rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py
CHANGED
|
@@ -1,25 +1,84 @@
|
|
|
1
1
|
"""Classes for storing configuration of a dataset."""
|
|
2
2
|
|
|
3
|
+
import copy
|
|
4
|
+
import functools
|
|
5
|
+
import json
|
|
6
|
+
import warnings
|
|
3
7
|
from datetime import timedelta
|
|
4
|
-
from enum import
|
|
5
|
-
from typing import Any
|
|
8
|
+
from enum import StrEnum
|
|
9
|
+
from typing import TYPE_CHECKING, Annotated, Any
|
|
6
10
|
|
|
11
|
+
import jsonargparse
|
|
7
12
|
import numpy as np
|
|
8
13
|
import numpy.typing as npt
|
|
9
14
|
import pytimeparse
|
|
10
|
-
import
|
|
15
|
+
from pydantic import (
|
|
16
|
+
BaseModel,
|
|
17
|
+
BeforeValidator,
|
|
18
|
+
ConfigDict,
|
|
19
|
+
Field,
|
|
20
|
+
PlainSerializer,
|
|
21
|
+
field_validator,
|
|
22
|
+
model_validator,
|
|
23
|
+
)
|
|
11
24
|
from rasterio.enums import Resampling
|
|
25
|
+
from upath import UPath
|
|
12
26
|
|
|
13
|
-
from rslearn.
|
|
27
|
+
from rslearn.log_utils import get_logger
|
|
28
|
+
from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
|
|
29
|
+
from rslearn.utils.raster_format import RasterFormat
|
|
30
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
14
31
|
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from rslearn.data_sources.data_source import DataSource
|
|
34
|
+
from rslearn.dataset.storage.storage import WindowStorageFactory
|
|
15
35
|
|
|
16
|
-
|
|
36
|
+
logger = get_logger("__name__")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def ensure_timedelta(v: Any) -> Any:
|
|
40
|
+
"""Ensure the value is a timedelta.
|
|
41
|
+
|
|
42
|
+
If the value is a string, we try to parse it with pytimeparse.
|
|
43
|
+
|
|
44
|
+
This function is meant to be used like Annotated[timedelta, BeforeValidator(ensure_timedelta)].
|
|
45
|
+
"""
|
|
46
|
+
if isinstance(v, timedelta):
|
|
47
|
+
return v
|
|
48
|
+
if isinstance(v, str):
|
|
49
|
+
return pytimeparse.parse(v)
|
|
50
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def ensure_optional_timedelta(v: Any) -> Any:
|
|
54
|
+
"""Like ensure_timedelta, but allows None as a value."""
|
|
55
|
+
if v is None:
|
|
56
|
+
return None
|
|
57
|
+
if isinstance(v, timedelta):
|
|
58
|
+
return v
|
|
59
|
+
if isinstance(v, str):
|
|
60
|
+
return pytimeparse.parse(v)
|
|
61
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def serialize_optional_timedelta(v: timedelta | None) -> str | None:
|
|
65
|
+
"""Serialize an optional timedelta for compatibility with pytimeparse."""
|
|
66
|
+
if v is None:
|
|
67
|
+
return None
|
|
68
|
+
return str(v.total_seconds()) + "s"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class DType(StrEnum):
|
|
17
72
|
"""Data type of a raster."""
|
|
18
73
|
|
|
19
74
|
UINT8 = "uint8"
|
|
20
75
|
UINT16 = "uint16"
|
|
21
76
|
UINT32 = "uint32"
|
|
77
|
+
UINT64 = "uint64"
|
|
78
|
+
INT8 = "int8"
|
|
79
|
+
INT16 = "int16"
|
|
22
80
|
INT32 = "int32"
|
|
81
|
+
INT64 = "int64"
|
|
23
82
|
FLOAT32 = "float32"
|
|
24
83
|
|
|
25
84
|
def get_numpy_dtype(self) -> npt.DTypeLike:
|
|
@@ -30,77 +89,43 @@ class DType(Enum):
|
|
|
30
89
|
return np.uint16
|
|
31
90
|
elif self == DType.UINT32:
|
|
32
91
|
return np.uint32
|
|
92
|
+
elif self == DType.UINT64:
|
|
93
|
+
return np.uint64
|
|
94
|
+
elif self == DType.INT8:
|
|
95
|
+
return np.int8
|
|
96
|
+
elif self == DType.INT16:
|
|
97
|
+
return np.int16
|
|
33
98
|
elif self == DType.INT32:
|
|
34
99
|
return np.int32
|
|
100
|
+
elif self == DType.INT64:
|
|
101
|
+
return np.int64
|
|
35
102
|
elif self == DType.FLOAT32:
|
|
36
103
|
return np.float32
|
|
37
104
|
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
38
105
|
|
|
39
|
-
def get_torch_dtype(self) -> torch.dtype:
|
|
40
|
-
"""Returns pytorch dtype object corresponding to this DType."""
|
|
41
|
-
if self == DType.INT32:
|
|
42
|
-
return torch.int32
|
|
43
|
-
elif self == DType.FLOAT32:
|
|
44
|
-
return torch.float32
|
|
45
|
-
else:
|
|
46
|
-
raise ValueError(f"unable to handle torch dtype {self}")
|
|
47
106
|
|
|
107
|
+
class ResamplingMethod(StrEnum):
|
|
108
|
+
"""An enum representing the rasterio Resampling."""
|
|
48
109
|
|
|
49
|
-
|
|
50
|
-
"
|
|
51
|
-
"
|
|
52
|
-
"
|
|
53
|
-
"cubic_spline": Resampling.cubic_spline,
|
|
54
|
-
}
|
|
110
|
+
NEAREST = "nearest"
|
|
111
|
+
BILINEAR = "bilinear"
|
|
112
|
+
CUBIC = "cubic"
|
|
113
|
+
CUBIC_SPLINE = "cubic_spline"
|
|
55
114
|
|
|
115
|
+
def get_rasterio_resampling(self) -> Resampling:
|
|
116
|
+
"""Get the rasterio Resampling corresponding to this ResamplingMethod."""
|
|
117
|
+
return RESAMPLING_METHODS[self]
|
|
56
118
|
|
|
57
|
-
class RasterFormatConfig:
|
|
58
|
-
"""A configuration specifying a RasterFormat."""
|
|
59
|
-
|
|
60
|
-
def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
|
|
61
|
-
"""Initialize a new RasterFormatConfig.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
name: the name of the RasterFormat to use.
|
|
65
|
-
config_dict: configuration to pass to the RasterFormat.
|
|
66
|
-
"""
|
|
67
|
-
self.name = name
|
|
68
|
-
self.config_dict = config_dict
|
|
69
|
-
|
|
70
|
-
@staticmethod
|
|
71
|
-
def from_config(config: dict[str, Any]) -> "RasterFormatConfig":
|
|
72
|
-
"""Create a RasterFormatConfig from config dict.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
config: the config dict for this RasterFormatConfig
|
|
76
|
-
"""
|
|
77
|
-
return RasterFormatConfig(name=config["name"], config_dict=config)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
class VectorFormatConfig:
|
|
81
|
-
"""A configuration specifying a VectorFormat."""
|
|
82
|
-
|
|
83
|
-
def __init__(self, name: str, config_dict: dict[str, Any] = {}) -> None:
|
|
84
|
-
"""Initialize a new VectorFormatConfig.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
name: the name of the VectorFormat to use.
|
|
88
|
-
config_dict: configuration to pass to the VectorFormat.
|
|
89
|
-
"""
|
|
90
|
-
self.name = name
|
|
91
|
-
self.config_dict = config_dict
|
|
92
|
-
|
|
93
|
-
@staticmethod
|
|
94
|
-
def from_config(config: dict[str, Any]) -> "VectorFormatConfig":
|
|
95
|
-
"""Create a VectorFormatConfig from config dict.
|
|
96
119
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
120
|
+
RESAMPLING_METHODS = {
|
|
121
|
+
ResamplingMethod.NEAREST: Resampling.nearest,
|
|
122
|
+
ResamplingMethod.BILINEAR: Resampling.bilinear,
|
|
123
|
+
ResamplingMethod.CUBIC: Resampling.cubic,
|
|
124
|
+
ResamplingMethod.CUBIC_SPLINE: Resampling.cubic_spline,
|
|
125
|
+
}
|
|
101
126
|
|
|
102
127
|
|
|
103
|
-
class BandSetConfig:
|
|
128
|
+
class BandSetConfig(BaseModel):
|
|
104
129
|
"""A configuration for a band set in a raster layer.
|
|
105
130
|
|
|
106
131
|
Each band set specifies one or more bands that should be stored together.
|
|
@@ -108,65 +133,75 @@ class BandSetConfig:
|
|
|
108
133
|
bands.
|
|
109
134
|
"""
|
|
110
135
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
"
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
136
|
+
model_config = ConfigDict(extra="forbid")
|
|
137
|
+
|
|
138
|
+
dtype: DType = Field(
|
|
139
|
+
description="Pixel value type to store the data under. This is used during dataset materialize and model predict."
|
|
140
|
+
)
|
|
141
|
+
bands: list[str] = Field(
|
|
142
|
+
default_factory=lambda: [],
|
|
143
|
+
description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
|
|
144
|
+
)
|
|
145
|
+
num_bands: int | None = Field(
|
|
146
|
+
default=None,
|
|
147
|
+
description="The number of bands in this band set. The bands will be named B0, B1, B2, etc.",
|
|
148
|
+
)
|
|
149
|
+
format: dict[str, Any] = Field(
|
|
150
|
+
default_factory=lambda: {
|
|
151
|
+
"class_path": "rslearn.utils.raster_format.GeotiffRasterFormat"
|
|
152
|
+
},
|
|
153
|
+
description="jsonargparse configuration for the RasterFormat to store the tiles in.",
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Store images at a resolution higher or lower than the window resolution. This
|
|
157
|
+
# enables keeping source data at its native resolution, either to save storage
|
|
158
|
+
# space (for lower resolution data) or to retain details (for higher resolution
|
|
159
|
+
# data). If positive, store data at the window resolution divided by
|
|
160
|
+
# 2^(zoom_offset) (higher resolution). If negative, store data at the window
|
|
161
|
+
# resolution multiplied by 2^(-zoom_offset) (lower resolution).
|
|
162
|
+
zoom_offset: int = Field(
|
|
163
|
+
default=0,
|
|
164
|
+
description="Store data at the window resolution multiplied by 2^(-zoom_offset).",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
remap: dict[str, Any] | None = Field(
|
|
168
|
+
default=None,
|
|
169
|
+
description="Optional jsonargparse configuration for a Remapper to remap pixel values.",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Optional list of names for the different possible values of each band. The length
|
|
173
|
+
# of this list must equal the number of bands. For example, [["forest", "desert"]]
|
|
174
|
+
# means that it is a single-band raster where values can be 0 (forest) or 1
|
|
175
|
+
# (desert).
|
|
176
|
+
class_names: list[list[str]] | None = Field(
|
|
177
|
+
default=None,
|
|
178
|
+
description="Optional list of names for the different possible values of each band.",
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Optional list of nodata values for this band set. This is used during
|
|
182
|
+
# materialization when creating mosaics, to determine which parts of the source
|
|
183
|
+
# images should be copied.
|
|
184
|
+
nodata_vals: list[float] | None = Field(
|
|
185
|
+
default=None, description="Optional nodata value for each band."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
@model_validator(mode="after")
|
|
189
|
+
def after_validator(self) -> "BandSetConfig":
|
|
190
|
+
"""Ensure the BandSetConfig is valid, and handle the num_bands field."""
|
|
191
|
+
if (len(self.bands) == 0 and self.num_bands is None) or (
|
|
192
|
+
len(self.bands) != 0 and self.num_bands is not None
|
|
193
|
+
):
|
|
194
|
+
raise ValueError("exactly one of bands and num_bands must be specified")
|
|
195
|
+
|
|
196
|
+
if self.num_bands is not None:
|
|
197
|
+
self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
|
|
198
|
+
self.num_bands = None
|
|
199
|
+
|
|
200
|
+
return self
|
|
166
201
|
|
|
167
202
|
def get_final_projection_and_bounds(
|
|
168
|
-
self, projection: Projection, bounds: PixelBounds
|
|
169
|
-
) -> tuple[Projection, PixelBounds
|
|
203
|
+
self, projection: Projection, bounds: PixelBounds
|
|
204
|
+
) -> tuple[Projection, PixelBounds]:
|
|
170
205
|
"""Gets the final projection/bounds based on band set config.
|
|
171
206
|
|
|
172
207
|
The band set config may apply a non-zero zoom offset that modifies the window's
|
|
@@ -180,348 +215,432 @@ class BandSetConfig:
|
|
|
180
215
|
Returns:
|
|
181
216
|
tuple of updated projection and bounds with zoom offset applied
|
|
182
217
|
"""
|
|
183
|
-
if self.zoom_offset
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
218
|
+
if self.zoom_offset >= 0:
|
|
219
|
+
factor = ResolutionFactor(numerator=2**self.zoom_offset)
|
|
220
|
+
else:
|
|
221
|
+
factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
|
|
222
|
+
|
|
223
|
+
return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
|
|
224
|
+
|
|
225
|
+
@field_validator("format", mode="before")
|
|
226
|
+
@classmethod
|
|
227
|
+
def convert_format_from_legacy(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
228
|
+
"""Support legacy format of the RasterFormat.
|
|
229
|
+
|
|
230
|
+
The legacy format sets 'name' instead of 'class_path', and uses custom parsing
|
|
231
|
+
for the init_args.
|
|
232
|
+
"""
|
|
233
|
+
if "name" not in v:
|
|
234
|
+
# New version, it is all good.
|
|
235
|
+
return v
|
|
236
|
+
|
|
237
|
+
warnings.warn(
|
|
238
|
+
"`format = {'name': ...}` is deprecated; "
|
|
239
|
+
"use `{'class_path': '...', 'init_args': {...}}` instead.",
|
|
240
|
+
DeprecationWarning,
|
|
241
|
+
)
|
|
242
|
+
logger.warning(
|
|
243
|
+
"BandSet.format uses legacy format; support will be removed after 2026-03-01."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
legacy_name_to_class_path = {
|
|
247
|
+
"image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
|
|
248
|
+
"geotiff": "rslearn.utils.raster_format.GeotiffRasterFormat",
|
|
249
|
+
"single_image": "rslearn.utils.raster_format.SingleImageRasterFormat",
|
|
250
|
+
}
|
|
251
|
+
if v["name"] not in legacy_name_to_class_path:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"could not parse legacy format with unknown raster format {v['name']}"
|
|
254
|
+
)
|
|
255
|
+
init_args = dict(v)
|
|
256
|
+
class_path = legacy_name_to_class_path[init_args.pop("name")]
|
|
257
|
+
|
|
258
|
+
return dict(
|
|
259
|
+
class_path=class_path,
|
|
260
|
+
init_args=init_args,
|
|
189
261
|
)
|
|
190
|
-
if bounds:
|
|
191
|
-
if self.zoom_offset > 0:
|
|
192
|
-
bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
|
|
193
|
-
else:
|
|
194
|
-
bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
|
|
195
|
-
return projection, bounds
|
|
196
262
|
|
|
263
|
+
def instantiate_raster_format(self) -> RasterFormat:
|
|
264
|
+
"""Instantiate the RasterFormat specified by this BandSetConfig."""
|
|
265
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
197
266
|
|
|
198
|
-
|
|
267
|
+
init_jsonargparse()
|
|
268
|
+
parser = jsonargparse.ArgumentParser()
|
|
269
|
+
parser.add_argument("--raster_format", type=RasterFormat)
|
|
270
|
+
cfg = parser.parse_object({"raster_format": self.format})
|
|
271
|
+
raster_format = parser.instantiate_classes(cfg).raster_format
|
|
272
|
+
return raster_format
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class SpaceMode(StrEnum):
|
|
199
276
|
"""Spatial matching mode when looking up items corresponding to a window."""
|
|
200
277
|
|
|
201
|
-
CONTAINS =
|
|
278
|
+
CONTAINS = "CONTAINS"
|
|
202
279
|
"""Items must contain the entire window."""
|
|
203
280
|
|
|
204
|
-
INTERSECTS =
|
|
281
|
+
INTERSECTS = "INTERSECTS"
|
|
205
282
|
"""Items must overlap any portion of the window."""
|
|
206
283
|
|
|
207
|
-
MOSAIC =
|
|
284
|
+
MOSAIC = "MOSAIC"
|
|
208
285
|
"""Groups of items should be computed that cover the entire window.
|
|
209
286
|
|
|
210
287
|
During materialization, items in each group are merged to form a mosaic in the
|
|
211
288
|
dataset.
|
|
212
289
|
"""
|
|
213
290
|
|
|
291
|
+
PER_PERIOD_MOSAIC = "PER_PERIOD_MOSAIC"
|
|
292
|
+
"""Create one mosaic per sub-period of the time range.
|
|
214
293
|
|
|
215
|
-
|
|
216
|
-
"""
|
|
217
|
-
|
|
218
|
-
WITHIN = 1
|
|
219
|
-
"""Items must be within the window time range."""
|
|
294
|
+
The duration of the sub-periods is controlled by another option in QueryConfig.
|
|
295
|
+
"""
|
|
220
296
|
|
|
221
|
-
|
|
222
|
-
"""
|
|
297
|
+
COMPOSITE = "COMPOSITE"
|
|
298
|
+
"""Creates one composite covering the entire window.
|
|
223
299
|
|
|
224
|
-
|
|
225
|
-
|
|
300
|
+
During querying all items intersecting the window are placed in one group.
|
|
301
|
+
The compositing_method in the rasterlayer config specifies how these items are reduced
|
|
302
|
+
to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
|
|
303
|
+
"""
|
|
226
304
|
|
|
227
|
-
|
|
228
|
-
"""Select items after the end of the window time range, up to max_matches."""
|
|
305
|
+
# TODO add PER_PERIOD_COMPOSITE
|
|
229
306
|
|
|
230
307
|
|
|
231
|
-
class
|
|
232
|
-
"""
|
|
308
|
+
class TimeMode(StrEnum):
|
|
309
|
+
"""Temporal matching mode when looking up items corresponding to a window."""
|
|
233
310
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
space_mode: SpaceMode = SpaceMode.MOSAIC,
|
|
237
|
-
time_mode: TimeMode = TimeMode.WITHIN,
|
|
238
|
-
max_matches: int = 1,
|
|
239
|
-
):
|
|
240
|
-
"""Creates a new query configuration.
|
|
311
|
+
WITHIN = "WITHIN"
|
|
312
|
+
"""Items must be within the window time range."""
|
|
241
313
|
|
|
242
|
-
|
|
243
|
-
|
|
314
|
+
NEAREST = "NEAREST"
|
|
315
|
+
"""Select items closest to the window time range, up to max_matches."""
|
|
244
316
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
time_mode: specifies how items should be matched with windows temporally
|
|
248
|
-
max_matches: the maximum number of items (or groups of items, if space_mode
|
|
249
|
-
is MOSAIC) to match
|
|
250
|
-
"""
|
|
251
|
-
self.space_mode = space_mode
|
|
252
|
-
self.time_mode = time_mode
|
|
253
|
-
self.max_matches = max_matches
|
|
254
|
-
|
|
255
|
-
def serialize(self) -> dict[str, Any]:
|
|
256
|
-
"""Serialize this QueryConfig to a config dict, currently unused."""
|
|
257
|
-
return {
|
|
258
|
-
"space_mode": str(self.space_mode),
|
|
259
|
-
"time_mode": str(self.time_mode),
|
|
260
|
-
"max_matches": self.max_matches,
|
|
261
|
-
}
|
|
317
|
+
BEFORE = "BEFORE"
|
|
318
|
+
"""Select items before the end of the window time range, up to max_matches."""
|
|
262
319
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
"""Create a QueryConfig from config dict.
|
|
320
|
+
AFTER = "AFTER"
|
|
321
|
+
"""Select items after the start of the window time range, up to max_matches."""
|
|
266
322
|
|
|
267
|
-
Args:
|
|
268
|
-
config: the config dict for this QueryConfig
|
|
269
|
-
"""
|
|
270
|
-
return QueryConfig(
|
|
271
|
-
space_mode=SpaceMode[config.get("space_mode", "MOSAIC")],
|
|
272
|
-
time_mode=TimeMode[config.get("time_mode", "WITHIN")],
|
|
273
|
-
max_matches=config.get("max_matches", 1),
|
|
274
|
-
)
|
|
275
323
|
|
|
324
|
+
class QueryConfig(BaseModel):
|
|
325
|
+
"""A configuration for querying items in a data source."""
|
|
276
326
|
|
|
277
|
-
|
|
327
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
328
|
+
|
|
329
|
+
space_mode: SpaceMode = Field(
|
|
330
|
+
default=SpaceMode.MOSAIC,
|
|
331
|
+
description="Specifies how items should be matched with windows spatially.",
|
|
332
|
+
)
|
|
333
|
+
time_mode: TimeMode = Field(
|
|
334
|
+
default=TimeMode.WITHIN,
|
|
335
|
+
description="Specifies how items should be matched with windows temporally.",
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Minimum number of item groups. If there are fewer than this many matches, then no
|
|
339
|
+
# matches will be returned. This can be used to prevent unnecessary data ingestion
|
|
340
|
+
# if the user plans to discard windows that do not have a sufficient amount of data.
|
|
341
|
+
min_matches: int = Field(
|
|
342
|
+
default=0, description="The minimum number of item groups."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
max_matches: int = Field(
|
|
346
|
+
default=1, description="The maximum number of item groups."
|
|
347
|
+
)
|
|
348
|
+
period_duration: Annotated[
|
|
349
|
+
timedelta,
|
|
350
|
+
BeforeValidator(ensure_timedelta),
|
|
351
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
352
|
+
] = Field(
|
|
353
|
+
default=timedelta(days=30),
|
|
354
|
+
description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class DataSourceConfig(BaseModel):
|
|
278
359
|
"""Configuration for a DataSource in a dataset layer."""
|
|
279
360
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
"
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
361
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
362
|
+
|
|
363
|
+
class_path: str = Field(description="Class path for the data source.")
|
|
364
|
+
init_args: dict[str, Any] = Field(
|
|
365
|
+
default_factory=lambda: {},
|
|
366
|
+
description="jsonargparse init args for the data source.",
|
|
367
|
+
)
|
|
368
|
+
query_config: QueryConfig = Field(
|
|
369
|
+
default_factory=lambda: QueryConfig(),
|
|
370
|
+
description="QueryConfig specifying how to match items with windows.",
|
|
371
|
+
)
|
|
372
|
+
time_offset: Annotated[
|
|
373
|
+
timedelta | None,
|
|
374
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
375
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
376
|
+
] = Field(
|
|
377
|
+
default=None,
|
|
378
|
+
description="Optional timedelta to add to the window's time range before matching.",
|
|
379
|
+
)
|
|
380
|
+
duration: Annotated[
|
|
381
|
+
timedelta | None,
|
|
382
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
383
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
384
|
+
] = Field(
|
|
385
|
+
default=None,
|
|
386
|
+
description="Optional, if the window's time range is (t0, t1), then update to (t0, t0 + duration).",
|
|
387
|
+
)
|
|
388
|
+
ingest: bool = Field(
|
|
389
|
+
default=True,
|
|
390
|
+
description="Whether to ingest this layer (default True). If False, it will be directly materialized without ingestion.",
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
@model_validator(mode="before")
|
|
394
|
+
@classmethod
|
|
395
|
+
def convert_from_legacy(cls, d: dict[str, Any]) -> dict[str, Any]:
|
|
396
|
+
"""Support legacy format of the DataSourceConfig.
|
|
397
|
+
|
|
398
|
+
The legacy format sets 'name' instead of 'class_path', and mixes the arguments
|
|
399
|
+
for the DataSource in with the DataSourceConfig keys.
|
|
301
400
|
"""
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
"""Serialize this DataSourceConfig to a config dict, currently unused."""
|
|
311
|
-
config_dict = self.config_dict.copy()
|
|
312
|
-
config_dict["name"] = self.name
|
|
313
|
-
config_dict["query_config"] = self.query_config.serialize()
|
|
314
|
-
config_dict["ingest"] = self.ingest
|
|
315
|
-
if self.time_offset:
|
|
316
|
-
config_dict["time_offset"] = str(self.time_offset)
|
|
317
|
-
if self.duration:
|
|
318
|
-
config_dict["duration"] = str(self.duration)
|
|
319
|
-
return config_dict
|
|
320
|
-
|
|
321
|
-
@staticmethod
|
|
322
|
-
def from_config(config: dict[str, Any]) -> "DataSourceConfig":
|
|
323
|
-
"""Create a DataSourceConfig from config dict.
|
|
324
|
-
|
|
325
|
-
Args:
|
|
326
|
-
config: the config dict for this DataSourceConfig
|
|
327
|
-
"""
|
|
328
|
-
kwargs = dict(
|
|
329
|
-
name=config["name"],
|
|
330
|
-
query_config=QueryConfig.from_config(config.get("query_config", {})),
|
|
331
|
-
config_dict=config,
|
|
401
|
+
if "name" not in d:
|
|
402
|
+
# New version, it is all good.
|
|
403
|
+
return d
|
|
404
|
+
|
|
405
|
+
warnings.warn(
|
|
406
|
+
"`Data source configuration {'name': ...}` is deprecated; "
|
|
407
|
+
"use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
|
|
408
|
+
DeprecationWarning,
|
|
332
409
|
)
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
410
|
+
logger.warning(
|
|
411
|
+
"Data source configuration uses legacy format; support will be removed after 2026-03-01."
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Split the dict into the base config that is in the pydantic model, and the
|
|
415
|
+
# source-specific options that should be moved to init_args dict.
|
|
416
|
+
class_path = d["name"]
|
|
417
|
+
base_config: dict[str, Any] = {}
|
|
418
|
+
ds_init_args: dict[str, Any] = {}
|
|
419
|
+
for k, v in d.items():
|
|
420
|
+
if k == "name":
|
|
421
|
+
continue
|
|
422
|
+
if k in cls.model_fields:
|
|
423
|
+
base_config[k] = v
|
|
424
|
+
else:
|
|
425
|
+
ds_init_args[k] = v
|
|
426
|
+
|
|
427
|
+
# Some legacy configs erroneously specify these keys, which are now caught by
|
|
428
|
+
# validation. But we still want those specific legacy configs to work.
|
|
429
|
+
if (
|
|
430
|
+
class_path == "rslearn.data_sources.planetary_computer.Sentinel2"
|
|
431
|
+
and "max_cloud_cover" in ds_init_args
|
|
432
|
+
):
|
|
433
|
+
warnings.warn(
|
|
434
|
+
"Data source configuration specifies invalid 'max_cloud_cover' option.",
|
|
435
|
+
DeprecationWarning,
|
|
340
436
|
)
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
437
|
+
del ds_init_args["max_cloud_cover"]
|
|
438
|
+
|
|
439
|
+
base_config["class_path"] = class_path
|
|
440
|
+
base_config["init_args"] = ds_init_args
|
|
441
|
+
return base_config
|
|
344
442
|
|
|
345
443
|
|
|
346
|
-
class LayerType(
|
|
444
|
+
class LayerType(StrEnum):
|
|
347
445
|
"""The layer type (raster or vector)."""
|
|
348
446
|
|
|
349
447
|
RASTER = "raster"
|
|
350
448
|
VECTOR = "vector"
|
|
351
449
|
|
|
352
450
|
|
|
353
|
-
class
|
|
354
|
-
"""
|
|
355
|
-
|
|
356
|
-
def __init__(
|
|
357
|
-
self,
|
|
358
|
-
layer_type: LayerType,
|
|
359
|
-
data_source: DataSourceConfig | None = None,
|
|
360
|
-
alias: str | None = None,
|
|
361
|
-
):
|
|
362
|
-
"""Initialize a new LayerConfig.
|
|
451
|
+
class CompositingMethod(StrEnum):
|
|
452
|
+
"""Method how to select pixels for the composite from corresponding items of a window."""
|
|
363
453
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
data_source: optional DataSourceConfig if this layer is retrievable
|
|
367
|
-
alias: alias for this layer to use in the tile store
|
|
368
|
-
"""
|
|
369
|
-
self.layer_type = layer_type
|
|
370
|
-
self.data_source = data_source
|
|
371
|
-
self.alias = alias
|
|
372
|
-
|
|
373
|
-
def serialize(self) -> dict[str, Any]:
|
|
374
|
-
"""Serialize this LayerConfig to a config dict, currently unused."""
|
|
375
|
-
return {
|
|
376
|
-
"layer_type": str(self.layer_type),
|
|
377
|
-
"data_source": self.data_source,
|
|
378
|
-
"alias": self.alias,
|
|
379
|
-
}
|
|
454
|
+
FIRST_VALID = "FIRST_VALID"
|
|
455
|
+
"""Select first valid pixel in order of corresponding items (might be sorted)"""
|
|
380
456
|
|
|
457
|
+
MEAN = "MEAN"
|
|
458
|
+
"""Select per-pixel mean value of corresponding items of a window"""
|
|
381
459
|
|
|
382
|
-
|
|
383
|
-
"""
|
|
460
|
+
MEDIAN = "MEDIAN"
|
|
461
|
+
"""Select per-pixel median value of corresponding items of a window"""
|
|
384
462
|
|
|
385
|
-
def __init__(
|
|
386
|
-
self,
|
|
387
|
-
layer_type: LayerType,
|
|
388
|
-
band_sets: list[BandSetConfig],
|
|
389
|
-
data_source: DataSourceConfig | None = None,
|
|
390
|
-
resampling_method: Resampling = Resampling.bilinear,
|
|
391
|
-
alias: str | None = None,
|
|
392
|
-
):
|
|
393
|
-
"""Initialize a new RasterLayerConfig.
|
|
394
463
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
band_sets: the bands to store in this layer
|
|
398
|
-
data_source: optional DataSourceConfig if this layer is retrievable
|
|
399
|
-
resampling_method: how to resample rasters (if needed), default bilinear resampling
|
|
400
|
-
alias: alias for this layer to use in the tile store
|
|
401
|
-
"""
|
|
402
|
-
super().__init__(layer_type, data_source, alias)
|
|
403
|
-
self.band_sets = band_sets
|
|
404
|
-
self.resampling_method = resampling_method
|
|
464
|
+
class LayerConfig(BaseModel):
|
|
465
|
+
"""Configuration of a layer in a dataset."""
|
|
405
466
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
467
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
468
|
+
|
|
469
|
+
type: LayerType = Field(description="The LayerType (raster or vector).")
|
|
470
|
+
data_source: DataSourceConfig | None = Field(
|
|
471
|
+
default=None,
|
|
472
|
+
description="Optional DataSourceConfig if this layer is retrievable.",
|
|
473
|
+
)
|
|
474
|
+
alias: str | None = Field(
|
|
475
|
+
default=None, description="Alias for this layer to use in the tile store."
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Raster layer options.
|
|
479
|
+
band_sets: list[BandSetConfig] = Field(
|
|
480
|
+
default_factory=lambda: [],
|
|
481
|
+
description="For raster layers, the bands to store in this layer.",
|
|
482
|
+
)
|
|
483
|
+
resampling_method: ResamplingMethod = Field(
|
|
484
|
+
default=ResamplingMethod.BILINEAR,
|
|
485
|
+
description="For raster layers, how to resample rasters (if neeed), default bilinear resampling.",
|
|
486
|
+
)
|
|
487
|
+
compositing_method: CompositingMethod = Field(
|
|
488
|
+
default=CompositingMethod.FIRST_VALID,
|
|
489
|
+
description="For raster layers, how to compute pixel values in the composite of each window's items.",
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Vector layer options.
|
|
493
|
+
vector_format: dict[str, Any] = Field(
|
|
494
|
+
default_factory=lambda: {
|
|
495
|
+
"class_path": "rslearn.utils.vector_format.GeojsonVectorFormat"
|
|
496
|
+
},
|
|
497
|
+
description="For vector layers, the jsonargparse configuration for the VectorFormat.",
|
|
498
|
+
)
|
|
499
|
+
class_property_name: str | None = Field(
|
|
500
|
+
default=None,
|
|
501
|
+
description="Optional metadata field indicating that the GeoJSON features contain a property that corresponds to a class label, and this is the name of that property.",
|
|
502
|
+
)
|
|
503
|
+
class_names: list[str] | None = Field(
|
|
504
|
+
default=None,
|
|
505
|
+
description="The list of classes that the class_property_name property could be set to.",
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
@model_validator(mode="after")
|
|
509
|
+
def after_validator(self) -> "LayerConfig":
|
|
510
|
+
"""Ensure the LayerConfig is valid."""
|
|
511
|
+
if self.type == LayerType.RASTER and len(self.band_sets) == 0:
|
|
512
|
+
raise ValueError(
|
|
513
|
+
"band sets must be specified and non-empty for raster layers"
|
|
514
|
+
)
|
|
409
515
|
|
|
410
|
-
|
|
411
|
-
config: the config dict for this RasterLayerConfig
|
|
412
|
-
"""
|
|
413
|
-
kwargs = {
|
|
414
|
-
"layer_type": LayerType(config["type"]),
|
|
415
|
-
"band_sets": [BandSetConfig.from_config(el) for el in config["band_sets"]],
|
|
416
|
-
}
|
|
417
|
-
if "data_source" in config:
|
|
418
|
-
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
|
|
419
|
-
if "resampling_method" in config:
|
|
420
|
-
kwargs["resampling_method"] = RESAMPLING_METHODS[
|
|
421
|
-
config["resampling_method"]
|
|
422
|
-
]
|
|
423
|
-
if "alias" in config:
|
|
424
|
-
kwargs["alias"] = config["alias"]
|
|
425
|
-
return RasterLayerConfig(**kwargs)
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
class VectorLayerConfig(LayerConfig):
|
|
429
|
-
"""Configuration of a vector layer."""
|
|
430
|
-
|
|
431
|
-
def __init__(
|
|
432
|
-
self,
|
|
433
|
-
layer_type: LayerType,
|
|
434
|
-
data_source: DataSourceConfig | None = None,
|
|
435
|
-
zoom_offset: int = 0,
|
|
436
|
-
format: VectorFormatConfig = VectorFormatConfig("geojson"),
|
|
437
|
-
alias: str | None = None,
|
|
438
|
-
):
|
|
439
|
-
"""Initialize a new VectorLayerConfig.
|
|
516
|
+
return self
|
|
440
517
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
zoom_offset: zoom offset at which to store the vector data
|
|
445
|
-
format: the VectorFormatConfig, default storing as GeoJSON
|
|
446
|
-
alias: alias for this layer to use in the tile store
|
|
447
|
-
"""
|
|
448
|
-
super().__init__(layer_type, data_source, alias)
|
|
449
|
-
self.zoom_offset = zoom_offset
|
|
450
|
-
self.format = format
|
|
518
|
+
def __hash__(self) -> int:
|
|
519
|
+
"""Return a hash of this LayerConfig."""
|
|
520
|
+
return hash(json.dumps(self.model_dump(mode="json"), sort_keys=True))
|
|
451
521
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
"""Create a VectorLayerConfig from config dict.
|
|
522
|
+
def __eq__(self, other: Any) -> bool:
|
|
523
|
+
"""Returns whether other is the same as this LayerConfig.
|
|
455
524
|
|
|
456
525
|
Args:
|
|
457
|
-
|
|
526
|
+
other: the other object to compare.
|
|
458
527
|
"""
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
if "zoom_offset" in config:
|
|
463
|
-
kwargs["zoom_offset"] = config["zoom_offset"]
|
|
464
|
-
if "format" in config:
|
|
465
|
-
kwargs["format"] = VectorFormatConfig.from_config(config["format"])
|
|
466
|
-
if "alias" in config:
|
|
467
|
-
kwargs["alias"] = config["alias"]
|
|
468
|
-
return VectorLayerConfig(**kwargs)
|
|
528
|
+
if not isinstance(other, LayerConfig):
|
|
529
|
+
return False
|
|
530
|
+
return self.model_dump() == other.model_dump()
|
|
469
531
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
"""Gets the final projection/bounds based on zoom offset.
|
|
532
|
+
@functools.cache
|
|
533
|
+
def instantiate_data_source(self, ds_path: UPath | None = None) -> "DataSource":
|
|
534
|
+
"""Instantiate the data source specified by this config.
|
|
474
535
|
|
|
475
536
|
Args:
|
|
476
|
-
|
|
477
|
-
bounds: the window's bounds (optional)
|
|
537
|
+
ds_path: optional dataset path to include in the DataSourceContext.
|
|
478
538
|
|
|
479
539
|
Returns:
|
|
480
|
-
|
|
540
|
+
the DataSource object.
|
|
481
541
|
"""
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
projection = Projection(
|
|
485
|
-
projection.crs,
|
|
486
|
-
projection.x_resolution / (2**self.zoom_offset),
|
|
487
|
-
projection.y_resolution / (2**self.zoom_offset),
|
|
488
|
-
)
|
|
489
|
-
if bounds:
|
|
490
|
-
if self.zoom_offset > 0:
|
|
491
|
-
bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
|
|
492
|
-
else:
|
|
493
|
-
bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
|
|
494
|
-
return projection, bounds
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
def load_layer_config(config: dict[str, Any]) -> LayerConfig:
|
|
498
|
-
"""Load a LayerConfig from a config dict."""
|
|
499
|
-
layer_type = LayerType(config.get("type"))
|
|
500
|
-
if layer_type == LayerType.RASTER:
|
|
501
|
-
return RasterLayerConfig.from_config(config)
|
|
502
|
-
elif layer_type == LayerType.VECTOR:
|
|
503
|
-
return VectorLayerConfig.from_config(config)
|
|
504
|
-
raise ValueError(f"Unknown layer type {layer_type}")
|
|
542
|
+
from rslearn.data_sources.data_source import DataSource, DataSourceContext
|
|
543
|
+
from rslearn.utils.jsonargparse import data_source_context_serializer
|
|
505
544
|
|
|
545
|
+
logger.debug("getting a data source for dataset at %s", ds_path)
|
|
546
|
+
if self.data_source is None:
|
|
547
|
+
raise ValueError("This layer does not specify a data source")
|
|
506
548
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
549
|
+
# Inject the DataSourceContext into the args.
|
|
550
|
+
context = DataSourceContext(
|
|
551
|
+
ds_path=ds_path,
|
|
552
|
+
layer_config=self,
|
|
553
|
+
)
|
|
554
|
+
ds_config: dict[str, Any] = {
|
|
555
|
+
"class_path": self.data_source.class_path,
|
|
556
|
+
"init_args": copy.deepcopy(self.data_source.init_args),
|
|
557
|
+
}
|
|
558
|
+
ds_config["init_args"]["context"] = data_source_context_serializer(context)
|
|
512
559
|
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
self.config_dict = config_dict
|
|
560
|
+
# Now we can parse with jsonargparse.
|
|
561
|
+
from rslearn.utils.jsonargparse import (
|
|
562
|
+
data_source_context_serializer,
|
|
563
|
+
init_jsonargparse,
|
|
564
|
+
)
|
|
519
565
|
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
""
|
|
566
|
+
init_jsonargparse()
|
|
567
|
+
parser = jsonargparse.ArgumentParser()
|
|
568
|
+
parser.add_argument("--data_source", type=DataSource)
|
|
569
|
+
cfg = parser.parse_object({"data_source": ds_config})
|
|
570
|
+
data_source = parser.instantiate_classes(cfg).data_source
|
|
571
|
+
return data_source
|
|
572
|
+
|
|
573
|
+
def instantiate_vector_format(self) -> VectorFormat:
|
|
574
|
+
"""Instantiate the vector format specified by this config."""
|
|
575
|
+
if self.type != LayerType.VECTOR:
|
|
576
|
+
raise ValueError(
|
|
577
|
+
f"cannot instantiate vector format for layer with type {self.type}"
|
|
578
|
+
)
|
|
523
579
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
580
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
581
|
+
|
|
582
|
+
init_jsonargparse()
|
|
583
|
+
parser = jsonargparse.ArgumentParser()
|
|
584
|
+
parser.add_argument("--vector_format", type=VectorFormat)
|
|
585
|
+
cfg = parser.parse_object({"vector_format": self.vector_format})
|
|
586
|
+
vector_format = parser.instantiate_classes(cfg).vector_format
|
|
587
|
+
return vector_format
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
class StorageConfig(BaseModel):
|
|
591
|
+
"""Configuration for the WindowStorageFactory (window metadata storage backend)."""
|
|
592
|
+
|
|
593
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
594
|
+
|
|
595
|
+
class_path: str = Field(
|
|
596
|
+
default="rslearn.dataset.storage.file.FileWindowStorageFactory",
|
|
597
|
+
description="Class path for the WindowStorageFactory.",
|
|
598
|
+
)
|
|
599
|
+
init_args: dict[str, Any] = Field(
|
|
600
|
+
default_factory=lambda: {},
|
|
601
|
+
description="jsonargparse init args for the WindowStorageFactory.",
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
def instantiate_window_storage_factory(self) -> "WindowStorageFactory":
|
|
605
|
+
"""Instantiate the WindowStorageFactory specified by this config."""
|
|
606
|
+
from rslearn.dataset.storage.storage import WindowStorageFactory
|
|
607
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
608
|
+
|
|
609
|
+
init_jsonargparse()
|
|
610
|
+
parser = jsonargparse.ArgumentParser()
|
|
611
|
+
parser.add_argument("--wsf", type=WindowStorageFactory)
|
|
612
|
+
cfg = parser.parse_object(
|
|
613
|
+
{
|
|
614
|
+
"wsf": dict(
|
|
615
|
+
class_path=self.class_path,
|
|
616
|
+
init_args=self.init_args,
|
|
617
|
+
)
|
|
618
|
+
}
|
|
619
|
+
)
|
|
620
|
+
wsf = parser.instantiate_classes(cfg).wsf
|
|
621
|
+
return wsf
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
class DatasetConfig(BaseModel):
|
|
625
|
+
"""Overall dataset configuration."""
|
|
626
|
+
|
|
627
|
+
model_config = ConfigDict(extra="forbid")
|
|
628
|
+
|
|
629
|
+
layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
|
|
630
|
+
tile_store: dict[str, Any] = Field(
|
|
631
|
+
default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
|
|
632
|
+
description="jsonargparse configuration for the TileStore.",
|
|
633
|
+
)
|
|
634
|
+
storage: StorageConfig = Field(
|
|
635
|
+
default_factory=lambda: StorageConfig(),
|
|
636
|
+
description="jsonargparse configuration for the WindowStorageFactory.",
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
@field_validator("layers", mode="after")
|
|
640
|
+
@classmethod
|
|
641
|
+
def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
|
|
642
|
+
"""Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
|
|
643
|
+
for layer_name in v.keys():
|
|
644
|
+
if "." in layer_name:
|
|
645
|
+
raise ValueError(f"layer names must not contain periods: {layer_name}")
|
|
646
|
+
return v
|