rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py CHANGED
@@ -1,25 +1,84 @@
1
1
  """Classes for storing configuration of a dataset."""
2
2
 
3
+ import copy
4
+ import functools
5
+ import json
6
+ import warnings
3
7
  from datetime import timedelta
4
- from enum import Enum
5
- from typing import Any
8
+ from enum import StrEnum
9
+ from typing import TYPE_CHECKING, Annotated, Any
6
10
 
11
+ import jsonargparse
7
12
  import numpy as np
8
13
  import numpy.typing as npt
9
14
  import pytimeparse
10
- import torch
15
+ from pydantic import (
16
+ BaseModel,
17
+ BeforeValidator,
18
+ ConfigDict,
19
+ Field,
20
+ PlainSerializer,
21
+ field_validator,
22
+ model_validator,
23
+ )
11
24
  from rasterio.enums import Resampling
25
+ from upath import UPath
12
26
 
13
- from rslearn.utils import PixelBounds, Projection
27
+ from rslearn.log_utils import get_logger
28
+ from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
29
+ from rslearn.utils.raster_format import RasterFormat
30
+ from rslearn.utils.vector_format import VectorFormat
14
31
 
32
+ if TYPE_CHECKING:
33
+ from rslearn.data_sources.data_source import DataSource
34
+ from rslearn.dataset.storage.storage import WindowStorageFactory
15
35
 
16
- class DType(Enum):
36
+ logger = get_logger("__name__")
37
+
38
+
39
+ def ensure_timedelta(v: Any) -> Any:
40
+ """Ensure the value is a timedelta.
41
+
42
+ If the value is a string, we try to parse it with pytimeparse.
43
+
44
+ This function is meant to be used like Annotated[timedelta, BeforeValidator(ensure_timedelta)].
45
+ """
46
+ if isinstance(v, timedelta):
47
+ return v
48
+ if isinstance(v, str):
49
+ return pytimeparse.parse(v)
50
+ raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
51
+
52
+
53
+ def ensure_optional_timedelta(v: Any) -> Any:
54
+ """Like ensure_timedelta, but allows None as a value."""
55
+ if v is None:
56
+ return None
57
+ if isinstance(v, timedelta):
58
+ return v
59
+ if isinstance(v, str):
60
+ return pytimeparse.parse(v)
61
+ raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
62
+
63
+
64
+ def serialize_optional_timedelta(v: timedelta | None) -> str | None:
65
+ """Serialize an optional timedelta for compatibility with pytimeparse."""
66
+ if v is None:
67
+ return None
68
+ return str(v.total_seconds()) + "s"
69
+
70
+
71
+ class DType(StrEnum):
17
72
  """Data type of a raster."""
18
73
 
19
74
  UINT8 = "uint8"
20
75
  UINT16 = "uint16"
21
76
  UINT32 = "uint32"
77
+ UINT64 = "uint64"
78
+ INT8 = "int8"
79
+ INT16 = "int16"
22
80
  INT32 = "int32"
81
+ INT64 = "int64"
23
82
  FLOAT32 = "float32"
24
83
 
25
84
  def get_numpy_dtype(self) -> npt.DTypeLike:
@@ -30,77 +89,43 @@ class DType(Enum):
30
89
  return np.uint16
31
90
  elif self == DType.UINT32:
32
91
  return np.uint32
92
+ elif self == DType.UINT64:
93
+ return np.uint64
94
+ elif self == DType.INT8:
95
+ return np.int8
96
+ elif self == DType.INT16:
97
+ return np.int16
33
98
  elif self == DType.INT32:
34
99
  return np.int32
100
+ elif self == DType.INT64:
101
+ return np.int64
35
102
  elif self == DType.FLOAT32:
36
103
  return np.float32
37
104
  raise ValueError(f"unable to handle numpy dtype {self}")
38
105
 
39
- def get_torch_dtype(self) -> torch.dtype:
40
- """Returns pytorch dtype object corresponding to this DType."""
41
- if self == DType.INT32:
42
- return torch.int32
43
- elif self == DType.FLOAT32:
44
- return torch.float32
45
- else:
46
- raise ValueError(f"unable to handle torch dtype {self}")
47
106
 
107
+ class ResamplingMethod(StrEnum):
108
+ """An enum representing the rasterio Resampling."""
48
109
 
49
- RESAMPLING_METHODS = {
50
- "nearest": Resampling.nearest,
51
- "bilinear": Resampling.bilinear,
52
- "cubic": Resampling.cubic,
53
- "cubic_spline": Resampling.cubic_spline,
54
- }
110
+ NEAREST = "nearest"
111
+ BILINEAR = "bilinear"
112
+ CUBIC = "cubic"
113
+ CUBIC_SPLINE = "cubic_spline"
55
114
 
115
+ def get_rasterio_resampling(self) -> Resampling:
116
+ """Get the rasterio Resampling corresponding to this ResamplingMethod."""
117
+ return RESAMPLING_METHODS[self]
56
118
 
