rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +414 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +28 -26
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +370 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/clay/clay.py +14 -1
  33. rslearn/models/concatenate_features.py +93 -0
  34. rslearn/models/croma.py +26 -3
  35. rslearn/models/satlaspretrain.py +18 -4
  36. rslearn/models/terramind.py +19 -0
  37. rslearn/tile_stores/__init__.py +0 -11
  38. rslearn/train/dataset.py +4 -12
  39. rslearn/train/prediction_writer.py +16 -32
  40. rslearn/train/tasks/classification.py +2 -1
  41. rslearn/utils/fsspec.py +20 -0
  42. rslearn/utils/jsonargparse.py +79 -0
  43. rslearn/utils/raster_format.py +1 -41
  44. rslearn/utils/vector_format.py +1 -38
  45. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
  46. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
  47. rslearn/data_sources/geotiff.py +0 -1
  48. rslearn/data_sources/raster_source.py +0 -23
  49. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
  50. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
  51. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
  52. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
  53. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
@@ -9,9 +9,9 @@ import requests
9
9
  import shapely
10
10
  from upath import UPath
11
11
 
12
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
12
+ from rslearn.config import QueryConfig
13
13
  from rslearn.const import WGS84_PROJECTION
14
- from rslearn.data_sources import DataSource, Item
14
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
15
15
  from rslearn.data_sources.utils import match_candidate_items_to_window
16
16
  from rslearn.log_utils import get_logger
17
17
  from rslearn.tile_stores import TileStoreWithLayer
@@ -68,21 +68,20 @@ class PlanetBasemap(DataSource):
68
68
 
69
69
  def __init__(
70
70
  self,
71
- config: RasterLayerConfig,
72
71
  series_id: str,
73
72
  bands: list[str],
74
73
  api_key: str | None = None,
74
+ context: DataSourceContext = DataSourceContext(),
75
75
  ):
76
76
  """Initialize a new Planet instance.
77
77
 
78
78
  Args:
79
- config: the LayerConfig of the layer containing this data source
80
79
  series_id: the series of mosaics to use.
81
80
  bands: list of band names to use.
82
81
  api_key: optional Planet API key (it can also be provided via PL_API_KEY
83
82
  environmnet variable).
83
+ context: the data source context
84
84
  """
85
- self.config = config
86
85
  self.series_id = series_id
87
86
  self.bands = bands
88
87
 
@@ -94,26 +93,6 @@ class PlanetBasemap(DataSource):
94
93
  # Lazily load mosaics.
95
94
  self.mosaics: dict | None = None
96
95
 
97
- @staticmethod
98
- def from_config(config: LayerConfig, ds_path: UPath) -> "PlanetBasemap":
99
- """Creates a new PlanetBasemap instance from a configuration dictionary."""
100
- assert isinstance(config, RasterLayerConfig)
101
- if config.data_source is None:
102
- raise ValueError("data_source is required")
103
- d = config.data_source.config_dict
104
- kwargs = dict(
105
- config=config,
106
- series_id=d["series_id"],
107
- bands=d["bands"],
108
- )
109
- optional_keys = [
110
- "api_key",
111
- ]
112
- for optional_key in optional_keys:
113
- if optional_key in d:
114
- kwargs[optional_key] = d[optional_key]
115
- return PlanetBasemap(**kwargs)
116
-
117
96
  def _load_mosaics(self) -> dict[str, STGeometry]:
