rslearn 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +420 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +14 -20
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +377 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/concatenate_features.py +93 -0
  33. rslearn/models/olmoearth_pretrain/model.py +2 -5
  34. rslearn/tile_stores/__init__.py +0 -11
  35. rslearn/train/dataset.py +4 -12
  36. rslearn/train/prediction_writer.py +16 -32
  37. rslearn/train/tasks/classification.py +2 -1
  38. rslearn/utils/fsspec.py +20 -0
  39. rslearn/utils/jsonargparse.py +79 -0
  40. rslearn/utils/raster_format.py +1 -41
  41. rslearn/utils/vector_format.py +1 -38
  42. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/METADATA +58 -25
  43. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/RECORD +48 -49
  44. rslearn/data_sources/geotiff.py +0 -1
  45. rslearn/data_sources/raster_source.py +0 -23
  46. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/top_level.txt +0 -0
@@ -10,40 +10,17 @@ Each source supports operations to lookup items that match with spatiotemporal
10
10
  geometries, and ingest those items.
11
11
  """
12
12
 
13
- import functools
14
- import importlib
15
-
16
- from upath import UPath
17
-
18
- from rslearn.config import LayerConfig
19
- from rslearn.log_utils import get_logger
20
-
21
- from .data_source import DataSource, Item, ItemLookupDataSource, RetrieveItemDataSource
22
-
23
- logger = get_logger(__name__)
24
-
25
-
26
- @functools.cache
27
- def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
28
- """Loads a data source from config dict.
29
-
30
- Args:
31
- config: the LayerConfig containing this data source.
32
- ds_path: the dataset root directory.
33
- """
34
- logger.debug("getting a data source for dataset at %s", ds_path)
35
- if config.data_source is None:
36
- raise ValueError("No data source specified")
37
- name = config.data_source.name
38
- module_name = ".".join(name.split(".")[:-1])
39
- class_name = name.split(".")[-1]
40
- module = importlib.import_module(module_name)
41
- class_ = getattr(module, class_name)
42
- return class_.from_config(config, ds_path)
43
-
13
+ from .data_source import (
14
+ DataSource,
15
+ DataSourceContext,
16
+ Item,
17
+ ItemLookupDataSource,
18
+ RetrieveItemDataSource,
19
+ )
44
20
 
45
21
  __all__ = (
46
22
  "DataSource",
23
+ "DataSourceContext",
47
24
  "Item",
48
25
  "ItemLookupDataSource",
49
26
  "RetrieveItemDataSource",
@@ -25,7 +25,7 @@ from rasterio.enums import Resampling
25
25
  from upath import UPath
26
26
 
27
27
  import rslearn.data_sources.utils
28
- from rslearn.config import LayerConfig, RasterLayerConfig
28
+ from rslearn.config import LayerConfig
29
29
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
30
30
  from rslearn.dataset import Window
31
31
  from rslearn.dataset.materialize import RasterMaterializer
@@ -34,7 +34,7 @@ from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
34
34
  from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
35
35
  from rslearn.utils.grid_index import GridIndex
36
36
 
37
- from .data_source import DataSource, Item, QueryConfig
37
+ from .data_source import DataSource, DataSourceContext, Item, QueryConfig
38
38
 
39
39
  WRS2_GRID_SIZE = 1.0
40
40
 
@@ -98,20 +98,25 @@ class LandsatOliTirs(DataSource, TileStore):
98
98
 
99
99
  def __init__(
100
100
  self,
101
- config: RasterLayerConfig,
102
- metadata_cache_dir: UPath,
101
+ metadata_cache_dir: str,
103
102
  sort_by: str | None = None,
103
+ context: DataSourceContext = DataSourceContext(),
104
104
  ) -> None:
105
105
  """Initialize a new LandsatOliTirs instance.
106
106
 
107
107
  Args:
108
- config: configuration of this layer
109
- metadata_cache_dir: directory to cache product metadata files.
108
+ metadata_cache_dir: directory to cache produtc metadata files.
110
109
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
111
110
  SpaceMode.WITHIN.