57
- class RasterFormatConfig:
58
- """A configuration specifying a RasterFormat."""
59
-
60
- def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
61
- """Initialize a new RasterFormatConfig.
62
-
63
- Args:
64
- name: the name of the RasterFormat to use.
65
- config_dict: configuration to pass to the RasterFormat.
66
- """
67
- self.name = name
68
- self.config_dict = config_dict
69
-
70
- @staticmethod
71
- def from_config(config: dict[str, Any]) -> "RasterFormatConfig":
72
- """Create a RasterFormatConfig from config dict.
73
-
74
- Args:
75
- config: the config dict for this RasterFormatConfig
76
- """
77
- return RasterFormatConfig(name=config["name"], config_dict=config)
78
-
79
-
80
- class VectorFormatConfig:
81
- """A configuration specifying a VectorFormat."""
82
-
83
- def __init__(self, name: str, config_dict: dict[str, Any] = {}) -> None:
84
- """Initialize a new VectorFormatConfig.
85
-
86
- Args:
87
- name: the name of the VectorFormat to use.
88
- config_dict: configuration to pass to the VectorFormat.
89
- """
90
- self.name = name
91
- self.config_dict = config_dict
92
-
93
- @staticmethod
94
- def from_config(config: dict[str, Any]) -> "VectorFormatConfig":
95
- """Create a VectorFormatConfig from config dict.
96
119
 
97
- Args:
98
- config: the config dict for this VectorFormatConfig
99
- """
100
- return VectorFormatConfig(name=config["name"], config_dict=config)
120
+ RESAMPLING_METHODS = {
121
+ ResamplingMethod.NEAREST: Resampling.nearest,
122
+ ResamplingMethod.BILINEAR: Resampling.bilinear,
123
+ ResamplingMethod.CUBIC: Resampling.cubic,
124
+ ResamplingMethod.CUBIC_SPLINE: Resampling.cubic_spline,
125
+ }
101
126
 
102
127
 
103
- class BandSetConfig:
128
+ class BandSetConfig(BaseModel):
104
129
  """A configuration for a band set in a raster layer.
105
130
 
106
131
  Each band set specifies one or more bands that should be stored together.
@@ -108,65 +133,75 @@ class BandSetConfig:
108
133
  bands.
109
134
  """
110
135
 
111
- def __init__(
112
- self,
113
- config_dict: dict[str, Any],
114
- dtype: DType,
115
- bands: list[str] | None = None,
116
- format: dict[str, Any] | None = None,
117
- zoom_offset: int = 0,
118
- remap: dict[str, Any] | None = None,
119
- ) -> None:
120
- """Creates a new BandSetConfig instance.
121
-
122
- Args:
123
- config_dict: the config dict used to configure this BandSetConfig
124
- dtype: the pixel value type to store tiles in
125
- bands: list of band names in this BandSetConfig
126
- format: the format to store tiles in, defaults to geotiff
127
- zoom_offset: non-negative integer, store images at window resolution
128
- divided by 2^(zoom_offset).
129
- remap: config dict for Remapper to remap pixel values
130
- """
131
- self.config_dict = config_dict
132
- self.bands = bands
133
- self.format = format
134
- self.dtype = dtype
135
- self.zoom_offset = zoom_offset
136
- self.remap = remap
137
-
138
- if not self.format:
139
- self.format = {"name": "geotiff"}
140
-
141
- def serialize(self) -> dict[str, Any]:
142
- """Serialize this BandSetConfig to a config dict, currently unused."""
143
- return {
144
- "bands": self.bands,
145
- "format": self.format,
146
- "dtype": self.dtype,
147
- "zoom_offset": self.zoom_offset,
148
- "remap": self.remap,
149
- }
150
-
151
- @staticmethod
152
- def from_config(config: dict[str, Any]) -> "BandSetConfig":
153
- """Create a BandSetConfig from config dict.
154
-
155
- Args:
156
- config: the config dict for this BandSetConfig
157
- """
158
- kwargs = dict(
159
- config_dict=config,
160
- dtype=DType(config["dtype"]),
161
- )
162
- for k in ["bands", "format", "zoom_offset", "remap"]:
163
- if k in config:
164
- kwargs[k] = config[k]
165
- return BandSetConfig(**kwargs)
136
+ model_config = ConfigDict(extra="forbid")
137
+
138
+ dtype: DType = Field(
139
+ description="Pixel value type to store the data under. This is used during dataset materialize and model predict."
140
+ )
141
+ bands: list[str] = Field(
142
+ default_factory=lambda: [],
143
+ description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
144
+ )
145
+ num_bands: int | None = Field(
146
+ default=None,
147
+ description="The number of bands in this band set. The bands will be named B0, B1, B2, etc.",
148
+ )
149
+ format: dict[str, Any] = Field(
150
+ default_factory=lambda: {
151
+ "class_path": "rslearn.utils.raster_format.GeotiffRasterFormat"
152
+ },
153
+ description="jsonargparse configuration for the RasterFormat to store the tiles in.",
154
+ )
155
+
156
+ # Store images at a resolution higher or lower than the window resolution. This
157
+ # enables keeping source data at its native resolution, either to save storage
158
+ # space (for lower resolution data) or to retain details (for higher resolution
159
+ # data). If positive, store data at the window resolution divided by
160
+ # 2^(zoom_offset) (higher resolution). If negative, store data at the window
161
+ # resolution multiplied by 2^(-zoom_offset) (lower resolution).
162
+ zoom_offset: int = Field(
163
+ default=0,
164
+ description="Store data at the window resolution multiplied by 2^(-zoom_offset).",
165
+ )
166
+
167
+ remap: dict[str, Any] | None = Field(
168
+ default=None,
169
+ description="Optional jsonargparse configuration for a Remapper to remap pixel values.",
170
+ )
171
+
172
+ # Optional list of names for the different possible values of each band. The length
173
+ # of this list must equal the number of bands. For example, [["forest", "desert"]]
174
+ # means that it is a single-band raster where values can be 0 (forest) or 1
175
+ # (desert).
176
+ class_names: list[list[str]] | None = Field(
177
+ default=None,
178
+ description="Optional list of names for the different possible values of each band.",
179
+ )
180
+
181
+ # Optional list of nodata values for this band set. This is used during
182
+ # materialization when creating mosaics, to determine which parts of the source
183
+ # images should be copied.
184
+ nodata_vals: list[float] | None = Field(
185
+ default=None, description="Optional nodata value for each band."
186
+ )
187
+
188
+ @model_validator(mode="after")
189
+ def after_validator(self) -> "BandSetConfig":
190
+ """Ensure the BandSetConfig is valid, and handle the num_bands field."""
191
+ if (len(self.bands) == 0 and self.num_bands is None) or (
192
+ len(self.bands) != 0 and self.num_bands is not None
193
+ ):
194
+ raise ValueError("exactly one of bands and num_bands must be specified")
195
+
196
+ if self.num_bands is not None:
197
+ self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
198
+ self.num_bands = None
199
+
200
+ return self
166
201
 
