rslearn 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +420 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +14 -20
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +377 -1
- rslearn/main.py +3 -3
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/olmoearth_pretrain/model.py +2 -5
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/METADATA +58 -25
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/RECORD +48 -49
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/WHEEL +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py
CHANGED
|
@@ -1,19 +1,73 @@
|
|
|
1
1
|
"""Classes for storing configuration of a dataset."""
|
|
2
2
|
|
|
3
|
+
import copy
|
|
4
|
+
import functools
|
|
3
5
|
import json
|
|
6
|
+
import warnings
|
|
4
7
|
from datetime import timedelta
|
|
5
|
-
from enum import
|
|
6
|
-
from typing import Any
|
|
8
|
+
from enum import StrEnum
|
|
9
|
+
from typing import TYPE_CHECKING, Annotated, Any
|
|
7
10
|
|
|
11
|
+
import jsonargparse
|
|
8
12
|
import numpy as np
|
|
9
13
|
import numpy.typing as npt
|
|
10
14
|
import pytimeparse
|
|
15
|
+
from pydantic import (
|
|
16
|
+
BaseModel,
|
|
17
|
+
BeforeValidator,
|
|
18
|
+
ConfigDict,
|
|
19
|
+
Field,
|
|
20
|
+
PlainSerializer,
|
|
21
|
+
field_validator,
|
|
22
|
+
model_validator,
|
|
23
|
+
)
|
|
11
24
|
from rasterio.enums import Resampling
|
|
25
|
+
from upath import UPath
|
|
12
26
|
|
|
27
|
+
from rslearn.log_utils import get_logger
|
|
13
28
|
from rslearn.utils import PixelBounds, Projection
|
|
29
|
+
from rslearn.utils.raster_format import RasterFormat
|
|
30
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
14
31
|
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from rslearn.data_sources.data_source import DataSource
|
|
15
34
|
|
|
16
|
-
|
|
35
|
+
logger = get_logger("__name__")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def ensure_timedelta(v: Any) -> Any:
|
|
39
|
+
"""Ensure the value is a timedelta.
|
|
40
|
+
|
|
41
|
+
If the value is a string, we try to parse it with pytimeparse.
|
|
42
|
+
|
|
43
|
+
This function is meant to be used like Annotated[timedelta, BeforeValidator(ensure_timedelta)].
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(v, timedelta):
|
|
46
|
+
return v
|
|
47
|
+
if isinstance(v, str):
|
|
48
|
+
return pytimeparse.parse(v)
|
|
49
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def ensure_optional_timedelta(v: Any) -> Any:
|
|
53
|
+
"""Like ensure_timedelta, but allows None as a value."""
|
|
54
|
+
if v is None:
|
|
55
|
+
return None
|
|
56
|
+
if isinstance(v, timedelta):
|
|
57
|
+
return v
|
|
58
|
+
if isinstance(v, str):
|
|
59
|
+
return pytimeparse.parse(v)
|
|
60
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def serialize_optional_timedelta(v: timedelta | None) -> str | None:
|
|
64
|
+
"""Serialize an optional timedelta for compatibility with pytimeparse."""
|
|
65
|
+
if v is None:
|
|
66
|
+
return None
|
|
67
|
+
return str(v.total_seconds()) + "s"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DType(StrEnum):
|
|
17
71
|
"""Data type of a raster."""
|
|
18
72
|
|
|
19
73
|
UINT8 = "uint8"
|
|
@@ -49,61 +103,28 @@ class DType(Enum):
|
|
|
49
103
|
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
50
104
|
|
|
51
105
|
|
|
52
|
-
|
|
53
|
-
"
|
|
54
|
-
"bilinear": Resampling.bilinear,
|
|
55
|
-
"cubic": Resampling.cubic,
|
|
56
|
-
"cubic_spline": Resampling.cubic_spline,
|
|
57
|
-
}
|
|
106
|
+
class ResamplingMethod(StrEnum):
|
|
107
|
+
"""An enum representing the rasterio Resampling."""
|
|
58
108
|
|
|
109
|
+
NEAREST = "nearest"
|
|
110
|
+
BILINEAR = "bilinear"
|
|
111
|
+
CUBIC = "cubic"
|
|
112
|
+
CUBIC_SPLINE = "cubic_spline"
|
|
59
113
|
|
|
60
|
-
|
|
61
|
-
|
|
114
|
+
def get_rasterio_resampling(self) -> Resampling:
|
|
115
|
+
"""Get the rasterio Resampling corresponding to this ResamplingMethod."""
|
|
116
|
+
return RESAMPLING_METHODS[self]
|
|
62
117
|
|
|
63
|
-
def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
|
|
64
|
-
"""Initialize a new RasterFormatConfig.
|
|
65
118
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
@staticmethod
|
|
74
|
-
def from_config(config: dict[str, Any]) -> "RasterFormatConfig":
|
|
75
|
-
"""Create a RasterFormatConfig from config dict.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
config: the config dict for this RasterFormatConfig
|
|
79
|
-
"""
|
|
80
|
-
return RasterFormatConfig(name=config["name"], config_dict=config)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class VectorFormatConfig:
|
|
84
|
-
"""A configuration specifying a VectorFormat."""
|
|
85
|
-
|
|
86
|
-
def __init__(self, name: str, config_dict: dict[str, Any] = {}) -> None:
|
|
87
|
-
"""Initialize a new VectorFormatConfig.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
name: the name of the VectorFormat to use.
|
|
91
|
-
config_dict: configuration to pass to the VectorFormat.
|
|
92
|
-
"""
|
|
93
|
-
self.name = name
|
|
94
|
-
self.config_dict = config_dict
|
|
95
|
-
|
|
96
|
-
@staticmethod
|
|
97
|
-
def from_config(config: dict[str, Any]) -> "VectorFormatConfig":
|
|
98
|
-
"""Create a VectorFormatConfig from config dict.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
config: the config dict for this VectorFormatConfig
|
|
102
|
-
"""
|
|
103
|
-
return VectorFormatConfig(name=config["name"], config_dict=config)
|
|
119
|
+
RESAMPLING_METHODS = {
|
|
120
|
+
ResamplingMethod.NEAREST: Resampling.nearest,
|
|
121
|
+
ResamplingMethod.BILINEAR: Resampling.bilinear,
|
|
122
|
+
ResamplingMethod.CUBIC: Resampling.cubic,
|
|
123
|
+
ResamplingMethod.CUBIC_SPLINE: Resampling.cubic_spline,
|
|
124
|
+
}
|
|
104
125
|
|
|
105
126
|
|
|
106
|
-
class BandSetConfig:
|
|
127
|
+
class BandSetConfig(BaseModel):
|
|
107
128
|
"""A configuration for a band set in a raster layer.
|
|
108
129
|
|
|
109
130
|
Each band set specifies one or more bands that should be stored together.
|
|
@@ -111,97 +132,67 @@ class BandSetConfig:
|
|
|
111
132
|
bands.
|
|
112
133
|
"""
|
|
113
134
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
135
|
+
dtype: DType = Field(description="Pixel value type to store the data under")
|
|
136
|
+
bands: list[str] = Field(
|
|
137
|
+
default_factory=lambda: [],
|
|
138
|
+
description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
|
|
139
|
+
)
|
|
140
|
+
num_bands: int | None = Field(
|
|
141
|
+
default=None,
|
|
142
|
+
description="The number of bands in this band set. The bands will be named B0, B1, B2, etc.",
|
|
143
|
+
)
|
|
144
|
+
format: dict[str, Any] = Field(
|
|
145
|
+
default_factory=lambda: {
|
|
146
|
+
"class_path": "rslearn.utils.raster_format.GeotiffRasterFormat"
|
|
147
|
+
},
|
|
148
|
+
description="jsonargparse configuration for the RasterFormat to store the tiles in.",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Store images at a resolution higher or lower than the window resolution. This
|
|
152
|
+
# enables keeping source data at its native resolution, either to save storage
|
|
153
|
+
# space (for lower resolution data) or to retain details (for higher resolution
|
|
154
|
+
# data). If positive, store data at the window resolution divided by
|
|
155
|
+
# 2^(zoom_offset) (higher resolution). If negative, store data at the window
|
|
156
|
+
# resolution multiplied by 2^(-zoom_offset) (lower resolution).
|
|
157
|
+
zoom_offset: int = Field(
|
|
158
|
+
default=0,
|
|
159
|
+
description="Store data at the window resolution multiplied by 2^(-zoom_offset).",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
remap: dict[str, Any] | None = Field(
|
|
163
|
+
default=None,
|
|
164
|
+
description="Optional jsonargparse configuration for a Remapper to remap pixel values.",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Optional list of names for the different possible values of each band. The length
|
|
168
|
+
# of this list must equal the number of bands. For example, [["forest", "desert"]]
|
|
169
|
+
# means that it is a single-band raster where values can be 0 (forest) or 1
|
|
170
|
+
# (desert).
|
|
171
|
+
class_names: list[list[str]] | None = Field(
|
|
172
|
+
default=None,
|
|
173
|
+
description="Optional list of names for the different possible values of each band.",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Optional list of nodata values for this band set. This is used during
|
|
177
|
+
# materialization when creating mosaics, to determine which parts of the source
|
|
178
|
+
# images should be copied.
|
|
179
|
+
nodata_vals: list[float] | None = Field(
|
|
180
|
+
default=None, description="Optional nodata value for each band."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
@model_validator(mode="after")
|
|
184
|
+
def after_validator(self) -> "BandSetConfig":
|
|
185
|
+
"""Ensure the BandSetConfig is valid, and handle the num_bands field."""
|
|
186
|
+
if (len(self.bands) == 0 and self.num_bands is None) or (
|
|
187
|
+
len(self.bands) != 0 and self.num_bands is not None
|
|
154
188
|
):
|
|
155
|
-
raise ValueError("exactly one of bands and num_bands must be
|
|
156
|
-
if bands is None:
|
|
157
|
-
assert num_bands is not None
|
|
158
|
-
bands = [f"B{idx}" for idx in range(num_bands)]
|
|
159
|
-
|
|
160
|
-
if class_names is not None and len(bands) != len(class_names):
|
|
161
|
-
raise ValueError(
|
|
162
|
-
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
|
|
163
|
-
)
|
|
189
|
+
raise ValueError("exactly one of bands and num_bands must be specified")
|
|
164
190
|
|
|
165
|
-
self.
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
self.zoom_offset = zoom_offset
|
|
169
|
-
self.remap = remap
|
|
170
|
-
self.class_names = class_names
|
|
171
|
-
self.nodata_vals = nodata_vals
|
|
191
|
+
if self.num_bands is not None:
|
|
192
|
+
self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
|
|
193
|
+
self.num_bands = None
|
|
172
194
|
|
|
173
|
-
|
|
174
|
-
self.format = {"name": "geotiff"}
|
|
175
|
-
else:
|
|
176
|
-
self.format = format
|
|
177
|
-
|
|
178
|
-
def serialize(self) -> dict[str, Any]:
|
|
179
|
-
"""Serialize this BandSetConfig to a config dict."""
|
|
180
|
-
return self.config_dict
|
|
181
|
-
|
|
182
|
-
@staticmethod
|
|
183
|
-
def from_config(config: dict[str, Any]) -> "BandSetConfig":
|
|
184
|
-
"""Create a BandSetConfig from config dict.
|
|
185
|
-
|
|
186
|
-
Args:
|
|
187
|
-
config: the config dict for this BandSetConfig
|
|
188
|
-
"""
|
|
189
|
-
kwargs = dict(
|
|
190
|
-
config_dict=config,
|
|
191
|
-
dtype=DType(config["dtype"]),
|
|
192
|
-
)
|
|
193
|
-
for k in [
|
|
194
|
-
"bands",
|
|
195
|
-
"num_bands",
|
|
196
|
-
"format",
|
|
197
|
-
"zoom_offset",
|
|
198
|
-
"remap",
|
|
199
|
-
"class_names",
|
|
200
|
-
"nodata_vals",
|
|
201
|
-
]:
|
|
202
|
-
if k in config:
|
|
203
|
-
kwargs[k] = config[k]
|
|
204
|
-
return BandSetConfig(**kwargs) # type: ignore
|
|
195
|
+
return self
|
|
205
196
|
|
|
206
197
|
def get_final_projection_and_bounds(
|
|
207
198
|
self, projection: Projection, bounds: PixelBounds
|
|
@@ -236,30 +227,79 @@ class BandSetConfig:
|
|
|
236
227
|
)
|
|
237
228
|
return projection, bounds
|
|
238
229
|
|
|
230
|
+
@field_validator("format", mode="before")
|
|
231
|
+
@classmethod
|
|
232
|
+
def convert_format_from_legacy(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
233
|
+
"""Support legacy format of the RasterFormat.
|
|
239
234
|
|
|
240
|
-
|
|
235
|
+
The legacy format sets 'name' instead of 'class_path', and uses custom parsing
|
|
236
|
+
for the init_args.
|
|
237
|
+
"""
|
|
238
|
+
if "name" not in v:
|
|
239
|
+
# New version, it is all good.
|
|
240
|
+
return v
|
|
241
|
+
|
|
242
|
+
warnings.warn(
|
|
243
|
+
"`format = {'name': ...}` is deprecated; "
|
|
244
|
+
"use `{'class_path': '...', 'init_args': {...}}` instead.",
|
|
245
|
+
DeprecationWarning,
|
|
246
|
+
)
|
|
247
|
+
logger.warning(
|
|
248
|
+
"BandSet.format uses legacy format; support will be removed after 2026-03-01."
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
legacy_name_to_class_path = {
|
|
252
|
+
"image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
|
|
253
|
+
"geotiff": "rslearn.utils.raster_format.GeotiffRasterFormat",
|
|
254
|
+
"single_image": "rslearn.utils.raster_format.SingleImageRasterFormat",
|
|
255
|
+
}
|
|
256
|
+
if v["name"] not in legacy_name_to_class_path:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"could not parse legacy format with unknown raster format {v['name']}"
|
|
259
|
+
)
|
|
260
|
+
init_args = dict(v)
|
|
261
|
+
class_path = legacy_name_to_class_path[init_args.pop("name")]
|
|
262
|
+
|
|
263
|
+
return dict(
|
|
264
|
+
class_path=class_path,
|
|
265
|
+
init_args=init_args,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def instantiate_raster_format(self) -> RasterFormat:
|
|
269
|
+
"""Instantiate the RasterFormat specified by this BandSetConfig."""
|
|
270
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
271
|
+
|
|
272
|
+
init_jsonargparse()
|
|
273
|
+
parser = jsonargparse.ArgumentParser()
|
|
274
|
+
parser.add_argument("--raster_format", type=RasterFormat)
|
|
275
|
+
cfg = parser.parse_object({"raster_format": self.format})
|
|
276
|
+
raster_format = parser.instantiate_classes(cfg).raster_format
|
|
277
|
+
return raster_format
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class SpaceMode(StrEnum):
|
|
241
281
|
"""Spatial matching mode when looking up items corresponding to a window."""
|
|
242
282
|
|
|
243
|
-
CONTAINS =
|
|
283
|
+
CONTAINS = "CONTAINS"
|
|
244
284
|
"""Items must contain the entire window."""
|
|
245
285
|
|
|
246
|
-
INTERSECTS =
|
|
286
|
+
INTERSECTS = "INTERSECTS"
|
|
247
287
|
"""Items must overlap any portion of the window."""
|
|
248
288
|
|
|
249
|
-
MOSAIC =
|
|
289
|
+
MOSAIC = "MOSAIC"
|
|
250
290
|
"""Groups of items should be computed that cover the entire window.
|
|
251
291
|
|
|
252
292
|
During materialization, items in each group are merged to form a mosaic in the
|
|
253
293
|
dataset.
|
|
254
294
|
"""
|
|
255
295
|
|
|
256
|
-
PER_PERIOD_MOSAIC =
|
|
296
|
+
PER_PERIOD_MOSAIC = "PER_PERIOD_MOSAIC"
|
|
257
297
|
"""Create one mosaic per sub-period of the time range.
|
|
258
298
|
|
|
259
299
|
The duration of the sub-periods is controlled by another option in QueryConfig.
|
|
260
300
|
"""
|
|
261
301
|
|
|
262
|
-
COMPOSITE =
|
|
302
|
+
COMPOSITE = "COMPOSITE"
|
|
263
303
|
"""Creates one composite covering the entire window.
|
|
264
304
|
|
|
265
305
|
During querying all items intersecting the window are placed in one group.
|
|
@@ -270,188 +310,219 @@ class SpaceMode(Enum):
|
|
|
270
310
|
# TODO add PER_PERIOD_COMPOSITE
|
|
271
311
|
|
|
272
312
|
|
|
273
|
-
class TimeMode(
|
|
313
|
+
class TimeMode(StrEnum):
|
|
274
314
|
"""Temporal matching mode when looking up items corresponding to a window."""
|
|
275
315
|
|
|
276
|
-
WITHIN =
|
|
316
|
+
WITHIN = "WITHIN"
|
|
277
317
|
"""Items must be within the window time range."""
|
|
278
318
|
|
|
279
|
-
NEAREST =
|
|
319
|
+
NEAREST = "NEAREST"
|
|
280
320
|
"""Select items closest to the window time range, up to max_matches."""
|
|
281
321
|
|
|
282
|
-
BEFORE =
|
|
322
|
+
BEFORE = "BEFORE"
|
|
283
323
|
"""Select items before the end of the window time range, up to max_matches."""
|
|
284
324
|
|
|
285
|
-
AFTER =
|
|
325
|
+
AFTER = "AFTER"
|
|
286
326
|
"""Select items after the start of the window time range, up to max_matches."""
|
|
287
327
|
|
|
288
328
|
|
|
289
|
-
class QueryConfig:
|
|
329
|
+
class QueryConfig(BaseModel):
|
|
290
330
|
"""A configuration for querying items in a data source."""
|
|
291
331
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
"
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
332
|
+
model_config = ConfigDict(frozen=True)
|
|
333
|
+
|
|
334
|
+
space_mode: SpaceMode = Field(
|
|
335
|
+
default=SpaceMode.MOSAIC,
|
|
336
|
+
description="Specifies how items should be matched with windows spatially.",
|
|
337
|
+
)
|
|
338
|
+
time_mode: TimeMode = Field(
|
|
339
|
+
default=TimeMode.WITHIN,
|
|
340
|
+
description="Specifies how items should be matched with windows temporally.",
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Minimum number of item groups. If there are fewer than this many matches, then no
|
|
344
|
+
# matches will be returned. This can be used to prevent unnecessary data ingestion
|
|
345
|
+
# if the user plans to discard windows that do not have a sufficient amount of data.
|
|
346
|
+
min_matches: int = Field(
|
|
347
|
+
default=0, description="The minimum number of item groups."
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
max_matches: int = Field(
|
|
351
|
+
default=1, description="The maximum number of item groups."
|
|
352
|
+
)
|
|
353
|
+
period_duration: Annotated[
|
|
354
|
+
timedelta,
|
|
355
|
+
BeforeValidator(ensure_timedelta),
|
|
356
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
357
|
+
] = Field(
|
|
358
|
+
default=timedelta(days=30),
|
|
359
|
+
description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class DataSourceConfig(BaseModel):
|
|
364
|
+
"""Configuration for a DataSource in a dataset layer."""
|
|
304
365
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
366
|
+
model_config = ConfigDict(frozen=True)
|
|
367
|
+
|
|
368
|
+
class_path: str = Field(description="Class path for the data source.")
|
|
369
|
+
init_args: dict[str, Any] = Field(
|
|
370
|
+
default_factory=lambda: {},
|
|
371
|
+
description="jsonargparse init args for the data source.",
|
|
372
|
+
)
|
|
373
|
+
query_config: QueryConfig = Field(
|
|
374
|
+
default_factory=lambda: QueryConfig(),
|
|
375
|
+
description="QueryConfig specifying how to match items with windows.",
|
|
376
|
+
)
|
|
377
|
+
time_offset: Annotated[
|
|
378
|
+
timedelta | None,
|
|
379
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
380
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
381
|
+
] = Field(
|
|
382
|
+
default=None,
|
|
383
|
+
description="Optional timedelta to add to the window's time range before matching.",
|
|
384
|
+
)
|
|
385
|
+
duration: Annotated[
|
|
386
|
+
timedelta | None,
|
|
387
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
388
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
389
|
+
] = Field(
|
|
390
|
+
default=None,
|
|
391
|
+
description="Optional, if the window's time range is (t0, t1), then update to (t0, t0 + duration).",
|
|
392
|
+
)
|
|
393
|
+
ingest: bool = Field(
|
|
394
|
+
default=True,
|
|
395
|
+
description="Whether to ingest this layer (default True). If False, it will be directly materialized without ingestion.",
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
@model_validator(mode="before")
|
|
399
|
+
@classmethod
|
|
400
|
+
def convert_from_legacy(cls, d: dict[str, Any]) -> dict[str, Any]:
|
|
401
|
+
"""Support legacy format of the DataSourceConfig.
|
|
402
|
+
|
|
403
|
+
The legacy format sets 'name' instead of 'class_path', and mixes the arguments
|
|
404
|
+
for the DataSource in with the DataSourceConfig keys.
|
|
316
405
|
"""
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
"
|
|
328
|
-
|
|
329
|
-
"max_matches": self.max_matches,
|
|
330
|
-
"period_duration": f"{self.period_duration.total_seconds()}s",
|
|
331
|
-
}
|
|
332
|
-
|
|
333
|
-
@staticmethod
|
|
334
|
-
def from_config(config: dict[str, Any]) -> "QueryConfig":
|
|
335
|
-
"""Create a QueryConfig from config dict.
|
|
406
|
+
if "name" not in d:
|
|
407
|
+
# New version, it is all good.
|
|
408
|
+
return d
|
|
409
|
+
|
|
410
|
+
warnings.warn(
|
|
411
|
+
"`Data source configuration {'name': ...}` is deprecated; "
|
|
412
|
+
"use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
|
|
413
|
+
DeprecationWarning,
|
|
414
|
+
)
|
|
415
|
+
logger.warning(
|
|
416
|
+
"Data source configuration uses legacy format; support will be removed after 2026-03-01."
|
|
417
|
+
)
|
|
336
418
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
""
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
kwargs["time_mode"] = TimeMode[config["time_mode"]]
|
|
345
|
-
if "period_duration" in config:
|
|
346
|
-
kwargs["period_duration"] = timedelta(
|
|
347
|
-
seconds=pytimeparse.parse(config["period_duration"])
|
|
348
|
-
)
|
|
349
|
-
for k in ["min_matches", "max_matches"]:
|
|
350
|
-
if k not in config:
|
|
419
|
+
# Split the dict into the base config that is in the pydantic model, and the
|
|
420
|
+
# source-specific options that should be moved to init_args dict.
|
|
421
|
+
class_path = d["name"]
|
|
422
|
+
base_config: dict[str, Any] = {}
|
|
423
|
+
ds_init_args: dict[str, Any] = {}
|
|
424
|
+
for k, v in d.items():
|
|
425
|
+
if k == "name":
|
|
351
426
|
continue
|
|
352
|
-
|
|
353
|
-
|
|
427
|
+
if k in cls.model_fields:
|
|
428
|
+
base_config[k] = v
|
|
429
|
+
else:
|
|
430
|
+
ds_init_args[k] = v
|
|
431
|
+
|
|
432
|
+
# Some legacy configs erroneously specify these keys, which are now caught by
|
|
433
|
+
# validation. But we still want those specific legacy configs to work.
|
|
434
|
+
if (
|
|
435
|
+
class_path == "rslearn.data_sources.planetary_computer.Sentinel2"
|
|
436
|
+
and "max_cloud_cover" in ds_init_args
|
|
437
|
+
):
|
|
438
|
+
warnings.warn(
|
|
439
|
+
"Data source configuration specifies invalid 'max_cloud_cover' option.",
|
|
440
|
+
DeprecationWarning,
|
|
441
|
+
)
|
|
442
|
+
del ds_init_args["max_cloud_cover"]
|
|
354
443
|
|
|
444
|
+
base_config["class_path"] = class_path
|
|
445
|
+
base_config["init_args"] = ds_init_args
|
|
446
|
+
return base_config
|
|
355
447
|
|
|
356
|
-
class DataSourceConfig:
|
|
357
|
-
"""Configuration for a DataSource in a dataset layer."""
|
|
358
448
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
name: str,
|
|
362
|
-
query_config: QueryConfig,
|
|
363
|
-
config_dict: dict[str, Any],
|
|
364
|
-
time_offset: timedelta | None = None,
|
|
365
|
-
duration: timedelta | None = None,
|
|
366
|
-
ingest: bool = True,
|
|
367
|
-
) -> None:
|
|
368
|
-
"""Initializes a new DataSourceConfig.
|
|
369
|
-
|
|
370
|
-
Args:
|
|
371
|
-
name: the data source class name
|
|
372
|
-
query_config: the QueryConfig specifying how to match items with windows
|
|
373
|
-
config_dict: additional config passed to initialize the DataSource
|
|
374
|
-
time_offset: optional, add this timedelta to the window's time range before
|
|
375
|
-
matching
|
|
376
|
-
duration: optional, if window's time range is (t0, t1), then update to
|
|
377
|
-
(t0, t0 + duration)
|
|
378
|
-
ingest: whether to ingest this layer or directly materialize it
|
|
379
|
-
(default true)
|
|
380
|
-
"""
|
|
381
|
-
self.name = name
|
|
382
|
-
self.query_config = query_config
|
|
383
|
-
self.config_dict = config_dict
|
|
384
|
-
self.time_offset = time_offset
|
|
385
|
-
self.duration = duration
|
|
386
|
-
self.ingest = ingest
|
|
449
|
+
class LayerType(StrEnum):
|
|
450
|
+
"""The layer type (raster or vector)."""
|
|
387
451
|
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
return self.config_dict
|
|
452
|
+
RASTER = "raster"
|
|
453
|
+
VECTOR = "vector"
|
|
391
454
|
|
|
392
|
-
@staticmethod
|
|
393
|
-
def from_config(config: dict[str, Any]) -> "DataSourceConfig":
|
|
394
|
-
"""Create a DataSourceConfig from config dict.
|
|
395
455
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
"""
|
|
399
|
-
kwargs = dict(
|
|
400
|
-
name=config["name"],
|
|
401
|
-
query_config=QueryConfig.from_config(config.get("query_config", {})),
|
|
402
|
-
config_dict=config,
|
|
403
|
-
)
|
|
404
|
-
if "time_offset" in config:
|
|
405
|
-
kwargs["time_offset"] = timedelta(
|
|
406
|
-
seconds=pytimeparse.parse(config["time_offset"])
|
|
407
|
-
)
|
|
408
|
-
if "duration" in config:
|
|
409
|
-
kwargs["duration"] = timedelta(
|
|
410
|
-
seconds=pytimeparse.parse(config["duration"])
|
|
411
|
-
)
|
|
412
|
-
if "ingest" in config:
|
|
413
|
-
kwargs["ingest"] = config["ingest"]
|
|
414
|
-
return DataSourceConfig(**kwargs)
|
|
456
|
+
class CompositingMethod(StrEnum):
|
|
457
|
+
"""Method how to select pixels for the composite from corresponding items of a window."""
|
|
415
458
|
|
|
459
|
+
FIRST_VALID = "FIRST_VALID"
|
|
460
|
+
"""Select first valid pixel in order of corresponding items (might be sorted)"""
|
|
416
461
|
|
|
417
|
-
|
|
418
|
-
"""
|
|
462
|
+
MEAN = "MEAN"
|
|
463
|
+
"""Select per-pixel mean value of corresponding items of a window"""
|
|
419
464
|
|
|
420
|
-
|
|
421
|
-
|
|
465
|
+
MEDIAN = "MEDIAN"
|
|
466
|
+
"""Select per-pixel median value of corresponding items of a window"""
|
|
422
467
|
|
|
423
468
|
|
|
424
|
-
class LayerConfig:
|
|
469
|
+
class LayerConfig(BaseModel):
|
|
425
470
|
"""Configuration of a layer in a dataset."""
|
|
426
471
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
472
|
+
model_config = ConfigDict(frozen=True)
|
|
473
|
+
|
|
474
|
+
type: LayerType = Field(description="The LayerType (raster or vector).")
|
|
475
|
+
data_source: DataSourceConfig | None = Field(
|
|
476
|
+
default=None,
|
|
477
|
+
description="Optional DataSourceConfig if this layer is retrievable.",
|
|
478
|
+
)
|
|
479
|
+
alias: str | None = Field(
|
|
480
|
+
default=None, description="Alias for this layer to use in the tile store."
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Raster layer options.
|
|
484
|
+
band_sets: list[BandSetConfig] = Field(
|
|
485
|
+
default_factory=lambda: [],
|
|
486
|
+
description="For raster layers, the bands to store in this layer.",
|
|
487
|
+
)
|
|
488
|
+
resampling_method: ResamplingMethod = Field(
|
|
489
|
+
default=ResamplingMethod.BILINEAR,
|
|
490
|
+
description="For raster layers, how to resample rasters (if neeed), default bilinear resampling.",
|
|
491
|
+
)
|
|
492
|
+
compositing_method: CompositingMethod = Field(
|
|
493
|
+
default=CompositingMethod.FIRST_VALID,
|
|
494
|
+
description="For raster layers, how to compute pixel values in the composite of each window's items.",
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Vector layer options.
|
|
498
|
+
vector_format: dict[str, Any] = Field(
|
|
499
|
+
default_factory=lambda: {
|
|
500
|
+
"class_path": "rslearn.utils.vector_format.GeojsonVectorFormat"
|
|
501
|
+
},
|
|
502
|
+
description="For vector layers, the jsonargparse configuration for the VectorFormat.",
|
|
503
|
+
)
|
|
504
|
+
class_property_name: str | None = Field(
|
|
505
|
+
default=None,
|
|
506
|
+
description="Optional metadata field indicating that the GeoJSON features contain a property that corresponds to a class label, and this is the name of that property.",
|
|
507
|
+
)
|
|
508
|
+
class_names: list[str] | None = Field(
|
|
509
|
+
default=None,
|
|
510
|
+
description="The list of classes that the class_property_name property could be set to.",
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
@model_validator(mode="after")
|
|
514
|
+
def after_validator(self) -> "LayerConfig":
|
|
515
|
+
"""Ensure the LayerConfig is valid."""
|
|
516
|
+
if self.type == LayerType.RASTER and len(self.band_sets) == 0:
|
|
517
|
+
raise ValueError(
|
|
518
|
+
"band sets must be specified and non-empty for raster layers"
|
|
519
|
+
)
|
|
434
520
|
|
|
435
|
-
|
|
436
|
-
layer_type: the LayerType (raster or vector)
|
|
437
|
-
data_source: optional DataSourceConfig if this layer is retrievable
|
|
438
|
-
alias: alias for this layer to use in the tile store
|
|
439
|
-
"""
|
|
440
|
-
self.layer_type = layer_type
|
|
441
|
-
self.data_source = data_source
|
|
442
|
-
self.alias = alias
|
|
443
|
-
|
|
444
|
-
def serialize(self) -> dict[str, Any]:
|
|
445
|
-
"""Serialize this LayerConfig to a config dict."""
|
|
446
|
-
return {
|
|
447
|
-
"layer_type": str(self.layer_type),
|
|
448
|
-
"data_source": self.data_source.serialize() if self.data_source else None,
|
|
449
|
-
"alias": self.alias,
|
|
450
|
-
}
|
|
521
|
+
return self
|
|
451
522
|
|
|
452
523
|
def __hash__(self) -> int:
|
|
453
524
|
"""Return a hash of this LayerConfig."""
|
|
454
|
-
return hash(json.dumps(self.
|
|
525
|
+
return hash(json.dumps(self.model_dump(mode="json"), sort_keys=True))
|
|
455
526
|
|
|
456
527
|
def __eq__(self, other: Any) -> bool:
|
|
457
528
|
"""Returns whether other is the same as this LayerConfig.
|
|
@@ -461,142 +532,71 @@ class LayerConfig:
|
|
|
461
532
|
"""
|
|
462
533
|
if not isinstance(other, LayerConfig):
|
|
463
534
|
return False
|
|
464
|
-
return self.
|
|
465
|
-
|
|
535
|
+
return self.model_dump() == other.model_dump()
|
|
466
536
|
|
|
467
|
-
|
|
468
|
-
|
|
537
|
+
@functools.cache
|
|
538
|
+
def instantiate_data_source(self, ds_path: UPath | None = None) -> "DataSource":
|
|
539
|
+
"""Instantiate the data source specified by this config.
|
|
469
540
|
|
|
470
|
-
|
|
471
|
-
|
|
541
|
+
Args:
|
|
542
|
+
ds_path: optional dataset path to include in the DataSourceContext.
|
|
472
543
|
|
|
473
|
-
|
|
474
|
-
|
|
544
|
+
Returns:
|
|
545
|
+
the DataSource object.
|
|
546
|
+
"""
|
|
547
|
+
from rslearn.data_sources.data_source import DataSource, DataSourceContext
|
|
548
|
+
from rslearn.utils.jsonargparse import data_source_context_serializer
|
|
475
549
|
|
|
476
|
-
|
|
477
|
-
|
|
550
|
+
logger.debug("getting a data source for dataset at %s", ds_path)
|
|
551
|
+
if self.data_source is None:
|
|
552
|
+
raise ValueError("This layer does not specify a data source")
|
|
478
553
|
|
|
554
|
+
# Inject the DataSourceContext into the args.
|
|
555
|
+
context = DataSourceContext(
|
|
556
|
+
ds_path=ds_path,
|
|
557
|
+
layer_config=self,
|
|
558
|
+
)
|
|
559
|
+
ds_config: dict[str, Any] = {
|
|
560
|
+
"class_path": self.data_source.class_path,
|
|
561
|
+
"init_args": copy.deepcopy(self.data_source.init_args),
|
|
562
|
+
}
|
|
563
|
+
ds_config["init_args"]["context"] = data_source_context_serializer(context)
|
|
479
564
|
|
|
480
|
-
|
|
481
|
-
|
|
565
|
+
# Now we can parse with jsonargparse.
|
|
566
|
+
from rslearn.utils.jsonargparse import (
|
|
567
|
+
data_source_context_serializer,
|
|
568
|
+
init_jsonargparse,
|
|
569
|
+
)
|
|
482
570
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
data_source
|
|
488
|
-
|
|
489
|
-
alias: str | None = None,
|
|
490
|
-
compositing_method: CompositingMethod = CompositingMethod.FIRST_VALID,
|
|
491
|
-
):
|
|
492
|
-
"""Initialize a new RasterLayerConfig.
|
|
571
|
+
init_jsonargparse()
|
|
572
|
+
parser = jsonargparse.ArgumentParser()
|
|
573
|
+
parser.add_argument("--data_source", type=DataSource)
|
|
574
|
+
cfg = parser.parse_object({"data_source": ds_config})
|
|
575
|
+
data_source = parser.instantiate_classes(cfg).data_source
|
|
576
|
+
return data_source
|
|
493
577
|
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
compositing_method: how to compute pixel values in the composite of each windows items
|
|
501
|
-
"""
|
|
502
|
-
super().__init__(layer_type, data_source, alias)
|
|
503
|
-
self.band_sets = band_sets
|
|
504
|
-
self.resampling_method = resampling_method
|
|
505
|
-
self.compositing_method = compositing_method
|
|
578
|
+
def instantiate_vector_format(self) -> VectorFormat:
|
|
579
|
+
"""Instantiate the vector format specified by this config."""
|
|
580
|
+
if self.type != LayerType.VECTOR:
|
|
581
|
+
raise ValueError(
|
|
582
|
+
f"cannot instantiate vector format for layer with type {self.type}"
|
|
583
|
+
)
|
|
506
584
|
|
|
507
|
-
|
|
508
|
-
def from_config(config: dict[str, Any]) -> "RasterLayerConfig":
|
|
509
|
-
"""Create a RasterLayerConfig from config dict.
|
|
585
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
510
586
|
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
""
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
}
|
|
518
|
-
if "data_source" in config:
|
|
519
|
-
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
|
|
520
|
-
if "resampling_method" in config:
|
|
521
|
-
kwargs["resampling_method"] = RESAMPLING_METHODS[
|
|
522
|
-
config["resampling_method"]
|
|
523
|
-
]
|
|
524
|
-
if "alias" in config:
|
|
525
|
-
kwargs["alias"] = config["alias"]
|
|
526
|
-
if "compositing_method" in config:
|
|
527
|
-
kwargs["compositing_method"] = CompositingMethod[
|
|
528
|
-
config["compositing_method"]
|
|
529
|
-
]
|
|
530
|
-
return RasterLayerConfig(**kwargs) # type: ignore
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
class VectorLayerConfig(LayerConfig):
|
|
534
|
-
"""Configuration of a vector layer."""
|
|
535
|
-
|
|
536
|
-
def __init__(
|
|
537
|
-
self,
|
|
538
|
-
layer_type: LayerType,
|
|
539
|
-
data_source: DataSourceConfig | None = None,
|
|
540
|
-
format: VectorFormatConfig = VectorFormatConfig("geojson"),
|
|
541
|
-
alias: str | None = None,
|
|
542
|
-
class_property_name: str | None = None,
|
|
543
|
-
class_names: list[str] | None = None,
|
|
544
|
-
):
|
|
545
|
-
"""Initialize a new VectorLayerConfig.
|
|
587
|
+
init_jsonargparse()
|
|
588
|
+
parser = jsonargparse.ArgumentParser()
|
|
589
|
+
parser.add_argument("--vector_format", type=VectorFormat)
|
|
590
|
+
cfg = parser.parse_object({"vector_format": self.vector_format})
|
|
591
|
+
vector_format = parser.instantiate_classes(cfg).vector_format
|
|
592
|
+
return vector_format
|
|
546
593
|
|
|
547
|
-
Args:
|
|
548
|
-
layer_type: the LayerType (must be vector)
|
|
549
|
-
data_source: optional DataSourceConfig if this layer is retrievable
|
|
550
|
-
format: the VectorFormatConfig, default storing as GeoJSON
|
|
551
|
-
alias: alias for this layer to use in the tile store
|
|
552
|
-
class_property_name: optional metadata field indicating that the GeoJSON
|
|
553
|
-
features contain a property that corresponds to a class label, and this
|
|
554
|
-
is the name of that property.
|
|
555
|
-
class_names: the list of classes that the class_property_name property
|
|
556
|
-
could be set to.
|
|
557
|
-
"""
|
|
558
|
-
super().__init__(layer_type, data_source, alias)
|
|
559
|
-
self.format = format
|
|
560
|
-
self.class_property_name = class_property_name
|
|
561
|
-
self.class_names = class_names
|
|
562
594
|
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
"""Create a VectorLayerConfig from config dict.
|
|
595
|
+
class DatasetConfig(BaseModel):
|
|
596
|
+
"""Overall dataset configuration."""
|
|
566
597
|
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
"""
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
|
|
573
|
-
if "format" in config:
|
|
574
|
-
kwargs["format"] = VectorFormatConfig.from_config(config["format"])
|
|
575
|
-
|
|
576
|
-
simple_optionals = [
|
|
577
|
-
"alias",
|
|
578
|
-
"class_property_name",
|
|
579
|
-
"class_names",
|
|
580
|
-
]
|
|
581
|
-
for k in simple_optionals:
|
|
582
|
-
if k in config:
|
|
583
|
-
kwargs[k] = config[k]
|
|
584
|
-
|
|
585
|
-
# The "zoom_offset" option was removed.
|
|
586
|
-
# We should change how we create configuration so we can error on all
|
|
587
|
-
# non-existing config options, but for now we make sure to raise error if
|
|
588
|
-
# zoom_offset is set since it is no longer supported.
|
|
589
|
-
if "zoom_offset" in config:
|
|
590
|
-
raise ValueError("unsupported zoom_offset option in vector layer config")
|
|
591
|
-
|
|
592
|
-
return VectorLayerConfig(**kwargs) # type: ignore
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
def load_layer_config(config: dict[str, Any]) -> LayerConfig:
|
|
596
|
-
"""Load a LayerConfig from a config dict."""
|
|
597
|
-
layer_type = LayerType(config.get("type"))
|
|
598
|
-
if layer_type == LayerType.RASTER:
|
|
599
|
-
return RasterLayerConfig.from_config(config)
|
|
600
|
-
elif layer_type == LayerType.VECTOR:
|
|
601
|
-
return VectorLayerConfig.from_config(config)
|
|
602
|
-
raise ValueError(f"Unknown layer type {layer_type}")
|
|
598
|
+
layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
|
|
599
|
+
tile_store: dict[str, Any] = Field(
|
|
600
|
+
default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
|
|
601
|
+
description="jsonargparse configuration for the TileStore.",
|
|
602
|
+
)
|