111
+ context: the data source context.
112
112
  """
113
- self.config = config
114
- self.metadata_cache_dir = metadata_cache_dir
113
+ # If context is provided, we join the directory with the dataset path,
114
+ # otherwise we treat it directly as UPath.
115
+ if context.ds_path is not None:
116
+ self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
117
+ else:
118
+ self.metadata_cache_dir = UPath(metadata_cache_dir)
119
+
115
120
  self.sort_by = sort_by
116
121
 
117
122
  self.client = boto3.client("s3")
@@ -120,21 +125,6 @@ class LandsatOliTirs(DataSource, TileStore):
120
125
 
121
126
  self.wrs2_index: GridIndex | None = None
122
127
 
123
- @staticmethod
124
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "LandsatOliTirs":
125
- """Creates a new LandsatOliTirs instance from a configuration dictionary."""
126
- if config.data_source is None:
127
- raise ValueError(f"data_source is required for config dict {config}")
128
- d = config.data_source.config_dict
129
- kwargs = dict(
130
- config=config,
131
- metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
132
- )
133
- if "sort_by" in d:
134
- kwargs["sort_by"] = d["sort_by"]
135
-
136
- return LandsatOliTirs(**kwargs)
137
-
138
128
  def _read_products(
139
129
  self, needed_year_pathrows: set[tuple[int, str, str]]
140
130
  ) -> Generator[LandsatOliTirsItem, None, None]:
@@ -536,7 +526,6 @@ class LandsatOliTirs(DataSource, TileStore):
536
526
  layer_name: the name of this layer
537
527
  layer_cfg: the config of this layer
538
528
  """
