rslearn 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +414 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +14 -20
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +370 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/concatenate_features.py +93 -0
  33. rslearn/tile_stores/__init__.py +0 -11
  34. rslearn/train/dataset.py +4 -12
  35. rslearn/train/prediction_writer.py +16 -32
  36. rslearn/train/tasks/classification.py +2 -1
  37. rslearn/utils/fsspec.py +20 -0
  38. rslearn/utils/jsonargparse.py +79 -0
  39. rslearn/utils/raster_format.py +1 -41
  40. rslearn/utils/vector_format.py +1 -38
  41. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
  42. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/RECORD +47 -48
  43. rslearn/data_sources/geotiff.py +0 -1
  44. rslearn/data_sources/raster_source.py +0 -23
  45. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
  46. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
  47. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
  48. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
  49. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py CHANGED
@@ -1,19 +1,73 @@
1
1
  """Classes for storing configuration of a dataset."""
2
2
 
3
+ import copy
4
+ import functools
3
5
  import json
6
+ import warnings
4
7
  from datetime import timedelta
5
- from enum import Enum
6
- from typing import Any
8
+ from enum import StrEnum
9
+ from typing import TYPE_CHECKING, Annotated, Any
7
10
 
11
+ import jsonargparse
8
12
  import numpy as np
9
13
  import numpy.typing as npt
10
14
  import pytimeparse
15
+ from pydantic import (
16
+ BaseModel,
17
+ BeforeValidator,
18
+ ConfigDict,
19
+ Field,
20
+ PlainSerializer,
21
+ field_validator,
22
+ model_validator,
23
+ )
11
24
  from rasterio.enums import Resampling
25
+ from upath import UPath
12
26
 
27
+ from rslearn.log_utils import get_logger
13
28
  from rslearn.utils import PixelBounds, Projection
29
+ from rslearn.utils.raster_format import RasterFormat
30
+ from rslearn.utils.vector_format import VectorFormat
14
31
 
32
+ if TYPE_CHECKING:
33
+ from rslearn.data_sources.data_source import DataSource
15
34
 
16
- class DType(Enum):
35
+ logger = get_logger("__name__")
36
+
37
+
38
+ def ensure_timedelta(v: Any) -> Any:
39
+ """Ensure the value is a timedelta.
40
+
41
+ If the value is a string, we try to parse it with pytimeparse.
42
+
43
+ This function is meant to be used like Annotated[timedelta, BeforeValidator(ensure_timedelta)].
44
+ """
45
+ if isinstance(v, timedelta):
46
+ return v
47
+ if isinstance(v, str):
48
+ return pytimeparse.parse(v)
49
+ raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
50
+
51
+
52
+ def ensure_optional_timedelta(v: Any) -> Any:
53
+ """Like ensure_timedelta, but allows None as a value."""
54
+ if v is None:
55
+ return None
56
+ if isinstance(v, timedelta):
57
+ return v
58
+ if isinstance(v, str):
59
+ return pytimeparse.parse(v)
60
+ raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
61
+
62
+
63
+ def serialize_optional_timedelta(v: timedelta | None) -> str | None:
64
+ """Serialize an optional timedelta for compatibility with pytimeparse."""
65
+ if v is None:
66
+ return None
67
+ return str(v.total_seconds()) + "s"
68
+
69
+
70
+ class DType(StrEnum):
17
71
  """Data type of a raster."""
18
72
 
19
73
  UINT8 = "uint8"
@@ -49,61 +103,28 @@ class DType(Enum):
49
103
  raise ValueError(f"unable to handle numpy dtype {self}")
50
104
 
51
105
 
52
- RESAMPLING_METHODS = {
53
- "nearest": Resampling.nearest,
54
- "bilinear": Resampling.bilinear,
55
- "cubic": Resampling.cubic,
56
- "cubic_spline": Resampling.cubic_spline,
57
- }
106
+ class ResamplingMethod(StrEnum):
107
+ """An enum representing the rasterio Resampling."""
58
108
 
109
+ NEAREST = "nearest"
110
+ BILINEAR = "bilinear"
111
+ CUBIC = "cubic"
112
+ CUBIC_SPLINE = "cubic_spline"
59
113
 
60
- class RasterFormatConfig:
61
- """A configuration specifying a RasterFormat."""
114
+ def get_rasterio_resampling(self) -> Resampling:
115
+ """Get the rasterio Resampling corresponding to this ResamplingMethod."""
116
+ return RESAMPLING_METHODS[self]
62
117
 