118
97
  """Lazily load mosaics in the configured series_id from Planet API.
119
98
 
@@ -18,9 +18,9 @@ import shapely
18
18
  from rasterio.enums import Resampling
19
19
  from upath import UPath
20
20
 
21
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
21
+ from rslearn.config import LayerConfig, QueryConfig
22
22
  from rslearn.const import WGS84_PROJECTION
23
- from rslearn.data_sources import DataSource, Item
23
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
24
24
  from rslearn.data_sources.utils import match_candidate_items_to_window
25
25
  from rslearn.dataset import Window
26
26
  from rslearn.dataset.materialize import RasterMaterializer
@@ -96,8 +96,9 @@ class PlanetaryComputer(DataSource, TileStore):
96
96
  sort_ascending: bool = True,
97
97
  timeout: timedelta = timedelta(seconds=10),
98
98
  skip_items_missing_assets: bool = False,
99
- cache_dir: UPath | None = None,
99
+ cache_dir: str | None = None,
100
100
  max_items_per_client: int | None = None,
101
+ context: DataSourceContext = DataSourceContext(),
101
102
  ):
102
103
  """Initialize a new PlanetaryComputer instance.
103
104
 
@@ -117,6 +118,7 @@ class PlanetaryComputer(DataSource, TileStore):
117
118
  max_items_per_client: number of STAC items to process before recreating
118
119
  the client to prevent memory leaks from the resolved objects cache.
119
120
  Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
121
+ context: the data source context.
120
122
  """
121
123
  self.collection_name = collection_name
122
124
  self.asset_bands = asset_bands
@@ -125,46 +127,23 @@ class PlanetaryComputer(DataSource, TileStore):
125
127
  self.sort_ascending = sort_ascending
126
128
  self.timeout = timeout
127
129
  self.skip_items_missing_assets = skip_items_missing_assets
128
- self.cache_dir = cache_dir
129
130
  self.max_items_per_client = (
130
131
  max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
131
132
  )
132
133
 
133
- if self.cache_dir is not None:
134
+ if cache_dir is not None:
135
+ if context.ds_path is not None:
136
+ self.cache_dir = join_upath(context.ds_path, cache_dir)
137
+ else:
138
+ self.cache_dir = UPath(cache_dir)
139
+
134
140
  self.cache_dir.mkdir(parents=True, exist_ok=True)
141
+ else:
142
+ self.cache_dir = None
135
143
 
136
144
  self.client: pystac_client.Client | None = None
137
145
  self._client_item_count = 0
138
146
 
139
- @staticmethod
140
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer":
141
- """Creates a new PlanetaryComputer instance from a configuration dictionary."""
142
- if config.data_source is None:
143
- raise ValueError("config.data_source is required")
144
- d = config.data_source.config_dict
145
- kwargs: dict[str, Any] = dict(
146
- collection_name=d["collection_name"],
147
- asset_bands=d["asset_bands"],
148
- )
149
-
150
- if "timeout_seconds" in d:
151
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
152
-
153
- if "cache_dir" in d:
154
- kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
155
-
156
- simple_optionals = [
157
- "query",
158
- "sort_by",
159
- "sort_ascending",
160
- "max_items_per_client",
161
- ]
162
- for k in simple_optionals:
163
- if k in d:
164
- kwargs[k] = d[k]
165
-
166
- return PlanetaryComputer(**kwargs)
167
-
168
147
  def _load_client(
169
148
  self,
170
149
  ) -> pystac_client.Client:
@@ -545,7 +524,6 @@ class PlanetaryComputer(DataSource, TileStore):
545
524
  layer_name: the name of this layer
546
525
  layer_cfg: the config of this layer
547
526
  """