539
- assert isinstance(layer_cfg, RasterLayerConfig)
540
529
  RasterMaterializer().materialize(
541
530
  TileStoreWithLayer(self, layer_name),
542
531
  window,
@@ -22,7 +22,6 @@ from rasterio.crs import CRS
22
22
  from upath import UPath
23
23
 
24
24
  import rslearn.data_sources.utils
25
- from rslearn.config import RasterLayerConfig
26
25
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_EPSG, WGS84_PROJECTION
27
26
  from rslearn.tile_stores import TileStoreWithLayer
28
27
  from rslearn.utils import GridIndex, Projection, STGeometry, daterange
@@ -32,6 +31,7 @@ from rslearn.utils.raster_format import get_raster_projection_and_bounds
32
31
  from .copernicus import get_harmonize_callback, get_sentinel2_tiles
33
32
  from .data_source import (
34
33
  DataSource,
34
+ DataSourceContext,
35
35
  Item,
36
36
  ItemLookupDataSource,
37
37
  QueryConfig,
@@ -83,16 +83,15 @@ class Naip(DataSource):
83
83
 
84
84
  def __init__(
85
85
  self,
86
- config: RasterLayerConfig,
87
- index_cache_dir: UPath,
86
+ index_cache_dir: str,
88
87
  use_rtree_index: bool = False,
89
88
  states: list[str] | None = None,
90
89
  years: list[int] | None = None,
90
+ context: DataSourceContext = DataSourceContext(),
91
91
  ) -> None:
92
92
  """Initialize a new Naip instance.
93
93
 
94
94
  Args:
95
- config: the LayerConfig of the layer containing this data source.
96
95
  index_cache_dir: directory to cache index shapefiles.
97
96
  use_rtree_index: whether to create an rtree index to enable faster lookups
98
97
  (default false)
@@ -100,9 +99,15 @@ class Naip(DataSource):
100
99
  the search. If use_rtree_index is enabled, the rtree will only be
101
100
  populated with data from these states.
102
101
  years: optional list of years to restrict the search
102
+ context: the data source context.
103
103
  """
104
- self.config = config
105
- self.index_cache_dir = index_cache_dir
104
+ # If context is provided, we join the directory with the dataset path,
105
+ # otherwise we treat it directly as UPath.
106
+ if context.ds_path is not None:
107
+ self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
108
+ else:
109
+ self.index_cache_dir = UPath(index_cache_dir)
110
+
106
111
  self.states = states
107
112
  self.years = years
108
113
 
@@ -119,22 +124,6 @@ class Naip(DataSource):
119
124
 
120
125
  self.rtree_index = get_cached_rtree(self.index_cache_dir, build_fn)
121
126
 
122
- @staticmethod
123
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
124
- """Creates a new Naip instance from a configuration dictionary."""
125
- if config.data_source is None:
126
- raise ValueError(f"data_source is required for config dict {config}")
127
- d = config.data_source.config_dict
128
- kwargs = dict(
129
- config=config,
130
- index_cache_dir=join_upath(ds_path, d["index_cache_dir"]),
131
- )
132
- simple_optionals = ["use_rtree_index", "states", "years"]
133
- for k in simple_optionals:
134
- if k in d:
135
- kwargs[k] = d[k]
136
- return Naip(**kwargs)
137
-
138
127
  def _download_manifest(self) -> UPath:
139
128
  """Download the manifest that enumerates files in the bucket.
140
129
 
@@ -460,51 +449,37 @@ class Sentinel2(
460
449
 
461
450
  def __init__(
462
451
  self,
463
- config: RasterLayerConfig,
464
452
  modality: Sentinel2Modality,
465
- metadata_cache_dir: UPath,
453
+ metadata_cache_dir: str,
466
454
  sort_by: str | None = None,
467
455
  harmonize: bool = False,
456
+ context: DataSourceContext = DataSourceContext(),
468
457
  ) -> None:
469
458
  """Initialize a new Sentinel2 instance.
470
459
 
471
460
  Args:
472
- config: the LayerConfig of the layer containing this data source.
473
461
  modality: L1C or L2A.
474
462
  metadata_cache_dir: directory to cache product metadata files.
475
463
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
476
464
  SpaceMode.WITHIN.
477
465
  harmonize: harmonize pixel values across different processing baselines,
478
466
  see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
467
+ context: the data source context.
479
468
  """ # noqa: E501
480
- self.config = config
469
+ # If context is provided, we join the directory with the dataset path,
470
+ # otherwise we treat it directly as UPath.
471
+ if context.ds_path is not None:
472
+ self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
473
+ else:
474
+ self.metadata_cache_dir = UPath(metadata_cache_dir)
475
+
481
476
  self.modality = modality
482
- self.metadata_cache_dir = metadata_cache_dir
483
477
  self.sort_by = sort_by
484
478
  self.harmonize = harmonize
485
479
 
486
480
  bucket_name = self.bucket_names[modality]
487
481
  self.bucket = boto3.resource("s3").Bucket(bucket_name)
488
482
 
489
- @staticmethod
490
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
491
- """Creates a new Sentinel2 instance from a configuration dictionary."""
492
- if config.data_source is None:
493
- raise ValueError("Sentinel2 data source requires a data source config")
494
- d = config.data_source.config_dict
495
- kwargs = dict(
496
- config=config,
497
- modality=Sentinel2Modality(d["modality"]),
498
- metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
499
- )
500
-
501
- simple_optionals = ["sort_by", "harmonize"]
502
- for k in simple_optionals:
503
- if k in d:
504
- kwargs[k] = d[k]
505
-
506
- return Sentinel2(**kwargs)
507
-
508
483
  def _read_products(
509
484
  self, needed_cell_months: set[tuple[str, int, int]]
510
485
  ) -> Generator[Sentinel2Item, None, None]:
@@ -7,7 +7,6 @@ from typing import Any
7
7
  import boto3
8
8
  from upath import UPath
9
9
 
10
- from rslearn.config import RasterLayerConfig
11
10
  from rslearn.data_sources.copernicus import (
12
11
  CopernicusItem,
13
12
  Sentinel1OrbitDirection,
@@ -19,7 +18,7 @@ from rslearn.log_utils import get_logger
19
18
  from rslearn.tile_stores import TileStore, TileStoreWithLayer
20
19
  from rslearn.utils.geometry import STGeometry
21
20
 
22
- from .data_source import DataSource, QueryConfig
21
+ from .data_source import DataSource, DataSourceContext, QueryConfig
23
22
 
24
23
  WRS2_GRID_SIZE = 1.0
25
24
 
@@ -45,11 +44,13 @@ class Sentinel1(DataSource, TileStore):
45
44
  def __init__(
46
45
  self,
47
46
  orbit_direction: Sentinel1OrbitDirection | None = None,
47
+ context: DataSourceContext = DataSourceContext(),
48
48
  ) -> None:
49
49
  """Initialize a new Sentinel1 instance.
50
50
 
51
51
  Args:
52
52
  orbit_direction: optional orbit direction to filter by.
53
+ context: the data source context.
53
54
  """
54
55
  self.client = boto3.client("s3")
55
56
  self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
@@ -59,18 +60,6 @@ class Sentinel1(DataSource, TileStore):
59
60
  orbit_direction=orbit_direction,
60
61
  )
61
62
 
62
- @staticmethod
63
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1":
64
- """Creates a new Sentinel1 instance from a configuration dictionary."""
65
- if config.data_source is None:
66
- raise ValueError(f"data_source is required for config dict {config}")
67
- d = config.data_source.config_dict
68
- kwargs: dict[str, Any] = {}
69
- if "orbit_direction" in d:
70
- d["orbit_direction"] = Sentinel1OrbitDirection[d["orbit_direction"]]
71
-
72
- return Sentinel1(**kwargs)
73
-
74
63
  def get_items(
75
64
  self, geometries: list[STGeometry], query_config: QueryConfig
76
65
  ) -> list[list[list[CopernicusItem]]]:
@@ -14,9 +14,9 @@ from dateutil.relativedelta import relativedelta
14
14
  from rasterio.transform import from_origin
15
15
  from upath import UPath
16
16
 
17
- from rslearn.config import QueryConfig, RasterLayerConfig, SpaceMode
17
+ from rslearn.config import QueryConfig, SpaceMode
18
18
  from rslearn.const import WGS84_EPSG, WGS84_PROJECTION
19
- from rslearn.data_sources import DataSource, Item
19
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
20
20
  from rslearn.log_utils import get_logger
21
21
  from rslearn.tile_stores import TileStoreWithLayer
22
22
  from rslearn.utils.geometry import STGeometry
@@ -55,59 +55,40 @@ class ERA5LandMonthlyMeans(DataSource):
55
55
 
56
56
  def __init__(
57
57
  self,
58
- band_names: list[str],
58
+ band_names: list[str] | None = None,
59
59
  api_key: str | None = None,
60
+ context: DataSourceContext = DataSourceContext(),
60
61
  ):
61
62
  """Initialize a new ERA5LandMonthlyMeans instance.
62
63
 
63
64
  Args:
64
65
  band_names: list of band names to acquire. These should correspond to CDS
65
- variable names but with "_" replaced with "-".
66
+ variable names but with "_" replaced with "-". This will only be used
67
+ if the layer config is missing from the context.
66
68
  api_key: the API key. If not set, it should be set via the CDSAPI_KEY
67
69
  environment variable.
70
+ context: the data source context.
68
71
  """
69
- self.band_names = band_names
72
+ self.band_names: list[str]
73
+ if context.layer_config is not None:
74
+ self.band_names = []
75
+ for band_set in context.layer_config.band_sets:
76
+ for band in band_set.bands:
77
+ if band in self.band_names:
78
+ continue
79
+ self.band_names.append(band)
80
+ elif band_names is not None:
81
+ self.band_names = band_names
82
+ else:
83
+ raise ValueError(
84
+ "band_names must be set if layer_config is not in the context"
85
+ )
70
86
 
71
87
  self.client = cdsapi.Client(
72
88
  url=self.api_url,
73
89
  key=api_key,
74
90
  )
75
91
 
76
- @staticmethod
77
- def from_config(
78
- config: RasterLayerConfig, ds_path: UPath
79
- ) -> "ERA5LandMonthlyMeans":
80
- """Creates a new ERA5LandMonthlyMeans instance from a configuration dictionary.
81
-
82
- Args:
83
- config: the LayerConfig of the layer containing this data source
84
- ds_path: the path to the data source
85
-
86
- Returns:
87
- A new ERA5LandMonthlyMeans instance
88
- """
89
- if config.data_source is None:
90
- raise ValueError("data_source is required")
91
- d = config.data_source.config_dict
92
-
93
- # Determine band names based on the configured band sets.
94
- band_names = []
95
- for band_set in config.band_sets:
96
- for band in band_set.bands:
97
- if band in band_names:
98
- continue
99
- band_names.append(band)
100
- kwargs: dict[str, Any] = dict(
101
- band_names=band_names,
102
- )
103
-
104
- simple_optionals = ["api_key"]
105
- for k in simple_optionals:
106
- if k in d:
107
- kwargs[k] = d[k]
108
-
109
- return ERA5LandMonthlyMeans(**kwargs)
110
-
111
92
  def get_items(
112
93
  self, geometries: list[STGeometry], query_config: QueryConfig
113
94
  ) -> list[list[list[Item]]]:
@@ -23,9 +23,9 @@ import requests
23
23
  import shapely
24
24
  from upath import UPath
25
25
 
26
- from rslearn.config import QueryConfig, RasterLayerConfig
26
+ from rslearn.config import QueryConfig
27
27
  from rslearn.const import WGS84_PROJECTION
28
- from rslearn.data_sources.data_source import DataSource, Item
28
+ from rslearn.data_sources.data_source import DataSource, DataSourceContext, Item
29
29
  from rslearn.data_sources.utils import match_candidate_items_to_window
30
30
  from rslearn.log_utils import get_logger
31
31
  from rslearn.tile_stores import TileStoreWithLayer
@@ -306,6 +306,7 @@ class Copernicus(DataSource):
306
306
  sort_by: str | None = None,
307
307
  sort_desc: bool = False,
308
308
  timeout: float = 10,
309
+ context: DataSourceContext = DataSourceContext(),
309
310
  ):
310
311
  """Create a new Copernicus.
311
312
 
@@ -332,6 +333,7 @@ class Copernicus(DataSource):
332
333
  sort_desc: for sort_by, sort in descending order instead of ascending
333
334
  order.
334
335
  timeout: timeout for requests.
336
+ context: the data source context.
335
337
  """
336
338
  self.glob_to_bands = glob_to_bands
337
339
  self.query_filter = query_filter
@@ -351,30 +353,6 @@ class Copernicus(DataSource):
351
353
  self.username = os.environ["COPERNICUS_USERNAME"]
352
354
  self.password = os.environ["COPERNICUS_PASSWORD"]
353
355
 
354
- @staticmethod
355
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Copernicus":
356
- """Creates a new Copernicus instance from a configuration dictionary."""
357
- if config.data_source is None:
358
- raise ValueError("config.data_source is required")
359
- d = config.data_source.config_dict
360
- kwargs: dict[str, Any] = dict(
361
- glob_to_bands=d["glob_to_bands"],
362
- )
363
-
364
- simple_optionals = [
365
- "access_token",
366
- "query_filter",
367
- "order_by",
368
- "sort_by",
369
- "sort_desc",
370
- "timeout",
371
- ]
372
- for k in simple_optionals:
373
- if k in d:
374
- kwargs[k] = d[k]
375
-
376
- return Copernicus(**kwargs)
377
-
378
356
  def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
379
357
  """Deserializes an item from JSON-decoded data."""
380
358
  assert isinstance(serialized_item, dict)
@@ -763,23 +741,43 @@ class Sentinel2(Copernicus):
763
741
 
764
742
  def __init__(
765
743
  self,
766
- assets: list[str],
767
744
  product_type: Sentinel2ProductType,
768
745
  harmonize: bool = False,
746
+ assets: list[str] | None = None,
747
+ context: DataSourceContext = DataSourceContext(),
769
748
  **kwargs: Any,
770
749
  ):
771
750
  """Create a new Sentinel2.
772
751
 
773
752
  Args:
774
- assets: list of assets corresponding to keys in BANDS, e.g. ["TCI", "B08"].
775
753
  product_type: desired product type, L1C or L2A.
776
754
  harmonize: harmonize pixel values across different processing baselines,
777
755
  see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
756
+ assets: the assets to download, or None to download all assets. This is
757
+ only used if the layer config is not in the context.
758
+ context: the data source context.
778
759
  kwargs: additional arguments to pass to Copernicus.
779
760
  """
780
761
  # Create glob to bands map.
762
+ # If the context is provided, we limit to needed assets based on the configured
763
+ # band sets.
764
+ if context.layer_config is not None:
765
+ needed_assets = []
766
+ for asset_key, asset_bands in Sentinel2.BANDS.items():
767
+ # See if the bands provided by this asset intersect with the bands in
768
+ # at least one configured band set.
769
+ for band_set in context.layer_config.band_sets:
770
+ if not set(band_set.bands).intersection(set(asset_bands)):
771
+ continue
772
+ needed_assets.append(asset_key)
773
+ break
774
+ elif assets is not None:
775
+ needed_assets = assets
776
+ else:
777
+ needed_assets = list(Sentinel2.BANDS.keys())
778
+
781
779
  glob_to_bands = {}
782
- for asset_key in assets:
780
+ for asset_key in needed_assets:
783
781
  band_names = self.BANDS[asset_key]
784
782
  glob_pattern = self.GLOB_PATTERNS[product_type][asset_key]
785
783
  glob_to_bands[glob_pattern] = band_names
@@ -788,46 +786,13 @@ class Sentinel2(Copernicus):
788
786
  query_filter = f"Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq '{quote(product_type.value)}')"
789
787
 
790
788
  super().__init__(
789
+ context=context,
791
790
  glob_to_bands=glob_to_bands,
792
791
  query_filter=query_filter,
793
792
  **kwargs,
794
793
  )
795
794
  self.harmonize = harmonize
796
795
 
797
- @staticmethod
798
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
799
- """Creates a new Sentinel2 instance from a configuration dictionary."""
800
- if config.data_source is None:
801
- raise ValueError("config.data_source is required")
802
- d = config.data_source.config_dict
803
-
804
- # Determine needed assets based on the configured band sets.
805
- needed_assets: set[str] = set()
806
- for asset_key, asset_bands in Sentinel2.BANDS.items():
807
- for band_set in config.band_sets:
808
- if not set(band_set.bands).intersection(set(asset_bands)):
809
- continue
810
- needed_assets.add(asset_key)
811
-
812
- kwargs: dict[str, Any] = dict(
813
- assets=list(needed_assets),
814
- product_type=Sentinel2ProductType[d["product_type"]],
815
- )
816
-
817
- simple_optionals = [
818
- "harmonize",
819
- "access_token",
820
- "order_by",
821
- "sort_by",
822
- "sort_desc",
823
- "timeout",
824
- ]
825
- for k in simple_optionals:
826
- if k in d:
827
- kwargs[k] = d[k]
828
-
829
- return Sentinel2(**kwargs)
830
-
831
796
  # Override to support harmonization step.
832
797
  def _process_product_zip(
833
798
  self, tile_store: TileStoreWithLayer, item: CopernicusItem, local_zip_fname: str
@@ -922,6 +887,7 @@ class Sentinel1(Copernicus):
922
887
  product_type: Sentinel1ProductType,
923
888
  polarisation: Sentinel1Polarisation,
924
889
  orbit_direction: Sentinel1OrbitDirection | None = None,
890
+ context: DataSourceContext = DataSourceContext(),
925
891
  **kwargs: Any,
926
892
  ):
927
893
  """Create a new Sentinel1.
@@ -930,6 +896,7 @@ class Sentinel1(Copernicus):
930
896
  product_type: desired product type.
931
897
  polarisation: desired polarisation(s).
932
898
  orbit_direction: optional orbit direction to filter by.
899
+ context: the data source context.
933
900
  kwargs: additional arguments to pass to Copernicus.
934
901
  """
935
902
  # Create query filter based on the product type.
@@ -945,31 +912,3 @@ class Sentinel1(Copernicus):
945
912
  query_filter=query_filter,
946
913
  **kwargs,
947
914
  )
948
-
949
- @staticmethod
950
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1":
951
- """Creates a new Sentinel1 instance from a configuration dictionary."""
952
- if config.data_source is None:
953
- raise ValueError("config.data_source is required")
954
- d = config.data_source.config_dict
955
-
956
- kwargs: dict[str, Any] = dict(
957
- product_type=Sentinel1ProductType[d["product_type"]],
958
- polarisation=Sentinel1Polarisation[d["polarisation"]],
959
- )
960
-
961
- if "orbit_direction" in d:
962
- kwargs["orbit_direction"] = Sentinel1OrbitDirection[d["orbit_direction"]]
963
-
964
- simple_optionals = [
965
- "access_token",
966
- "order_by",
967
- "sort_by",
968
- "sort_desc",
969
- "timeout",
970
- ]
971
- for k in simple_optionals:
972
- if k in d:
973
- kwargs[k] = d[k]
974
-
975
- return Sentinel1(**kwargs)
@@ -3,6 +3,8 @@
3
3
  from collections.abc import Generator
4
4
  from typing import Any, BinaryIO, Generic, TypeVar
5
5
 
6
+ from upath import UPath
7
+
6
8
  from rslearn.config import LayerConfig, QueryConfig
7
9
  from rslearn.dataset import Window
8
10
  from rslearn.tile_stores import TileStoreWithLayer
@@ -127,3 +129,27 @@ class RetrieveItemDataSource(DataSource[ItemType]):
127
129
  ) -> Generator[tuple[str, BinaryIO], None, None]:
128
130
  """Retrieves the rasters corresponding to an item as file streams."""
129
131
  raise NotImplementedError
132
+
133
+
134
+ class DataSourceContext:
135
+ """This context is passed to every data source.
136
+
137
+ When initializing data sources within rslearn, we always set the ds_path and
138
+ layer_config. However, for convenience (for users directly initializing the data
139
+ sources externally), each data source should allow for initialization when one or
140
+ both are missing.
141
+ """
142
+
143
+ def __init__(
144
+ self, ds_path: UPath | None = None, layer_config: LayerConfig | None = None
145
+ ):
146
+ """Create a new DataSourceContext.
147
+
148
+ Args:
149
+ ds_path: the path of the underlying dataset.
150
+ layer_config: the LayerConfig for the layer that the data source is for.
151
+ """
152
+ # We don't use dataclass here because otherwise jsonargparse will ignore our
153
+ # custom serializer/deserializer defined in rslearn.utils.jsonargparse.
154
+ self.ds_path = ds_path
155
+ self.layer_config = layer_config