167
202
  def get_final_projection_and_bounds(
168
- self, projection: Projection, bounds: PixelBounds | None
169
- ) -> tuple[Projection, PixelBounds | None]:
203
+ self, projection: Projection, bounds: PixelBounds
204
+ ) -> tuple[Projection, PixelBounds]:
170
205
  """Gets the final projection/bounds based on band set config.
171
206
 
172
207
  The band set config may apply a non-zero zoom offset that modifies the window's
@@ -180,348 +215,432 @@ class BandSetConfig:
180
215
  Returns:
181
216
  tuple of updated projection and bounds with zoom offset applied
182
217
  """
183
- if self.zoom_offset == 0:
184
- return projection, bounds
185
- projection = Projection(
186
- projection.crs,
187
- projection.x_resolution / (2**self.zoom_offset),
188
- projection.y_resolution / (2**self.zoom_offset),
218
+ if self.zoom_offset >= 0:
219
+ factor = ResolutionFactor(numerator=2**self.zoom_offset)
220
+ else:
221
+ factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
222
+
223
+ return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
224
+
225
+ @field_validator("format", mode="before")
226
+ @classmethod
227
+ def convert_format_from_legacy(cls, v: dict[str, Any]) -> dict[str, Any]:
228
+ """Support legacy format of the RasterFormat.
229
+
230
+ The legacy format sets 'name' instead of 'class_path', and uses custom parsing
231
+ for the init_args.
232
+ """
233
+ if "name" not in v:
234
+ # New version, it is all good.
235
+ return v
236
+
237
+ warnings.warn(
238
+ "`format = {'name': ...}` is deprecated; "
239
+ "use `{'class_path': '...', 'init_args': {...}}` instead.",
240
+ DeprecationWarning,
241
+ )
242
+ logger.warning(
243
+ "BandSet.format uses legacy format; support will be removed after 2026-03-01."
244
+ )
245
+
246
+ legacy_name_to_class_path = {
247
+ "image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
248
+ "geotiff": "rslearn.utils.raster_format.GeotiffRasterFormat",
249
+ "single_image": "rslearn.utils.raster_format.SingleImageRasterFormat",
250
+ }
251
+ if v["name"] not in legacy_name_to_class_path:
252
+ raise ValueError(
253
+ f"could not parse legacy format with unknown raster format {v['name']}"
254
+ )
255
+ init_args = dict(v)
256
+ class_path = legacy_name_to_class_path[init_args.pop("name")]
257
+
258
+ return dict(
259
+ class_path=class_path,
260
+ init_args=init_args,
189
261
  )
190
- if bounds:
191
- if self.zoom_offset > 0:
192
- bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
193
- else:
194
- bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
195
- return projection, bounds
196
262
 
263
+ def instantiate_raster_format(self) -> RasterFormat:
264
+ """Instantiate the RasterFormat specified by this BandSetConfig."""
265
+ from rslearn.utils.jsonargparse import init_jsonargparse
197
266
 
198
- class SpaceMode(Enum):
267
+ init_jsonargparse()
268
+ parser = jsonargparse.ArgumentParser()
269
+ parser.add_argument("--raster_format", type=RasterFormat)
270
+ cfg = parser.parse_object({"raster_format": self.format})
271
+ raster_format = parser.instantiate_classes(cfg).raster_format
272
+ return raster_format
273
+
274
+
275
+ class SpaceMode(StrEnum):
199
276
  """Spatial matching mode when looking up items corresponding to a window."""
200
277
 
201
- CONTAINS = 1
278
+ CONTAINS = "CONTAINS"
202
279
  """Items must contain the entire window."""
203
280
 
204
- INTERSECTS = 2
281
+ INTERSECTS = "INTERSECTS"
205
282
  """Items must overlap any portion of the window."""
206
283
 
207
- MOSAIC = 3
284
+ MOSAIC = "MOSAIC"
208
285
  """Groups of items should be computed that cover the entire window.
209
286
 
210
287
  During materialization, items in each group are merged to form a mosaic in the
211
288
  dataset.
212
289
  """
213
290
 
291
+ PER_PERIOD_MOSAIC = "PER_PERIOD_MOSAIC"
292
+ """Create one mosaic per sub-period of the time range.
214
293
 
215
- class TimeMode(Enum):
216
- """Temporal matching mode when looking up items corresponding to a window."""
217
-
218
- WITHIN = 1
219
- """Items must be within the window time range."""
294
+ The duration of the sub-periods is controlled by another option in QueryConfig.
295
+ """
220
296
 
221
- NEAREST = 2
222
- """Select items closest to the window time range, up to max_matches."""
297
+ COMPOSITE = "COMPOSITE"
298
+ """Creates one composite covering the entire window.
223
299
 
224
- BEFORE = 3
225
- """Select items before the start of the window time range, up to max_matches."""
300
+ During querying all items intersecting the window are placed in one group.
301
+ The compositing_method in the rasterlayer config specifies how these items are reduced
302
+ to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
303
+ """
226
304
 
227
- AFTER = 4
228
- """Select items after the end of the window time range, up to max_matches."""
305
+ # TODO add PER_PERIOD_COMPOSITE
229
306
 
230
307
 
231
- class QueryConfig:
232
- """A configuration for querying items in a data source."""
308
+ class TimeMode(StrEnum):
309
+ """Temporal matching mode when looking up items corresponding to a window."""
233
310
 
234
- def __init__(
235
- self,
236
- space_mode: SpaceMode = SpaceMode.MOSAIC,
237
- time_mode: TimeMode = TimeMode.WITHIN,
238
- max_matches: int = 1,
239
- ):
240
- """Creates a new query configuration.
311
+ WITHIN = "WITHIN"
312
+ """Items must be within the window time range."""
241
313
 
242
- The provided options determine how a DataSource should lookup items that match a
243
- spatiotemporal window.
314
+ NEAREST = "NEAREST"
315
+ """Select items closest to the window time range, up to max_matches."""
244
316
 
245
- Args:
246
- space_mode: specifies how items should be matched with windows spatially
247
- time_mode: specifies how items should be matched with windows temporally
248
- max_matches: the maximum number of items (or groups of items, if space_mode
249
- is MOSAIC) to match
250
- """
251
- self.space_mode = space_mode
252
- self.time_mode = time_mode
253
- self.max_matches = max_matches
254
-
255
- def serialize(self) -> dict[str, Any]:
256
- """Serialize this QueryConfig to a config dict, currently unused."""
257
- return {
258
- "space_mode": str(self.space_mode),
259
- "time_mode": str(self.time_mode),
260
- "max_matches": self.max_matches,
261
- }
317
+ BEFORE = "BEFORE"
318
+ """Select items before the end of the window time range, up to max_matches."""
262
319
 
263
- @staticmethod
264
- def from_config(config: dict[str, Any]) -> "QueryConfig":
265
- """Create a QueryConfig from config dict.
320
+ AFTER = "AFTER"
321
+ """Select items after the start of the window time range, up to max_matches."""
266
322
 
267
- Args:
268
- config: the config dict for this QueryConfig
269
- """
270
- return QueryConfig(
271
- space_mode=SpaceMode[config.get("space_mode", "MOSAIC")],
272
- time_mode=TimeMode[config.get("time_mode", "WITHIN")],
273
- max_matches=config.get("max_matches", 1),
274
- )
275
323
 
324
+ class QueryConfig(BaseModel):
325
+ """A configuration for querying items in a data source."""
276
326
 
277
- class DataSourceConfig:
327
+ model_config = ConfigDict(frozen=True, extra="forbid")
328
+
329
+ space_mode: SpaceMode = Field(
330
+ default=SpaceMode.MOSAIC,
331
+ description="Specifies how items should be matched with windows spatially.",
332
+ )
333
+ time_mode: TimeMode = Field(
334
+ default=TimeMode.WITHIN,
335
+ description="Specifies how items should be matched with windows temporally.",
336
+ )
337
+
338
+ # Minimum number of item groups. If there are fewer than this many matches, then no
339
+ # matches will be returned. This can be used to prevent unnecessary data ingestion
340
+ # if the user plans to discard windows that do not have a sufficient amount of data.
341
+ min_matches: int = Field(
342
+ default=0, description="The minimum number of item groups."
343
+ )
344
+
345
+ max_matches: int = Field(
346
+ default=1, description="The maximum number of item groups."
347
+ )
348
+ period_duration: Annotated[
349
+ timedelta,
350
+ BeforeValidator(ensure_timedelta),
351
+ PlainSerializer(serialize_optional_timedelta),
352
+ ] = Field(
353
+ default=timedelta(days=30),
354
+ description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
355
+ )
356
+
357
+
358
+ class DataSourceConfig(BaseModel):
278
359
  """Configuration for a DataSource in a dataset layer."""
279
360
 
280
- def __init__(
281
- self,
282
- name: str,
283
- query_config: QueryConfig,
284
- config_dict: dict[str, Any],
285
- time_offset: timedelta | None = None,
286
- duration: timedelta | None = None,
287
- ingest: bool = True,
288
- ) -> None:
289
- """Initializes a new DataSourceConfig.
290
-
291
- Args:
292
- name: the data source class name
293
- query_config: the QueryConfig specifying how to match items with windows
294
- config_dict: additional config passed to initialize the DataSource
295
- time_offset: optional, add this timedelta to the window's time range before
296
- matching
297
- duration: optional, if window's time range is (t0, t1), then update to
298
- (t0, t0 + duration)
299
- ingest: whether to ingest this layer or directly materialize it
300
- (default true)
361
+ model_config = ConfigDict(frozen=True, extra="forbid")
362
+
363
+ class_path: str = Field(description="Class path for the data source.")
364
+ init_args: dict[str, Any] = Field(
365
+ default_factory=lambda: {},
366
+ description="jsonargparse init args for the data source.",
367
+ )
368
+ query_config: QueryConfig = Field(
369
+ default_factory=lambda: QueryConfig(),
370
+ description="QueryConfig specifying how to match items with windows.",
371
+ )
372
+ time_offset: Annotated[
373
+ timedelta | None,
374
+ BeforeValidator(ensure_optional_timedelta),
375
+ PlainSerializer(serialize_optional_timedelta),
376
+ ] = Field(
377
+ default=None,
378
+ description="Optional timedelta to add to the window's time range before matching.",
379
+ )
380
+ duration: Annotated[
381
+ timedelta | None,
382
+ BeforeValidator(ensure_optional_timedelta),
383
+ PlainSerializer(serialize_optional_timedelta),
384
+ ] = Field(
385
+ default=None,
386
+ description="Optional, if the window's time range is (t0, t1), then update to (t0, t0 + duration).",
387
+ )
388
+ ingest: bool = Field(
389
+ default=True,
390
+ description="Whether to ingest this layer (default True). If False, it will be directly materialized without ingestion.",
391
+ )
392
+
393
+ @model_validator(mode="before")
394
+ @classmethod
395
+ def convert_from_legacy(cls, d: dict[str, Any]) -> dict[str, Any]:
396
+ """Support legacy format of the DataSourceConfig.
397
+
398
+ The legacy format sets 'name' instead of 'class_path', and mixes the arguments
399
+ for the DataSource in with the DataSourceConfig keys.
301
400
  """
302
- self.name = name
303
- self.query_config = query_config
304
- self.config_dict = config_dict
305
- self.time_offset = time_offset
306
- self.duration = duration
307
- self.ingest = ingest
308
-
309
- def serialize(self) -> dict[str, Any]:
310
- """Serialize this DataSourceConfig to a config dict, currently unused."""
311
- config_dict = self.config_dict.copy()
312
- config_dict["name"] = self.name
313
- config_dict["query_config"] = self.query_config.serialize()
314
- config_dict["ingest"] = self.ingest
315
- if self.time_offset:
316
- config_dict["time_offset"] = str(self.time_offset)
317
- if self.duration:
318
- config_dict["duration"] = str(self.duration)
319
- return config_dict
320
-
321
- @staticmethod
322
- def from_config(config: dict[str, Any]) -> "DataSourceConfig":
323
- """Create a DataSourceConfig from config dict.
324
-
325
- Args:
326
- config: the config dict for this DataSourceConfig
327
- """
328
- kwargs = dict(
329
- name=config["name"],
330
- query_config=QueryConfig.from_config(config.get("query_config", {})),
331
- config_dict=config,
401
+ if "name" not in d:
402
+ # New version, it is all good.
403
+ return d
404
+
405
+ warnings.warn(
406
+ "`Data source configuration {'name': ...}` is deprecated; "
407
+ "use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
408
+ DeprecationWarning,
332
409
  )
333
- if "time_offset" in config:
334
- kwargs["time_offset"] = timedelta(
335
- seconds=pytimeparse.parse(config["time_offset"])
336
- )
337
- if "duration" in config:
338
- kwargs["duration"] = timedelta(
339
- seconds=pytimeparse.parse(config["duration"])
410
+ logger.warning(
411
+ "Data source configuration uses legacy format; support will be removed after 2026-03-01."
412
+ )
413
+
414
+ # Split the dict into the base config that is in the pydantic model, and the
415
+ # source-specific options that should be moved to init_args dict.
416
+ class_path = d["name"]
417
+ base_config: dict[str, Any] = {}
418
+ ds_init_args: dict[str, Any] = {}
419
+ for k, v in d.items():
420
+ if k == "name":
421
+ continue
422
+ if k in cls.model_fields:
423
+ base_config[k] = v
424
+ else:
425
+ ds_init_args[k] = v
426
+
427
+ # Some legacy configs erroneously specify these keys, which are now caught by
428
+ # validation. But we still want those specific legacy configs to work.
429
+ if (
430
+ class_path == "rslearn.data_sources.planetary_computer.Sentinel2"
431
+ and "max_cloud_cover" in ds_init_args
432
+ ):
433
+ warnings.warn(
434
+ "Data source configuration specifies invalid 'max_cloud_cover' option.",
435
+ DeprecationWarning,
340
436
  )
341
- if "ingest" in config:
342
- kwargs["ingest"] = config["ingest"]
343
- return DataSourceConfig(**kwargs)
437
+ del ds_init_args["max_cloud_cover"]
438
+
439
+ base_config["class_path"] = class_path
440
+ base_config["init_args"] = ds_init_args
441
+ return base_config
344
442
 
345
443
 
346
- class LayerType(Enum):
444
+ class LayerType(StrEnum):
347
445
  """The layer type (raster or vector)."""
348
446
 
349
447
  RASTER = "raster"
350
448
  VECTOR = "vector"
351
449
 
352
450
 
353
- class LayerConfig:
354
- """Configuration of a layer in a dataset."""
355
-
356
- def __init__(
357
- self,
358
- layer_type: LayerType,
359
- data_source: DataSourceConfig | None = None,
360
- alias: str | None = None,
361
- ):
362
- """Initialize a new LayerConfig.
451
+ class CompositingMethod(StrEnum):
452
+ """Method how to select pixels for the composite from corresponding items of a window."""
363
453
 
364
- Args:
365
- layer_type: the LayerType (raster or vector)
366
- data_source: optional DataSourceConfig if this layer is retrievable
367
- alias: alias for this layer to use in the tile store
368
- """
369
- self.layer_type = layer_type
370
- self.data_source = data_source
371
- self.alias = alias
372
-
373
- def serialize(self) -> dict[str, Any]:
374
- """Serialize this LayerConfig to a config dict, currently unused."""
375
- return {
376
- "layer_type": str(self.layer_type),
377
- "data_source": self.data_source,
378
- "alias": self.alias,
379
- }
454
+ FIRST_VALID = "FIRST_VALID"
455
+ """Select first valid pixel in order of corresponding items (might be sorted)"""
380
456
 
457
+ MEAN = "MEAN"
458
+ """Select per-pixel mean value of corresponding items of a window"""
381
459
 
382
- class RasterLayerConfig(LayerConfig):
383
- """Configuration of a raster layer."""
460
+ MEDIAN = "MEDIAN"
461
+ """Select per-pixel median value of corresponding items of a window"""
384
462
 
385
- def __init__(
386
- self,
387
- layer_type: LayerType,
388
- band_sets: list[BandSetConfig],
389
- data_source: DataSourceConfig | None = None,
390
- resampling_method: Resampling = Resampling.bilinear,
391
- alias: str | None = None,
392
- ):
393
- """Initialize a new RasterLayerConfig.
394
463
 
395
- Args:
396
- layer_type: the LayerType (must be raster)
397
- band_sets: the bands to store in this layer
398
- data_source: optional DataSourceConfig if this layer is retrievable
399
- resampling_method: how to resample rasters (if needed), default bilinear resampling
400
- alias: alias for this layer to use in the tile store
401
- """
402
- super().__init__(layer_type, data_source, alias)
403
- self.band_sets = band_sets
404
- self.resampling_method = resampling_method
464
+ class LayerConfig(BaseModel):
465
+ """Configuration of a layer in a dataset."""
405
466
 
406
- @staticmethod
407
- def from_config(config: dict[str, Any]) -> "RasterLayerConfig":
408
- """Create a RasterLayerConfig from config dict.
467
+ model_config = ConfigDict(frozen=True, extra="forbid")
468
+
469
+ type: LayerType = Field(description="The LayerType (raster or vector).")
470
+ data_source: DataSourceConfig | None = Field(
471
+ default=None,
472
+ description="Optional DataSourceConfig if this layer is retrievable.",
473
+ )
474
+ alias: str | None = Field(
475
+ default=None, description="Alias for this layer to use in the tile store."
476
+ )
477
+
478
+ # Raster layer options.
479
+ band_sets: list[BandSetConfig] = Field(
480
+ default_factory=lambda: [],
481
+ description="For raster layers, the bands to store in this layer.",
482
+ )
483
+ resampling_method: ResamplingMethod = Field(
484
+ default=ResamplingMethod.BILINEAR,
485
+ description="For raster layers, how to resample rasters (if neeed), default bilinear resampling.",
486
+ )
487
+ compositing_method: CompositingMethod = Field(
488
+ default=CompositingMethod.FIRST_VALID,
489
+ description="For raster layers, how to compute pixel values in the composite of each window's items.",
490
+ )
491
+
492
+ # Vector layer options.
493
+ vector_format: dict[str, Any] = Field(
494
+ default_factory=lambda: {
495
+ "class_path": "rslearn.utils.vector_format.GeojsonVectorFormat"
496
+ },
497
+ description="For vector layers, the jsonargparse configuration for the VectorFormat.",
498
+ )
499
+ class_property_name: str | None = Field(
500
+ default=None,
501
+ description="Optional metadata field indicating that the GeoJSON features contain a property that corresponds to a class label, and this is the name of that property.",
502
+ )
503
+ class_names: list[str] | None = Field(
504
+ default=None,
505
+ description="The list of classes that the class_property_name property could be set to.",
506
+ )
507
+
508
+ @model_validator(mode="after")
509
+ def after_validator(self) -> "LayerConfig":
510
+ """Ensure the LayerConfig is valid."""
511
+ if self.type == LayerType.RASTER and len(self.band_sets) == 0:
512
+ raise ValueError(
513
+ "band sets must be specified and non-empty for raster layers"
514
+ )
409
515
 
410
- Args:
411
- config: the config dict for this RasterLayerConfig
412
- """
413
- kwargs = {
414
- "layer_type": LayerType(config["type"]),
415
- "band_sets": [BandSetConfig.from_config(el) for el in config["band_sets"]],
416
- }
417
- if "data_source" in config:
418
- kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
419
- if "resampling_method" in config:
420
- kwargs["resampling_method"] = RESAMPLING_METHODS[
421
- config["resampling_method"]
422
- ]
423
- if "alias" in config:
424
- kwargs["alias"] = config["alias"]
425
- return RasterLayerConfig(**kwargs)
426
-
427
-
428
- class VectorLayerConfig(LayerConfig):
429
- """Configuration of a vector layer."""
430
-
431
- def __init__(
432
- self,
433
- layer_type: LayerType,
434
- data_source: DataSourceConfig | None = None,
435
- zoom_offset: int = 0,
436
- format: VectorFormatConfig = VectorFormatConfig("geojson"),
437
- alias: str | None = None,
438
- ):
439
- """Initialize a new VectorLayerConfig.
516
+ return self
440
517
 
441
- Args:
442
- layer_type: the LayerType (must be vector)
443
- data_source: optional DataSourceConfig if this layer is retrievable
444
- zoom_offset: zoom offset at which to store the vector data
445
- format: the VectorFormatConfig, default storing as GeoJSON
446
- alias: alias for this layer to use in the tile store
447
- """
448
- super().__init__(layer_type, data_source, alias)
449
- self.zoom_offset = zoom_offset
450
- self.format = format
518
+ def __hash__(self) -> int:
519
+ """Return a hash of this LayerConfig."""
520
+ return hash(json.dumps(self.model_dump(mode="json"), sort_keys=True))
451
521
 
452
- @staticmethod
453
- def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
454
- """Create a VectorLayerConfig from config dict.
522
+ def __eq__(self, other: Any) -> bool:
523
+ """Returns whether other is the same as this LayerConfig.
455
524
 
456
525
  Args:
457
- config: the config dict for this VectorLayerConfig
526
+ other: the other object to compare.
458
527
  """
459
- kwargs = {"layer_type": LayerType(config["type"])}
460
- if "data_source" in config:
461
- kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
462
- if "zoom_offset" in config:
463
- kwargs["zoom_offset"] = config["zoom_offset"]
464
- if "format" in config:
465
- kwargs["format"] = VectorFormatConfig.from_config(config["format"])
466
- if "alias" in config:
467
- kwargs["alias"] = config["alias"]
468
- return VectorLayerConfig(**kwargs)
528
+ if not isinstance(other, LayerConfig):
529
+ return False
530
+ return self.model_dump() == other.model_dump()
469
531
 
470
- def get_final_projection_and_bounds(
471
- self, projection: Projection, bounds: PixelBounds | None
472
- ) -> tuple[Projection, PixelBounds | None]:
473
- """Gets the final projection/bounds based on zoom offset.
532
+ @functools.cache
533
+ def instantiate_data_source(self, ds_path: UPath | None = None) -> "DataSource":
534
+ """Instantiate the data source specified by this config.
474
535
 
475
536
  Args:
476
- projection: the window's projection
477
- bounds: the window's bounds (optional)
537
+ ds_path: optional dataset path to include in the DataSourceContext.
478
538
 
479
539
  Returns:
480
- tuple of updated projection and bounds with zoom offset applied
540
+ the DataSource object.
481
541
  """
482
- if self.zoom_offset == 0:
483
- return projection, bounds
484
- projection = Projection(
485
- projection.crs,
486
- projection.x_resolution / (2**self.zoom_offset),
487
- projection.y_resolution / (2**self.zoom_offset),
488
- )
489
- if bounds:
490
- if self.zoom_offset > 0:
491
- bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
492
- else:
493
- bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
494
- return projection, bounds
495
-
496
-
497
- def load_layer_config(config: dict[str, Any]) -> LayerConfig:
498
- """Load a LayerConfig from a config dict."""
499
- layer_type = LayerType(config.get("type"))
500
- if layer_type == LayerType.RASTER:
501
- return RasterLayerConfig.from_config(config)
502
- elif layer_type == LayerType.VECTOR:
503
- return VectorLayerConfig.from_config(config)
504
- raise ValueError(f"Unknown layer type {layer_type}")
542
+ from rslearn.data_sources.data_source import DataSource, DataSourceContext
543
+ from rslearn.utils.jsonargparse import data_source_context_serializer
505
544
 
545
+ logger.debug("getting a data source for dataset at %s", ds_path)
546
+ if self.data_source is None:
547
+ raise ValueError("This layer does not specify a data source")
506
548
 
507
- class TileStoreConfig:
508
- """A configuration specifying a TileStore."""
509
-
510
- def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
511
- """Create a new TileStoreConfig.
549
+ # Inject the DataSourceContext into the args.
550
+ context = DataSourceContext(
551
+ ds_path=ds_path,
552
+ layer_config=self,
553
+ )
554
+ ds_config: dict[str, Any] = {
555
+ "class_path": self.data_source.class_path,
556
+ "init_args": copy.deepcopy(self.data_source.init_args),
557
+ }
558
+ ds_config["init_args"]["context"] = data_source_context_serializer(context)
512
559
 
513
- Args:
514
- name: the tile store implementation name to use
515
- config_dict: configuration options
516
- """
517
- self.name = name
518
- self.config_dict = config_dict
560
+ # Now we can parse with jsonargparse.
561
+ from rslearn.utils.jsonargparse import (
562
+ data_source_context_serializer,
563
+ init_jsonargparse,
564
+ )
519
565
 
520
- @staticmethod
521
- def from_config(config: dict[str, Any]) -> "TileStoreConfig":
522
- """Create a TileStoreConfig from config dict.
566
+ init_jsonargparse()
567
+ parser = jsonargparse.ArgumentParser()
568
+ parser.add_argument("--data_source", type=DataSource)
569
+ cfg = parser.parse_object({"data_source": ds_config})
570
+ data_source = parser.instantiate_classes(cfg).data_source
571
+ return data_source
572
+
573
+ def instantiate_vector_format(self) -> VectorFormat:
574
+ """Instantiate the vector format specified by this config."""
575
+ if self.type != LayerType.VECTOR:
576
+ raise ValueError(
577
+ f"cannot instantiate vector format for layer with type {self.type}"
578
+ )
523
579
 
524
- Args:
525
- config: the config dict for this TileStoreConfig
526
- """
527
- return TileStoreConfig(name=config["name"], config_dict=config)
580
+ from rslearn.utils.jsonargparse import init_jsonargparse
581
+
582
+ init_jsonargparse()
583
+ parser = jsonargparse.ArgumentParser()
584
+ parser.add_argument("--vector_format", type=VectorFormat)
585
+ cfg = parser.parse_object({"vector_format": self.vector_format})
586
+ vector_format = parser.instantiate_classes(cfg).vector_format
587
+ return vector_format
588
+
589
+
590
+ class StorageConfig(BaseModel):
591
+ """Configuration for the WindowStorageFactory (window metadata storage backend)."""
592
+
593
+ model_config = ConfigDict(frozen=True, extra="forbid")
594
+
595
+ class_path: str = Field(
596
+ default="rslearn.dataset.storage.file.FileWindowStorageFactory",
597
+ description="Class path for the WindowStorageFactory.",
598
+ )
599
+ init_args: dict[str, Any] = Field(
600
+ default_factory=lambda: {},
601
+ description="jsonargparse init args for the WindowStorageFactory.",
602
+ )
603
+
604
+ def instantiate_window_storage_factory(self) -> "WindowStorageFactory":
605
+ """Instantiate the WindowStorageFactory specified by this config."""
606
+ from rslearn.dataset.storage.storage import WindowStorageFactory
607
+ from rslearn.utils.jsonargparse import init_jsonargparse
608
+
609
+ init_jsonargparse()
610
+ parser = jsonargparse.ArgumentParser()
611
+ parser.add_argument("--wsf", type=WindowStorageFactory)
612
+ cfg = parser.parse_object(
613
+ {
614
+ "wsf": dict(
615
+ class_path=self.class_path,
616
+ init_args=self.init_args,
617
+ )
618
+ }
619
+ )
620
+ wsf = parser.instantiate_classes(cfg).wsf
621
+ return wsf
622
+
623
+
624
+ class DatasetConfig(BaseModel):
625
+ """Overall dataset configuration."""
626
+
627
+ model_config = ConfigDict(extra="forbid")
628
+
629
+ layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
630
+ tile_store: dict[str, Any] = Field(
631
+ default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
632
+ description="jsonargparse configuration for the TileStore.",
633
+ )
634
+ storage: StorageConfig = Field(
635
+ default_factory=lambda: StorageConfig(),
636
+ description="jsonargparse configuration for the WindowStorageFactory.",
637
+ )
638
+
639
+ @field_validator("layers", mode="after")
640
+ @classmethod
641
+ def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
642
+ """Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
643
+ for layer_name in v.keys():
644
+ if "." in layer_name:
645
+ raise ValueError(f"layer names must not contain periods: {layer_name}")
646
+ return v