548
- assert isinstance(layer_cfg, RasterLayerConfig)
549
527
  RasterMaterializer().materialize(
550
528
  TileStoreWithLayer(self, layer_name),
551
529
  window,
@@ -581,72 +559,48 @@ class Sentinel2(PlanetaryComputer):
581
559
 
582
560
  def __init__(
583
561
  self,
584
- assets: list[str] | None = None,
585
562
  harmonize: bool = False,
563
+ assets: list[str] | None = None,
564
+ context: DataSourceContext = DataSourceContext(),
586
565
  **kwargs: Any,
587
566
  ):
588
567
  """Initialize a new Sentinel2 instance.
589
568
 
590
569
  Args:
591
- assets: which assets in BANDS to ingest/materialize. None to ingest all
592
- assets.
593
570
  harmonize: harmonize pixel values across different processing baselines,
594
571
  see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
572
+ assets: list of asset names to ingest, or None to ingest all assets. This
573
+ is only used if the layer config is missing from the context.
574
+ context: the data source context.
595
575
  kwargs: other arguments to pass to PlanetaryComputer.
596
576
  """
597
577
  self.harmonize = harmonize
598
578
 
599
- if assets is None:
600
- asset_bands = self.BANDS
601
- else:
579
+ # Determine which assets we need based on the bands in the layer config.
580
+ if context.layer_config is not None:
581
+ asset_bands: dict[str, list[str]] = {}
582
+ for asset_key, band_names in self.BANDS.items():
583
+ # See if the bands provided by this asset intersect with the bands in
584
+ # at least one configured band set.
585
+ for band_set in context.layer_config.band_sets:
586
+ if not set(band_set.bands).intersection(set(band_names)):
587
+ continue
588
+ asset_bands[asset_key] = band_names
589
+ break
590
+ elif assets is not None:
602
591
  asset_bands = {asset_key: self.BANDS[asset_key] for asset_key in assets}
592
+ else:
593
+ asset_bands = self.BANDS
603
594
 
604
595
  super().__init__(
605
596
  collection_name=self.COLLECTION_NAME,
606
597
  asset_bands=asset_bands,
607
598
  # Skip since all of the items should have the same assets.
608
599
  skip_items_missing_assets=True,
600
+ context=context,
609
601
  **kwargs,
610
602
  )
611
603
 
612
- @staticmethod
613
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
614
- """Creates a new Sentinel2 instance from a configuration dictionary."""
615
- if config.data_source is None:
616
- raise ValueError("config.data_source is required")
617
- d = config.data_source.config_dict
618
-
619
- # Determine the needed assets based on the band sets.
620
- needed_assets: set[str] = set()
621
- for asset_key, asset_bands in Sentinel2.BANDS.items():
622
- for band_set in config.band_sets:
623
- if not set(band_set.bands).intersection(set(asset_bands)):
624
- continue
625
- needed_assets.add(asset_key)
626
-
627
- kwargs: dict[str, Any] = dict(
628
- assets=list(needed_assets),
629
- )
630
-
631
- if "timeout_seconds" in d:
632
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
633
-
634
- if "cache_dir" in d:
635
- kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
636
-
637
- simple_optionals = [
638
- "harmonize",
639
- "query",
640
- "sort_by",
641
- "sort_ascending",
642
- "max_items_per_client",
643
- ]
644
- for k in simple_optionals:
645
- if k in d:
646
- kwargs[k] = d[k]
647
-
648
- return Sentinel2(**kwargs)
649
-
650
604
  def _get_product_xml(self, item: PlanetaryComputerItem) -> ET.Element:
651
605
  asset_url = planetary_computer.sign(item.asset_urls["product-metadata"])
652
606
  response = requests.get(asset_url, timeout=self.timeout.total_seconds())
@@ -779,55 +733,42 @@ class Sentinel1(PlanetaryComputer):
779
733
 
780
734
  def __init__(
781
735
  self,
782
- band_names: list[str],
736
+ band_names: list[str] | None = None,
737
+ context: DataSourceContext = DataSourceContext(),
783
738
  **kwargs: Any,
784
739
  ):
785
740
  """Initialize a new Sentinel1 instance.
786
741
 
787
742
  Args:
788
- band_names: list of bands to try to ingest.
743
+ band_names: list of bands to try to ingest, if the layer config is missing
744
+ from the context.
745
+ context: the data source context.
789
746
  kwargs: additional arguments to pass to PlanetaryComputer.
790
747
  """
748
+ # Get band names from the config if possible. If it isn't in the context, then
749
+ # we have to use the provided band names.
750
+ if context.layer_config is not None:
751
+ band_names = list(
752
+ {
753
+ band
754
+ for band_set in context.layer_config.band_sets
755
+ for band in band_set.bands
756
+ }
757
+ )
758
+ if band_names is None:
759
+ raise ValueError(
760
+ "band_names must be set if layer config is not in the context"
761
+ )
762
+ # For Sentinel-1, the asset key should be the same as the band name (and all
763
+ # assets have one band).
791
764
  asset_bands = {band: [band] for band in band_names}
792
765
  super().__init__(
793
766
  collection_name=self.COLLECTION_NAME,
794
767
  asset_bands=asset_bands,
768
+ context=context,
795
769
  **kwargs,
796
770
  )
797
771
 
798
- @staticmethod
799
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1":
800
- """Creates a new Sentinel1 instance from a configuration dictionary."""
801
- if config.data_source is None:
802
- raise ValueError("config.data_source is required")
803
- d = config.data_source.config_dict
804
- band_names: set[str] = set()
805
- for band_set in config.band_sets:
806
- for band in band_set.bands:
807
- band_names.add(band)
808
-
809
- kwargs: dict[str, Any] = dict(
810
- band_names=list(band_names),
811
- )
812
-
813
- if "timeout_seconds" in d:
814
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
815
-
816
- if "cache_dir" in d:
817
- kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
818
-
819
- simple_optionals = [
820
- "query",
821
- "sort_by",
822
- "sort_ascending",
823
- "max_items_per_client",
824
- ]
825
- for k in simple_optionals:
826
- if k in d:
827
- kwargs[k] = d[k]
828
-
829
- return Sentinel1(**kwargs)
830
-
831
772
 
832
773
  class Naip(PlanetaryComputer):
833
774
  """A data source for NAIP data on Microsoft Planetary Computer.
@@ -840,42 +781,18 @@ class Naip(PlanetaryComputer):
840
781
 
841
782
  def __init__(
842
783
  self,
784
+ context: DataSourceContext = DataSourceContext(),
843
785
  **kwargs: Any,
844
786
  ):
845
787
  """Initialize a new Naip instance.
846
788
 
847
789
  Args:
848
- band_names: list of bands to try to ingest.
790
+ context: the data source context.
849
791
  kwargs: additional arguments to pass to PlanetaryComputer.
850
792
  """
851
793
  super().__init__(
852
794
  collection_name=self.COLLECTION_NAME,
853
795
  asset_bands=self.ASSET_BANDS,
796
+ context=context,
854
797
  **kwargs,
855
798
  )
856
-
857
- @staticmethod
858
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
859
- """Creates a new Naip instance from a configuration dictionary."""
860
- if config.data_source is None:
861
- raise ValueError("config.data_source is required")
862
- d = config.data_source.config_dict
863
- kwargs = {}
864
-
865
- if "timeout_seconds" in d:
866
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
867
-
868
- if "cache_dir" in d:
869
- kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
870
-
871
- simple_optionals = [
872
- "query",
873
- "sort_by",
874
- "sort_ascending",
875
- "max_items_per_client",
876
- ]
877
- for k in simple_optionals:
878
- if k in d:
879
- kwargs[k] = d[k]
880
-
881
- return Naip(**kwargs)
@@ -11,9 +11,9 @@ import requests.auth
11
11
  import shapely
12
12
  from upath import UPath
13
13
 
14
- from rslearn.config import QueryConfig, RasterLayerConfig
14
+ from rslearn.config import QueryConfig
15
15
  from rslearn.const import WGS84_PROJECTION
16
- from rslearn.data_sources import DataSource, Item
16
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
17
17
  from rslearn.data_sources.utils import match_candidate_items_to_window
18
18
  from rslearn.log_utils import get_logger
19
19
  from rslearn.tile_stores import TileStoreWithLayer
@@ -60,39 +60,28 @@ class CDL(DataSource):
60
60
 
61
61
  def __init__(
62
62
  self,
63
- band_name: str = "cdl",
64
63
  timeout: timedelta = timedelta(seconds=10),
64
+ context: DataSourceContext = DataSourceContext(),
65
65
  ):
66
66
  """Initialize a new CDL instance.
67
67
 
68
68
  Args:
69
- band_name: what to call the band.
70
69
  timeout: timeout for requests.
70
+ context: the data source context.
71
71
  """
72
- self.band_name = band_name
73
72
  self.timeout = timeout
74
73
 
75
- @staticmethod
76
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "CDL":
77
- """Creates a new CDL instance from a configuration dictionary."""
78
- if config.data_source is None:
79
- raise ValueError("config.data_source is required")
80
- d = config.data_source.config_dict
81
-
82
- # Get the band name chosen by the user.
83
- # There should be a single band set with a single band.
84
- if len(config.band_sets) != 1:
85
- raise ValueError("expected a single band set")
86
- if len(config.band_sets[0].bands) != 1:
87
- raise ValueError("expected band set to have a single band")
88
- kwargs: dict[str, Any] = {
89
- "band_name": config.band_sets[0].bands[0],
90
- }
91
-
92
- if "timeout_seconds" in d:
93
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
94
-
95
- return CDL(**kwargs)
74
+ # Get the band name from the layer config, which should have a single band set
75
+ # with a single band. If the layer config is not available in the context, we
76
+ # default to "cdl".
77
+ if context.layer_config is not None:
78
+ if len(context.layer_config.band_sets) != 1:
79
+ raise ValueError("expected a single band set")
80
+ if len(context.layer_config.band_sets[0].bands) != 1:
81
+ raise ValueError("expected band set to have a single band")
82
+ self.band_name = context.layer_config.band_sets[0].bands[0]
83
+ else:
84
+ self.band_name = "cdl"
96
85
 
97
86
  def get_item_by_name(self, name: str) -> Item:
98
87
  """Gets an item by name.
@@ -18,9 +18,9 @@ import requests
18
18
  import shapely
19
19
  from upath import UPath
20
20
 
21
- from rslearn.config import QueryConfig, RasterLayerConfig
21
+ from rslearn.config import QueryConfig
22
22
  from rslearn.const import WGS84_PROJECTION
23
- from rslearn.data_sources import DataSource, Item
23
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
24
24
  from rslearn.data_sources.utils import match_candidate_items_to_window
25
25
  from rslearn.tile_stores import TileStoreWithLayer
26
26
  from rslearn.utils import STGeometry
@@ -314,25 +314,24 @@ class LandsatOliTirs(DataSource):
314
314
 
315
315
  def __init__(
316
316
  self,
317
- config: RasterLayerConfig,
318
317
  username: str,
319
318
  sort_by: str | None = None,
320
319
  password: str | None = None,
321
320
  token: str | None = None,
322
321
  timeout: timedelta = timedelta(seconds=10),
322
+ context: DataSourceContext = DataSourceContext(),
323
323
  ):
324
324
  """Initialize a new LandsatOliTirs instance.
325
325
 
326
326
  Args:
327
- config: the LayerConfig of the layer containing this data source
328
327
  username: EROS username
329
328
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
330
329
  SpaceMode.WITHIN.
331
330
  password: EROS password (see M2MAPIClient).
332
331
  token: EROS application token (see M2MAPIClient).
333
332
  timeout: timeout for requests.
333
+ context: the data source context.
334
334
  """
335
- self.config = config
336
335
  self.sort_by = sort_by
337
336
  self.timeout = timeout
338
337
 
@@ -340,30 +339,6 @@ class LandsatOliTirs(DataSource):
340
339
  username, password=password, token=token, timeout=timeout
341
340
  )