63
- def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
64
- """Initialize a new RasterFormatConfig.
65
118
 
66
- Args:
67
- name: the name of the RasterFormat to use.
68
- config_dict: configuration to pass to the RasterFormat.
69
- """
70
- self.name = name
71
- self.config_dict = config_dict
72
-
73
- @staticmethod
74
- def from_config(config: dict[str, Any]) -> "RasterFormatConfig":
75
- """Create a RasterFormatConfig from config dict.
76
-
77
- Args:
78
- config: the config dict for this RasterFormatConfig
79
- """
80
- return RasterFormatConfig(name=config["name"], config_dict=config)
81
-
82
-
83
- class VectorFormatConfig:
84
- """A configuration specifying a VectorFormat."""
85
-
86
- def __init__(self, name: str, config_dict: dict[str, Any] = {}) -> None:
87
- """Initialize a new VectorFormatConfig.
88
-
89
- Args:
90
- name: the name of the VectorFormat to use.
91
- config_dict: configuration to pass to the VectorFormat.
92
- """
93
- self.name = name
94
- self.config_dict = config_dict
95
-
96
- @staticmethod
97
- def from_config(config: dict[str, Any]) -> "VectorFormatConfig":
98
- """Create a VectorFormatConfig from config dict.
99
-
100
- Args:
101
- config: the config dict for this VectorFormatConfig
102
- """
103
- return VectorFormatConfig(name=config["name"], config_dict=config)
119
+ RESAMPLING_METHODS = {
120
+ ResamplingMethod.NEAREST: Resampling.nearest,
121
+ ResamplingMethod.BILINEAR: Resampling.bilinear,
122
+ ResamplingMethod.CUBIC: Resampling.cubic,
123
+ ResamplingMethod.CUBIC_SPLINE: Resampling.cubic_spline,
124
+ }
104
125
 
105
126
 
106
- class BandSetConfig:
127
+ class BandSetConfig(BaseModel):
107
128
  """A configuration for a band set in a raster layer.
108
129
 
109
130
  Each band set specifies one or more bands that should be stored together.
@@ -111,97 +132,67 @@ class BandSetConfig:
111
132
  bands.
112
133
  """
113
134
 