342
341
 
343
- @staticmethod
344
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "LandsatOliTirs":
345
- """Creates a new LandsatOliTirs instance from a configuration dictionary."""
346
- if config.data_source is None:
347
- raise ValueError("data_source is required")
348
- d = config.data_source.config_dict
349
-
350
- kwargs = dict(
351
- config=config,
352
- username=d["username"],
353
- sort_by=d.get("sort_by"),
354
- )
355
-
356
- if "timeout_seconds" in d:
357
- kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
358
-
359
- # Optional config.
360
- for k in ["password", "token"]:
361
- if k not in d:
362
- continue
363
- kwargs[k] = d[k]
364
-
365
- return LandsatOliTirs(**kwargs)
366
-
367
342
  def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
368
343
  """Convert scene metadata from the API to a LandsatOliTirsItem."""
369
344
  metadata_dict = {}
@@ -8,8 +8,11 @@ import shapely
8
8
 
9
9
  from rslearn.config import QueryConfig, SpaceMode, TimeMode
10
10
  from rslearn.data_sources import Item
11
+ from rslearn.log_utils import get_logger
11
12
  from rslearn.utils import STGeometry, shp_intersects
12
13
 
14
+ logger = get_logger(__name__)
15
+
13
16
  MOSAIC_MIN_ITEM_COVERAGE = 0.1
14
17
  """Minimum fraction of area that item should cover when adding it to a mosaic group."""
15
18
 
@@ -298,6 +301,12 @@ def match_candidate_items_to_window(
298
301
 
299
302
  # Enforce minimum matches if set.
300
303
  if len(groups) < query_config.min_matches:
304
+ logger.warning(
305
+ "Window rejected: found %d matches (required: %d) for time range %s",
306
+ len(groups),
307
+ query_config.min_matches,
308
+ geometry.time_range if geometry.time_range else "unlimited",
309
+ )
301
310
  return []
302
311
 
303
312
  return groups
@@ -6,18 +6,17 @@ import os
6
6
  import shutil
7
7
  import tempfile
8
8
  import zipfile
9
- from typing import Any
10
9
 
11
10
  import requests
12
11
  from fsspec.implementations.local import LocalFileSystem
13
12
  from upath import UPath
14
13
 
15
- from rslearn.config import DataSourceConfig, LayerConfig, QueryConfig, RasterLayerConfig
16
- from rslearn.data_sources.local_files import LocalFiles
14
+ from rslearn.config import LayerType
15
+ from rslearn.data_sources.local_files import LocalFiles, RasterItemSpec
17
16
  from rslearn.log_utils import get_logger
18
17
  from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
19
18
 
20
- from .data_source import Item
19
+ from .data_source import DataSourceContext, Item
21
20
 
22
21
  logger = get_logger(__name__)
23
22
 
@@ -237,71 +236,65 @@ class WorldCereal(LocalFiles):
237
236
 
238
237
  def __init__(
239
238
  self,
240
- config: LayerConfig,
241
- band: str,
242
- worldcereal_dir: UPath,
239
+ worldcereal_dir: str,
240
+ band: str | None = None,
241
+ context: DataSourceContext = DataSourceContext(),
243
242
  ) -> None:
244
243
  """Create a new WorldCereal.
245
244
 
246
245
  Args:
247
- config: configuration for this layer.
248
- band: the worldcereal band being processed.
249
246
  worldcereal_dir: the directory to extract the WorldCereal GeoTIFF files. For
250
247
  high performance, this should be a local directory; if the dataset is
251
248
  remote, prefix with a protocol ("file://") to use a local directory
252
249
  instead of a path relative to the dataset path.
250
+ band: the worldcereal band to process. This will only be used if the layer
251
+ config is missing from the context.
252
+ context: the data source context.
253
253
  """
254
- self.band = band
255
- tif_dir, tif_filepath = self.download_worldcereal_data(band, worldcereal_dir)
254
+ if context.ds_path is not None:
255
+ worldcereal_upath = join_upath(context.ds_path, worldcereal_dir)
256
+ else:
257
+ worldcereal_upath = UPath(worldcereal_dir)
258
+
259
+ if context.layer_config is not None:
260
+ if len(context.layer_config.band_sets) != 1:
261
+ raise ValueError("expected a single band set")
262
+ if len(context.layer_config.band_sets[0].bands) != 1:
263
+ raise ValueError("expected band set to have a single band")
264
+ self.band = context.layer_config.band_sets[0].bands[0]
265
+ elif band is not None:
266
+ self.band = band
267
+ else:
268
+ raise ValueError("band must be set if layer config is not in the context")
269
+
270
+ tif_dir, tif_filepath = self.download_worldcereal_data(
271
+ self.band, worldcereal_upath
272
+ )
256
273
  all_aezs: set[int] = self.all_aezs_from_tifs(tif_filepath)
257
274
 
258
275
  # now that we have all our aezs, lets match them to the bands
259
- spec_dicts: list[dict] = []
276
+ item_specs: list[RasterItemSpec] = []
260
277
  for aez in all_aezs:
261
- spec_dict: dict[str, Any] = {
278
+ item_spec = RasterItemSpec(
279
+ fnames=[],
280
+ bands=[],
262
281
  # must be a str since we / with a posix path later
263
- "name": str(aez),
264
- "fnames": [],
265
- "bands": [],
266
- }
282
+ name=str(aez),
283
+ )
267
284
  aez_band_filepath = self.filepath_for_product_aez(tif_filepath, aez)
268
285
  if aez_band_filepath is not None:
269
- spec_dict["fnames"].append(aez_band_filepath.absolute().as_uri())
270
- spec_dict["bands"].append([band])
271
- spec_dicts.append(spec_dict)
272
- if len(spec_dicts) == 0:
273
- raise ValueError(f"No AEZ files found for {band}")
274
- # add this to the config
275
- if config.data_source is not None:
276
- if "item_specs" in config.data_source.config_dict:
277
- logger.warning(
278
- "Overwriting item_specs in WorldCereal config.data_source"
279
- )
280
- config.data_source.config_dict["item_specs"] = spec_dicts
281
- else:
282
- config.data_source = DataSourceConfig(
283
- name="rslearn.data_sources.WorldCereal",
284
- query_config=QueryConfig.from_config({}),
285
- config_dict={"item_specs": spec_dicts},
286
- )
287
-
288
- super().__init__(config, tif_dir)
289
-
290
- @staticmethod
291
- def from_config(config: LayerConfig, ds_path: UPath) -> "LocalFiles":
292
- """Creates a new LocalFiles instance from a configuration dictionary."""
293
- if config.data_source is None:
294
- raise ValueError("LocalFiles data source requires a data source config")
295
- d = config.data_source.config_dict
296
- assert isinstance(config, RasterLayerConfig)
297
- bandsets = config.band_sets
298
- assert len(bandsets) == 1
299
- assert len(bandsets[0].bands) == 1
300
- band = bandsets[0].bands[0]
301
- return WorldCereal(
302
- config=config,
303
- band=band,
304
- worldcereal_dir=join_upath(ds_path, d["worldcereal_dir"]),
286
+ item_spec.fnames.append(aez_band_filepath.absolute().as_uri())
287
+ assert item_spec.bands is not None
288
+ item_spec.bands.append([self.band])
289
+ item_specs.append(item_spec)
290
+ if len(item_specs) == 0:
291
+ raise ValueError(f"No AEZ files found for {self.band}")
292
+
293
+ super().__init__(
294
+ src_dir=tif_dir,
295
+ raster_item_specs=item_specs,
296
+ layer_type=LayerType.RASTER,
297
+ context=context,
305
298
  )
306
299
 
307
300
  @staticmethod
@@ -441,7 +434,7 @@ class WorldCereal(LocalFiles):
441
434
  cache_fname = self.src_dir / f"{self.band}_summary.json"
442
435
  if not cache_fname.exists():
443
436
  logger.debug("cache at %s does not exist, listing items", cache_fname)
444
- items = self.importer.list_items(self.config, self.src_dir)
437
+ items = self.importer.list_items(self.src_dir)
445
438
  serialized_items = [item.serialize() for item in items]
446
439
  with cache_fname.open("w") as f:
447
440
  json.dump(serialized_items, f)