114
- def __init__(
115
- self,
116
- config_dict: dict[str, Any],
117
- dtype: DType,
118
- bands: list[str] | None = None,
119
- num_bands: int | None = None,
120
- format: dict[str, Any] | None = None,
121
- zoom_offset: int = 0,
122
- remap: dict[str, Any] | None = None,
123
- class_names: list[list[str]] | None = None,
124
- nodata_vals: list[float] | None = None,
125
- ) -> None:
126
- """Creates a new BandSetConfig instance.
127
-
128
- Args:
129
- config_dict: the config dict used to configure this BandSetConfig
130
- dtype: the pixel value type to store tiles in
131
- bands: list of band names in this BandSetConfig. One of bands or num_bands
132
- must be set.
133
- num_bands: the number of bands in this band set. The bands will be named
134
- B00, B01, B02, etc.
135
- format: the format to store tiles in, defaults to geotiff
136
- zoom_offset: store images at a resolution higher or lower than the window
137
- resolution. This enables keeping source data at its native resolution,
138
- either to save storage space (for lower resolution data) or to retain
139
- details (for higher resolution data). If positive, store data at the
140
- window resolution divided by 2^(zoom_offset) (higher resolution). If
141
- negative, store data at the window resolution multiplied by
142
- 2^(-zoom_offset) (lower resolution).
143
- remap: config dict for Remapper to remap pixel values
144
- class_names: optional list of names for the different possible values of
145
- each band. The length of this list must equal the number of bands. For
146
- example, [["forest", "desert"]] means that it is a single-band raster
147
- where values can be 0 (forest) or 1 (desert).
148
- nodata_vals: the nodata values for this band set. This is used during
149
- materialization when creating mosaics, to determine which parts of the
150
- source images should be copied.
151
- """
152
- if (bands is None and num_bands is None) or (
153
- bands is not None and num_bands is not None
135
+ dtype: DType = Field(description="Pixel value type to store the data under")
136
+ bands: list[str] = Field(
137
+ default_factory=lambda: [],
138
+ description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
139
+ )
140
+ num_bands: int | None = Field(
141
+ default=None,
142
+ description="The number of bands in this band set. The bands will be named B0, B1, B2, etc.",
143
+ )
144
+ format: dict[str, Any] = Field(
145
+ default_factory=lambda: {
146
+ "class_path": "rslearn.utils.raster_format.GeotiffRasterFormat"
147
+ },
148
+ description="jsonargparse configuration for the RasterFormat to store the tiles in.",
149
+ )
150
+
151
+ # Store images at a resolution higher or lower than the window resolution. This
152
+ # enables keeping source data at its native resolution, either to save storage
153
+ # space (for lower resolution data) or to retain details (for higher resolution
154
+ # data). If positive, store data at the window resolution divided by
155
+ # 2^(zoom_offset) (higher resolution). If negative, store data at the window
156
+ # resolution multiplied by 2^(-zoom_offset) (lower resolution).
157
+ zoom_offset: int = Field(
158
+ default=0,
159
+ description="Store data at the window resolution multiplied by 2^(-zoom_offset).",
160
+ )
161
+
162
+ remap: dict[str, Any] | None = Field(
163
+ default=None,
164
+ description="Optional jsonargparse configuration for a Remapper to remap pixel values.",
165
+ )
166
+
167
+ # Optional list of names for the different possible values of each band. The length
168
+ # of this list must equal the number of bands. For example, [["forest", "desert"]]
169
+ # means that it is a single-band raster where values can be 0 (forest) or 1
170
+ # (desert).
171
+ class_names: list[list[str]] | None = Field(
172
+ default=None,
173
+ description="Optional list of names for the different possible values of each band.",
174
+ )
175
+
176
+ # Optional list of nodata values for this band set. This is used during
177
+ # materialization when creating mosaics, to determine which parts of the source
178
+ # images should be copied.
179
+ nodata_vals: list[float] | None = Field(
180
+ default=None, description="Optional nodata value for each band."
181
+ )
182
+
183
+ @model_validator(mode="after")
184
+ def after_validator(self) -> "BandSetConfig":
185
+ """Ensure the BandSetConfig is valid, and handle the num_bands field."""
186
+ if (len(self.bands) == 0 and self.num_bands is None) or (
187
+ len(self.bands) != 0 and self.num_bands is not None
154
188
  ):
155
- raise ValueError("exactly one of bands and num_bands must be set")
156
- if bands is None:
157
- assert num_bands is not None
158
- bands = [f"B{idx}" for idx in range(num_bands)]
159
-
160
- if class_names is not None and len(bands) != len(class_names):
161
- raise ValueError(
162
- f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
163
- )
164
-
165
- self.config_dict = config_dict
166
- self.bands = bands
167
- self.dtype = dtype
168
- self.zoom_offset = zoom_offset
169
- self.remap = remap
170
- self.class_names = class_names
171
- self.nodata_vals = nodata_vals
172
-
173
- if format is None:
174
- self.format = {"name": "geotiff"}
175
- else:
176
- self.format = format
189
+ raise ValueError("exactly one of bands and num_bands must be specified")
177
190
 
178
- def serialize(self) -> dict[str, Any]:
179
- """Serialize this BandSetConfig to a config dict."""
180
- return self.config_dict
191
+ if self.num_bands is not None:
192
+ self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
193
+ self.num_bands = None
181
194
 
182
- @staticmethod
183
- def from_config(config: dict[str, Any]) -> "BandSetConfig":
184
- """Create a BandSetConfig from config dict.
185
-
186
- Args:
187
- config: the config dict for this BandSetConfig
188
- """
189
- kwargs = dict(
190
- config_dict=config,
191
- dtype=DType(config["dtype"]),
192
- )
193
- for k in [
194
- "bands",
195
- "num_bands",
196
- "format",
197
- "zoom_offset",
198
- "remap",
199
- "class_names",
200
- "nodata_vals",
201
- ]:
202
- if k in config:
203
- kwargs[k] = config[k]
204
- return BandSetConfig(**kwargs) # type: ignore
195
+ return self
205
196
 
206
197
  def get_final_projection_and_bounds(
207
198
  self, projection: Projection, bounds: PixelBounds
@@ -236,30 +227,76 @@ class BandSetConfig:
236
227
  )
237
228
  return projection, bounds
238
229
 
230
+ @field_validator("format", mode="before")
231
+ @classmethod
232
+ def convert_format_from_legacy(cls, v: dict[str, Any]) -> dict[str, Any]:
233
+ """Support legacy format of the RasterFormat.
234
+
235
+ The legacy format sets 'name' instead of 'class_path', and uses custom parsing
236
+ for the init_args.
237
+ """
238
+ if "name" not in v:
239
+ # New version, it is all good.
240
+ return v
241
+
242
+ warnings.warn(
243
+ "`format = {'name': ...}` is deprecated; "
244
+ "use `{'class_path': '...', 'init_args': {...}}` instead.",
245
+ DeprecationWarning,
246
+ )
247
+
248
+ legacy_name_to_class_path = {
249
+ "image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
250
+ "geotiff": "rslearn.utils.raster_format.GeotiffRasterFormat",
251
+ "single_image": "rslearn.utils.raster_format.SingleImageRasterFormat",
252
+ }
253
+ if v["name"] not in legacy_name_to_class_path:
254
+ raise ValueError(
255
+ f"could not parse legacy format with unknown raster format {v['name']}"
256
+ )
257
+ init_args = dict(v)
258
+ class_path = legacy_name_to_class_path[init_args.pop("name")]
259
+
260
+ return dict(
261
+ class_path=class_path,
262
+ init_args=init_args,
263
+ )
264
+
265
+ def instantiate_raster_format(self) -> RasterFormat:
266
+ """Instantiate the RasterFormat specified by this BandSetConfig."""
267
+ from rslearn.utils.jsonargparse import init_jsonargparse
239
268
 
240
- class SpaceMode(Enum):
269
+ init_jsonargparse()
270
+ parser = jsonargparse.ArgumentParser()
271
+ parser.add_argument("--raster_format", type=RasterFormat)
272
+ cfg = parser.parse_object({"raster_format": self.format})
273
+ raster_format = parser.instantiate_classes(cfg).raster_format
274
+ return raster_format
275
+
276
+
277
+ class SpaceMode(StrEnum):
241
278
  """Spatial matching mode when looking up items corresponding to a window."""
242
279
 
243
- CONTAINS = 1
280
+ CONTAINS = "CONTAINS"
244
281
  """Items must contain the entire window."""
245
282
 
246
- INTERSECTS = 2
283
+ INTERSECTS = "INTERSECTS"
247
284
  """Items must overlap any portion of the window."""
248
285
 
249
- MOSAIC = 3
286
+ MOSAIC = "MOSAIC"
250
287
  """Groups of items should be computed that cover the entire window.
251
288
 
252
289
  During materialization, items in each group are merged to form a mosaic in the
253
290
  dataset.
254
291
  """
255
292
 
256
- PER_PERIOD_MOSAIC = 4
293
+ PER_PERIOD_MOSAIC = "PER_PERIOD_MOSAIC"
257
294
  """Create one mosaic per sub-period of the time range.
258
295
 
259
296
  The duration of the sub-periods is controlled by another option in QueryConfig.
260
297
  """
261
298
 
262
- COMPOSITE = 5
299
+ COMPOSITE = "COMPOSITE"
263
300
  """Creates one composite covering the entire window.
264
301
 
265
302
  During querying all items intersecting the window are placed in one group.
@@ -270,188 +307,216 @@ class SpaceMode(Enum):
270
307
  # TODO add PER_PERIOD_COMPOSITE
271
308
 
272
309
 
273
- class TimeMode(Enum):
310
+ class TimeMode(StrEnum):
274
311
  """Temporal matching mode when looking up items corresponding to a window."""
275
312
 
276
- WITHIN = 1
313
+ WITHIN = "WITHIN"
277
314
  """Items must be within the window time range."""
278
315
 
279
- NEAREST = 2
316
+ NEAREST = "NEAREST"
280
317
  """Select items closest to the window time range, up to max_matches."""
281
318
 
282
- BEFORE = 3
319
+ BEFORE = "BEFORE"
283
320
  """Select items before the end of the window time range, up to max_matches."""
284
321
 
285
- AFTER = 4
322
+ AFTER = "AFTER"
286
323
  """Select items after the start of the window time range, up to max_matches."""
287
324
 
288
325
 
289
- class QueryConfig:
326
+ class QueryConfig(BaseModel):
290
327
  """A configuration for querying items in a data source."""
291
328
 
292
- def __init__(
293
- self,
294
- space_mode: SpaceMode = SpaceMode.MOSAIC,
295
- time_mode: TimeMode = TimeMode.WITHIN,
296
- min_matches: int = 0,
297
- max_matches: int = 1,
298
- period_duration: timedelta = timedelta(days=30),
299
- ):
300
- """Creates a new query configuration.
301
-
302
- The provided options determine how a DataSource should lookup items that match a
303
- spatiotemporal window.
329
+ model_config = ConfigDict(frozen=True)
330
+
331
+ space_mode: SpaceMode = Field(
332
+ default=SpaceMode.MOSAIC,
333
+ description="Specifies how items should be matched with windows spatially.",
334
+ )
335
+ time_mode: TimeMode = Field(
336
+ default=TimeMode.WITHIN,
337
+ description="Specifies how items should be matched with windows temporally.",
338
+ )
339
+
340
+ # Minimum number of item groups. If there are fewer than this many matches, then no
341
+ # matches will be returned. This can be used to prevent unnecessary data ingestion
342
+ # if the user plans to discard windows that do not have a sufficient amount of data.
343
+ min_matches: int = Field(
344
+ default=0, description="The minimum number of item groups."
345
+ )
346
+
347
+ max_matches: int = Field(
348
+ default=1, description="The maximum number of item groups."
349
+ )
350
+ period_duration: Annotated[
351
+ timedelta,
352
+ BeforeValidator(ensure_timedelta),
353
+ PlainSerializer(serialize_optional_timedelta),
354
+ ] = Field(
355
+ default=timedelta(days=30),
356
+ description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
357
+ )
358
+
359
+
360
+ class DataSourceConfig(BaseModel):
361
+ """Configuration for a DataSource in a dataset layer."""
304
362
 
305
- Args:
306
- space_mode: specifies how items should be matched with windows spatially
307
- time_mode: specifies how items should be matched with windows temporally
308
- min_matches: the minimum number of item groups. If there are fewer than
309
- this many matches, then no matches will be returned. This can be used
310
- to prevent unnecessary data ingestion if the user plans to discard
311
- windows that do not have a sufficient amount of data.
312
- max_matches: the maximum number of items (or groups of items, if space_mode
313
- is MOSAIC) to match
314
- period_duration: the duration of the periods, if the space mode is
315
- PER_PERIOD_MOSAIC.
363
+ model_config = ConfigDict(frozen=True)
364
+
365
+ class_path: str = Field(description="Class path for the data source.")
366
+ init_args: dict[str, Any] = Field(
367
+ default_factory=lambda: {},
368
+ description="jsonargparse init args for the data source.",
369
+ )
370
+ query_config: QueryConfig = Field(
371
+ default_factory=lambda: QueryConfig(),
372
+ description="QueryConfig specifying how to match items with windows.",
373
+ )
374
+ time_offset: Annotated[
375
+ timedelta | None,
376
+ BeforeValidator(ensure_optional_timedelta),
377
+ PlainSerializer(serialize_optional_timedelta),
378
+ ] = Field(
379
+ default=None,
380
+ description="Optional timedelta to add to the window's time range before matching.",
381
+ )
382
+ duration: Annotated[
383
+ timedelta | None,
384
+ BeforeValidator(ensure_optional_timedelta),
385
+ PlainSerializer(serialize_optional_timedelta),
386
+ ] = Field(
387
+ default=None,
388
+ description="Optional, if the window's time range is (t0, t1), then update to (t0, t0 + duration).",
389
+ )
390
+ ingest: bool = Field(
391
+ default=True,
392
+ description="Whether to ingest this layer (default True). If False, it will be directly materialized without ingestion.",
393
+ )
394
+
395
+ @model_validator(mode="before")
396
+ @classmethod
397
+ def convert_from_legacy(cls, d: dict[str, Any]) -> dict[str, Any]:
398
+ """Support legacy format of the DataSourceConfig.
399
+
400
+ The legacy format sets 'name' instead of 'class_path', and mixes the arguments
401
+ for the DataSource in with the DataSourceConfig keys.
316
402
  """
317
- self.space_mode = space_mode
318
- self.time_mode = time_mode
319
- self.min_matches = min_matches
320
- self.max_matches = max_matches
321
- self.period_duration = period_duration
322
-
323
- def serialize(self) -> dict[str, Any]:
324
- """Serialize this QueryConfig to a config dict."""
325
- return {
326
- "space_mode": str(self.space_mode),
327
- "time_mode": str(self.time_mode),
328
- "min_matches": self.min_matches,
329
- "max_matches": self.max_matches,
330
- "period_duration": f"{self.period_duration.total_seconds()}s",
331
- }
332
-
333
- @staticmethod
334
- def from_config(config: dict[str, Any]) -> "QueryConfig":
335
- """Create a QueryConfig from config dict.
403
+ if "name" not in d:
404
+ # New version, it is all good.
405
+ return d
406
+
407
+ warnings.warn(
408
+ "`Data source configuration {'name': ...}` is deprecated; "
409
+ "use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
410
+ DeprecationWarning,
411
+ )
336
412
 
337
- Args:
338
- config: the config dict for this QueryConfig
339
- """
340
- kwargs: dict[str, Any] = dict()
341
- if "space_mode" in config:
342
- kwargs["space_mode"] = SpaceMode[config["space_mode"]]
343
- if "time_mode" in config:
344
- kwargs["time_mode"] = TimeMode[config["time_mode"]]
345
- if "period_duration" in config:
346
- kwargs["period_duration"] = timedelta(
347
- seconds=pytimeparse.parse(config["period_duration"])
348
- )
349
- for k in ["min_matches", "max_matches"]:
350
- if k not in config:
413
+ # Split the dict into the base config that is in the pydantic model, and the
414
+ # source-specific options that should be moved to init_args dict.
415
+ class_path = d["name"]
416
+ base_config: dict[str, Any] = {}
417
+ ds_init_args: dict[str, Any] = {}
418
+ for k, v in d.items():
419
+ if k == "name":
351
420
  continue
352
- kwargs[k] = config[k]
353
- return QueryConfig(**kwargs)
421
+ if k in cls.model_fields:
422
+ base_config[k] = v
423
+ else:
424
+ ds_init_args[k] = v
425
+
426
+ # Some legacy configs erroneously specify these keys, which are now caught by
427
+ # validation. But we still want those specific legacy configs to work.
428
+ if (
429
+ class_path == "rslearn.data_sources.planetary_computer.Sentinel2"
430
+ and "max_cloud_cover" in ds_init_args
431
+ ):
432
+ warnings.warn(
433
+ "Data source configuration specifies invalid 'max_cloud_cover' option.",
434
+ DeprecationWarning,
435
+ )
436
+ del ds_init_args["max_cloud_cover"]
354
437
 
438
+ base_config["class_path"] = class_path
439
+ base_config["init_args"] = ds_init_args
440
+ return base_config
355
441
 
356
- class DataSourceConfig:
357
- """Configuration for a DataSource in a dataset layer."""
358
442
 
359
- def __init__(
360
- self,
361
- name: str,
362
- query_config: QueryConfig,
363
- config_dict: dict[str, Any],
364
- time_offset: timedelta | None = None,
365
- duration: timedelta | None = None,
366
- ingest: bool = True,
367
- ) -> None:
368
- """Initializes a new DataSourceConfig.
369
-
370
- Args:
371
- name: the data source class name
372
- query_config: the QueryConfig specifying how to match items with windows
373
- config_dict: additional config passed to initialize the DataSource
374
- time_offset: optional, add this timedelta to the window's time range before
375
- matching
376
- duration: optional, if window's time range is (t0, t1), then update to
377
- (t0, t0 + duration)
378
- ingest: whether to ingest this layer or directly materialize it
379
- (default true)
380
- """
381
- self.name = name
382
- self.query_config = query_config
383
- self.config_dict = config_dict
384
- self.time_offset = time_offset
385
- self.duration = duration
386
- self.ingest = ingest
443
+ class LayerType(StrEnum):
444
+ """The layer type (raster or vector)."""
387
445
 
388
- def serialize(self) -> dict[str, Any]:
389
- """Serialize this DataSourceConfig to a config dict."""
390
- return self.config_dict
446
+ RASTER = "raster"
447
+ VECTOR = "vector"
391
448
 
392
- @staticmethod
393
- def from_config(config: dict[str, Any]) -> "DataSourceConfig":
394
- """Create a DataSourceConfig from config dict.
395
449
 
396
- Args:
397
- config: the config dict for this DataSourceConfig
398
- """
399
- kwargs = dict(
400
- name=config["name"],
401
- query_config=QueryConfig.from_config(config.get("query_config", {})),
402
- config_dict=config,
403
- )
404
- if "time_offset" in config:
405
- kwargs["time_offset"] = timedelta(
406
- seconds=pytimeparse.parse(config["time_offset"])
407
- )
408
- if "duration" in config:
409
- kwargs["duration"] = timedelta(
410
- seconds=pytimeparse.parse(config["duration"])
411
- )
412
- if "ingest" in config:
413
- kwargs["ingest"] = config["ingest"]
414
- return DataSourceConfig(**kwargs)
450
+ class CompositingMethod(StrEnum):
451
+ """Method how to select pixels for the composite from corresponding items of a window."""
415
452
 
453
+ FIRST_VALID = "FIRST_VALID"
454
+ """Select first valid pixel in order of corresponding items (might be sorted)"""
416
455
 
417
- class LayerType(Enum):
418
- """The layer type (raster or vector)."""
456
+ MEAN = "MEAN"
457
+ """Select per-pixel mean value of corresponding items of a window"""
419
458
 
420
- RASTER = "raster"
421
- VECTOR = "vector"
459
+ MEDIAN = "MEDIAN"
460
+ """Select per-pixel median value of corresponding items of a window"""
422
461
 
423
462
 
424
- class LayerConfig:
463
+ class LayerConfig(BaseModel):
425
464
  """Configuration of a layer in a dataset."""
426
465
 
427
- def __init__(
428
- self,
429
- layer_type: LayerType,
430
- data_source: DataSourceConfig | None = None,
431
- alias: str | None = None,
432
- ):
433
- """Initialize a new LayerConfig.
466
+ model_config = ConfigDict(frozen=True)
467
+
468
+ type: LayerType = Field(description="The LayerType (raster or vector).")
469
+ data_source: DataSourceConfig | None = Field(
470
+ default=None,
471
+ description="Optional DataSourceConfig if this layer is retrievable.",
472
+ )
473
+ alias: str | None = Field(
474
+ default=None, description="Alias for this layer to use in the tile store."
475
+ )
476
+
477
+ # Raster layer options.
478
+ band_sets: list[BandSetConfig] = Field(
479
+ default_factory=lambda: [],
480
+ description="For raster layers, the bands to store in this layer.",
481
+ )
482
+ resampling_method: ResamplingMethod = Field(
483
+ default=ResamplingMethod.BILINEAR,
484
+ description="For raster layers, how to resample rasters (if neeed), default bilinear resampling.",
485
+ )
486
+ compositing_method: CompositingMethod = Field(
487
+ default=CompositingMethod.FIRST_VALID,
488
+ description="For raster layers, how to compute pixel values in the composite of each window's items.",
489
+ )
490
+
491
+ # Vector layer options.
492
+ vector_format: dict[str, Any] = Field(
493
+ default_factory=lambda: {
494
+ "class_path": "rslearn.utils.vector_format.GeojsonVectorFormat"
495
+ },
496
+ description="For vector layers, the jsonargparse configuration for the VectorFormat.",
497
+ )
498
+ class_property_name: str | None = Field(
499
+ default=None,
500
+ description="Optional metadata field indicating that the GeoJSON features contain a property that corresponds to a class label, and this is the name of that property.",
501
+ )
502
+ class_names: list[str] | None = Field(
503
+ default=None,
504
+ description="The list of classes that the class_property_name property could be set to.",
505
+ )
506
+
507
+ @model_validator(mode="after")
508
+ def after_validator(self) -> "LayerConfig":
509
+ """Ensure the LayerConfig is valid."""
510
+ if self.type == LayerType.RASTER and len(self.band_sets) == 0:
511
+ raise ValueError(
512
+ "band sets must be specified and non-empty for raster layers"
513
+ )
434
514
 
435
- Args:
436
- layer_type: the LayerType (raster or vector)
437
- data_source: optional DataSourceConfig if this layer is retrievable
438
- alias: alias for this layer to use in the tile store
439
- """
440
- self.layer_type = layer_type
441
- self.data_source = data_source
442
- self.alias = alias
443
-
444
- def serialize(self) -> dict[str, Any]:
445
- """Serialize this LayerConfig to a config dict."""
446
- return {
447
- "layer_type": str(self.layer_type),
448
- "data_source": self.data_source.serialize() if self.data_source else None,
449
- "alias": self.alias,
450
- }
515
+ return self
451
516
 
452
517
  def __hash__(self) -> int:
453
518
  """Return a hash of this LayerConfig."""
454
- return hash(json.dumps(self.serialize(), sort_keys=True))
519
+ return hash(json.dumps(self.model_dump(mode="json"), sort_keys=True))
455
520
 
456
521
  def __eq__(self, other: Any) -> bool:
457
522
  """Returns whether other is the same as this LayerConfig.
@@ -461,142 +526,71 @@ class LayerConfig:
461
526
  """
462
527
  if not isinstance(other, LayerConfig):
463
528
  return False
464
- return self.serialize() == other.serialize()
465
-
529
+ return self.model_dump() == other.model_dump()
466
530
 
467
- class CompositingMethod(Enum):
468
- """Method how to select pixels for the composite from corresponding items of a window."""
531
+ @functools.cache
532
+ def instantiate_data_source(self, ds_path: UPath | None = None) -> "DataSource":
533
+ """Instantiate the data source specified by this config.
469
534
 
470
- FIRST_VALID = 1
471
- """Select first valid pixel in order of corresponding items (might be sorted)"""
535
+ Args:
536
+ ds_path: optional dataset path to include in the DataSourceContext.
472
537
 
473
- MEAN = 2
474
- """Select per-pixel mean value of corresponding items of a window"""
538
+ Returns:
539
+ the DataSource object.
540
+ """
541
+ from rslearn.data_sources.data_source import DataSource, DataSourceContext
542
+ from rslearn.utils.jsonargparse import data_source_context_serializer
475
543
 
476
- MEDIAN = 3
477
- """Select per-pixel median value of corresponding items of a window"""
544
+ logger.debug("getting a data source for dataset at %s", ds_path)
545
+ if self.data_source is None:
546
+ raise ValueError("This layer does not specify a data source")
478
547
 
548
+ # Inject the DataSourceContext into the args.
549
+ context = DataSourceContext(
550
+ ds_path=ds_path,
551
+ layer_config=self,
552
+ )
553
+ ds_config: dict[str, Any] = {
554
+ "class_path": self.data_source.class_path,
555
+ "init_args": copy.deepcopy(self.data_source.init_args),
556
+ }
557
+ ds_config["init_args"]["context"] = data_source_context_serializer(context)
479
558
 
480
- class RasterLayerConfig(LayerConfig):
481
- """Configuration of a raster layer."""
559
+ # Now we can parse with jsonargparse.
560
+ from rslearn.utils.jsonargparse import (
561
+ data_source_context_serializer,
562
+ init_jsonargparse,
563
+ )
482
564
 
483
- def __init__(
484
- self,
485
- layer_type: LayerType,
486
- band_sets: list[BandSetConfig],
487
- data_source: DataSourceConfig | None = None,
488
- resampling_method: Resampling = Resampling.bilinear,
489
- alias: str | None = None,
490
- compositing_method: CompositingMethod = CompositingMethod.FIRST_VALID,
491
- ):
492
- """Initialize a new RasterLayerConfig.
565
+ init_jsonargparse()
566
+ parser = jsonargparse.ArgumentParser()
567
+ parser.add_argument("--data_source", type=DataSource)
568
+ cfg = parser.parse_object({"data_source": ds_config})
569
+ data_source = parser.instantiate_classes(cfg).data_source
570
+ return data_source
493
571
 
494
- Args:
495
- layer_type: the LayerType (must be raster)
496
- band_sets: the bands to store in this layer
497
- data_source: optional DataSourceConfig if this layer is retrievable
498
- resampling_method: how to resample rasters (if needed), default bilinear resampling
499
- alias: alias for this layer to use in the tile store
500
- compositing_method: how to compute pixel values in the composite of each windows items
501
- """
502
- super().__init__(layer_type, data_source, alias)
503
- self.band_sets = band_sets
504
- self.resampling_method = resampling_method
505
- self.compositing_method = compositing_method
572
+ def instantiate_vector_format(self) -> VectorFormat:
573
+ """Instantiate the vector format specified by this config."""
574
+ if self.type != LayerType.VECTOR:
575
+ raise ValueError(
576
+ f"cannot instantiate vector format for layer with type {self.type}"
577
+ )
506
578
 
507
- @staticmethod
508
- def from_config(config: dict[str, Any]) -> "RasterLayerConfig":
509
- """Create a RasterLayerConfig from config dict.
579
+ from rslearn.utils.jsonargparse import init_jsonargparse
510
580
 
511
- Args:
512
- config: the config dict for this RasterLayerConfig
513
- """
514
- kwargs = {
515
- "layer_type": LayerType(config["type"]),
516
- "band_sets": [BandSetConfig.from_config(el) for el in config["band_sets"]],
517
- }
518
- if "data_source" in config:
519
- kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
520
- if "resampling_method" in config:
521
- kwargs["resampling_method"] = RESAMPLING_METHODS[
522
- config["resampling_method"]
523
- ]
524
- if "alias" in config:
525
- kwargs["alias"] = config["alias"]
526
- if "compositing_method" in config:
527
- kwargs["compositing_method"] = CompositingMethod[
528
- config["compositing_method"]
529
- ]
530
- return RasterLayerConfig(**kwargs) # type: ignore
531
-
532
-
533
- class VectorLayerConfig(LayerConfig):
534
- """Configuration of a vector layer."""
535
-
536
- def __init__(
537
- self,
538
- layer_type: LayerType,
539
- data_source: DataSourceConfig | None = None,
540
- format: VectorFormatConfig = VectorFormatConfig("geojson"),
541
- alias: str | None = None,
542
- class_property_name: str | None = None,
543
- class_names: list[str] | None = None,
544
- ):
545
- """Initialize a new VectorLayerConfig.
581
+ init_jsonargparse()
582
+ parser = jsonargparse.ArgumentParser()
583
+ parser.add_argument("--vector_format", type=VectorFormat)
584
+ cfg = parser.parse_object({"vector_format": self.vector_format})
585
+ vector_format = parser.instantiate_classes(cfg).vector_format
586
+ return vector_format
546
587
 
547
- Args:
548
- layer_type: the LayerType (must be vector)
549
- data_source: optional DataSourceConfig if this layer is retrievable
550
- format: the VectorFormatConfig, default storing as GeoJSON
551
- alias: alias for this layer to use in the tile store
552
- class_property_name: optional metadata field indicating that the GeoJSON
553
- features contain a property that corresponds to a class label, and this
554
- is the name of that property.
555
- class_names: the list of classes that the class_property_name property
556
- could be set to.
557
- """
558
- super().__init__(layer_type, data_source, alias)
559
- self.format = format
560
- self.class_property_name = class_property_name
561
- self.class_names = class_names
562
588
 
563
- @staticmethod
564
- def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
565
- """Create a VectorLayerConfig from config dict.
589
+ class DatasetConfig(BaseModel):
590
+ """Overall dataset configuration."""
566
591
 
567
- Args:
568
- config: the config dict for this VectorLayerConfig
569
- """
570
- kwargs: dict[str, Any] = {"layer_type": LayerType(config["type"])}
571
- if "data_source" in config:
572
- kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
573
- if "format" in config:
574
- kwargs["format"] = VectorFormatConfig.from_config(config["format"])
575
-
576
- simple_optionals = [
577
- "alias",
578
- "class_property_name",
579
- "class_names",
580
- ]
581
- for k in simple_optionals:
582
- if k in config:
583
- kwargs[k] = config[k]
584
-
585
- # The "zoom_offset" option was removed.
586
- # We should change how we create configuration so we can error on all
587
- # non-existing config options, but for now we make sure to raise error if
588
- # zoom_offset is set since it is no longer supported.
589
- if "zoom_offset" in config:
590
- raise ValueError("unsupported zoom_offset option in vector layer config")
591
-
592
- return VectorLayerConfig(**kwargs) # type: ignore
593
-
594
-
595
- def load_layer_config(config: dict[str, Any]) -> LayerConfig:
596
- """Load a LayerConfig from a config dict."""
597
- layer_type = LayerType(config.get("type"))
598
- if layer_type == LayerType.RASTER:
599
- return RasterLayerConfig.from_config(config)
600
- elif layer_type == LayerType.VECTOR:
601
- return VectorLayerConfig.from_config(config)
602
- raise ValueError(f"Unknown layer type {layer_type}")
592
+ layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
593
+ tile_store: dict[str, Any] = Field(
594
+ default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
595
+ description="jsonargparse configuration for the TileStore.",
596
+ )