rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +414 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +28 -26
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +370 -1
- rslearn/main.py +3 -3
- rslearn/models/clay/clay.py +14 -1
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/croma.py +26 -3
- rslearn/models/satlaspretrain.py +18 -4
- rslearn/models/terramind.py +19 -0
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -9,9 +9,9 @@ import requests
|
|
|
9
9
|
import shapely
|
|
10
10
|
from upath import UPath
|
|
11
11
|
|
|
12
|
-
from rslearn.config import
|
|
12
|
+
from rslearn.config import QueryConfig
|
|
13
13
|
from rslearn.const import WGS84_PROJECTION
|
|
14
|
-
from rslearn.data_sources import DataSource, Item
|
|
14
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
15
15
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
16
16
|
from rslearn.log_utils import get_logger
|
|
17
17
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
@@ -68,21 +68,20 @@ class PlanetBasemap(DataSource):
|
|
|
68
68
|
|
|
69
69
|
def __init__(
|
|
70
70
|
self,
|
|
71
|
-
config: RasterLayerConfig,
|
|
72
71
|
series_id: str,
|
|
73
72
|
bands: list[str],
|
|
74
73
|
api_key: str | None = None,
|
|
74
|
+
context: DataSourceContext = DataSourceContext(),
|
|
75
75
|
):
|
|
76
76
|
"""Initialize a new Planet instance.
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
|
-
config: the LayerConfig of the layer containing this data source
|
|
80
79
|
series_id: the series of mosaics to use.
|
|
81
80
|
bands: list of band names to use.
|
|
82
81
|
api_key: optional Planet API key (it can also be provided via PL_API_KEY
|
|
83
82
|
environmnet variable).
|
|
83
|
+
context: the data source context
|
|
84
84
|
"""
|
|
85
|
-
self.config = config
|
|
86
85
|
self.series_id = series_id
|
|
87
86
|
self.bands = bands
|
|
88
87
|
|
|
@@ -94,26 +93,6 @@ class PlanetBasemap(DataSource):
|
|
|
94
93
|
# Lazily load mosaics.
|
|
95
94
|
self.mosaics: dict | None = None
|
|
96
95
|
|
|
97
|
-
@staticmethod
|
|
98
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "PlanetBasemap":
|
|
99
|
-
"""Creates a new PlanetBasemap instance from a configuration dictionary."""
|
|
100
|
-
assert isinstance(config, RasterLayerConfig)
|
|
101
|
-
if config.data_source is None:
|
|
102
|
-
raise ValueError("data_source is required")
|
|
103
|
-
d = config.data_source.config_dict
|
|
104
|
-
kwargs = dict(
|
|
105
|
-
config=config,
|
|
106
|
-
series_id=d["series_id"],
|
|
107
|
-
bands=d["bands"],
|
|
108
|
-
)
|
|
109
|
-
optional_keys = [
|
|
110
|
-
"api_key",
|
|
111
|
-
]
|
|
112
|
-
for optional_key in optional_keys:
|
|
113
|
-
if optional_key in d:
|
|
114
|
-
kwargs[optional_key] = d[optional_key]
|
|
115
|
-
return PlanetBasemap(**kwargs)
|
|
116
|
-
|
|
117
96
|
def _load_mosaics(self) -> dict[str, STGeometry]:
|
|
118
97
|
"""Lazily load mosaics in the configured series_id from Planet API.
|
|
119
98
|
|
|
@@ -18,9 +18,9 @@ import shapely
|
|
|
18
18
|
from rasterio.enums import Resampling
|
|
19
19
|
from upath import UPath
|
|
20
20
|
|
|
21
|
-
from rslearn.config import LayerConfig, QueryConfig
|
|
21
|
+
from rslearn.config import LayerConfig, QueryConfig
|
|
22
22
|
from rslearn.const import WGS84_PROJECTION
|
|
23
|
-
from rslearn.data_sources import DataSource, Item
|
|
23
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
24
24
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
25
25
|
from rslearn.dataset import Window
|
|
26
26
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
@@ -96,8 +96,9 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
96
96
|
sort_ascending: bool = True,
|
|
97
97
|
timeout: timedelta = timedelta(seconds=10),
|
|
98
98
|
skip_items_missing_assets: bool = False,
|
|
99
|
-
cache_dir:
|
|
99
|
+
cache_dir: str | None = None,
|
|
100
100
|
max_items_per_client: int | None = None,
|
|
101
|
+
context: DataSourceContext = DataSourceContext(),
|
|
101
102
|
):
|
|
102
103
|
"""Initialize a new PlanetaryComputer instance.
|
|
103
104
|
|
|
@@ -117,6 +118,7 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
117
118
|
max_items_per_client: number of STAC items to process before recreating
|
|
118
119
|
the client to prevent memory leaks from the resolved objects cache.
|
|
119
120
|
Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
|
|
121
|
+
context: the data source context.
|
|
120
122
|
"""
|
|
121
123
|
self.collection_name = collection_name
|
|
122
124
|
self.asset_bands = asset_bands
|
|
@@ -125,46 +127,23 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
125
127
|
self.sort_ascending = sort_ascending
|
|
126
128
|
self.timeout = timeout
|
|
127
129
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
128
|
-
self.cache_dir = cache_dir
|
|
129
130
|
self.max_items_per_client = (
|
|
130
131
|
max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
|
|
131
132
|
)
|
|
132
133
|
|
|
133
|
-
if
|
|
134
|
+
if cache_dir is not None:
|
|
135
|
+
if context.ds_path is not None:
|
|
136
|
+
self.cache_dir = join_upath(context.ds_path, cache_dir)
|
|
137
|
+
else:
|
|
138
|
+
self.cache_dir = UPath(cache_dir)
|
|
139
|
+
|
|
134
140
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
141
|
+
else:
|
|
142
|
+
self.cache_dir = None
|
|
135
143
|
|
|
136
144
|
self.client: pystac_client.Client | None = None
|
|
137
145
|
self._client_item_count = 0
|
|
138
146
|
|
|
139
|
-
@staticmethod
|
|
140
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer":
|
|
141
|
-
"""Creates a new PlanetaryComputer instance from a configuration dictionary."""
|
|
142
|
-
if config.data_source is None:
|
|
143
|
-
raise ValueError("config.data_source is required")
|
|
144
|
-
d = config.data_source.config_dict
|
|
145
|
-
kwargs: dict[str, Any] = dict(
|
|
146
|
-
collection_name=d["collection_name"],
|
|
147
|
-
asset_bands=d["asset_bands"],
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
if "timeout_seconds" in d:
|
|
151
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
152
|
-
|
|
153
|
-
if "cache_dir" in d:
|
|
154
|
-
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
155
|
-
|
|
156
|
-
simple_optionals = [
|
|
157
|
-
"query",
|
|
158
|
-
"sort_by",
|
|
159
|
-
"sort_ascending",
|
|
160
|
-
"max_items_per_client",
|
|
161
|
-
]
|
|
162
|
-
for k in simple_optionals:
|
|
163
|
-
if k in d:
|
|
164
|
-
kwargs[k] = d[k]
|
|
165
|
-
|
|
166
|
-
return PlanetaryComputer(**kwargs)
|
|
167
|
-
|
|
168
147
|
def _load_client(
|
|
169
148
|
self,
|
|
170
149
|
) -> pystac_client.Client:
|
|
@@ -545,7 +524,6 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
545
524
|
layer_name: the name of this layer
|
|
546
525
|
layer_cfg: the config of this layer
|
|
547
526
|
"""
|
|
548
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
549
527
|
RasterMaterializer().materialize(
|
|
550
528
|
TileStoreWithLayer(self, layer_name),
|
|
551
529
|
window,
|
|
@@ -581,72 +559,48 @@ class Sentinel2(PlanetaryComputer):
|
|
|
581
559
|
|
|
582
560
|
def __init__(
|
|
583
561
|
self,
|
|
584
|
-
assets: list[str] | None = None,
|
|
585
562
|
harmonize: bool = False,
|
|
563
|
+
assets: list[str] | None = None,
|
|
564
|
+
context: DataSourceContext = DataSourceContext(),
|
|
586
565
|
**kwargs: Any,
|
|
587
566
|
):
|
|
588
567
|
"""Initialize a new Sentinel2 instance.
|
|
589
568
|
|
|
590
569
|
Args:
|
|
591
|
-
assets: which assets in BANDS to ingest/materialize. None to ingest all
|
|
592
|
-
assets.
|
|
593
570
|
harmonize: harmonize pixel values across different processing baselines,
|
|
594
571
|
see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
|
|
572
|
+
assets: list of asset names to ingest, or None to ingest all assets. This
|
|
573
|
+
is only used if the layer config is missing from the context.
|
|
574
|
+
context: the data source context.
|
|
595
575
|
kwargs: other arguments to pass to PlanetaryComputer.
|
|
596
576
|
"""
|
|
597
577
|
self.harmonize = harmonize
|
|
598
578
|
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
579
|
+
# Determine which assets we need based on the bands in the layer config.
|
|
580
|
+
if context.layer_config is not None:
|
|
581
|
+
asset_bands: dict[str, list[str]] = {}
|
|
582
|
+
for asset_key, band_names in self.BANDS.items():
|
|
583
|
+
# See if the bands provided by this asset intersect with the bands in
|
|
584
|
+
# at least one configured band set.
|
|
585
|
+
for band_set in context.layer_config.band_sets:
|
|
586
|
+
if not set(band_set.bands).intersection(set(band_names)):
|
|
587
|
+
continue
|
|
588
|
+
asset_bands[asset_key] = band_names
|
|
589
|
+
break
|
|
590
|
+
elif assets is not None:
|
|
602
591
|
asset_bands = {asset_key: self.BANDS[asset_key] for asset_key in assets}
|
|
592
|
+
else:
|
|
593
|
+
asset_bands = self.BANDS
|
|
603
594
|
|
|
604
595
|
super().__init__(
|
|
605
596
|
collection_name=self.COLLECTION_NAME,
|
|
606
597
|
asset_bands=asset_bands,
|
|
607
598
|
# Skip since all of the items should have the same assets.
|
|
608
599
|
skip_items_missing_assets=True,
|
|
600
|
+
context=context,
|
|
609
601
|
**kwargs,
|
|
610
602
|
)
|
|
611
603
|
|
|
612
|
-
@staticmethod
|
|
613
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
|
|
614
|
-
"""Creates a new Sentinel2 instance from a configuration dictionary."""
|
|
615
|
-
if config.data_source is None:
|
|
616
|
-
raise ValueError("config.data_source is required")
|
|
617
|
-
d = config.data_source.config_dict
|
|
618
|
-
|
|
619
|
-
# Determine the needed assets based on the band sets.
|
|
620
|
-
needed_assets: set[str] = set()
|
|
621
|
-
for asset_key, asset_bands in Sentinel2.BANDS.items():
|
|
622
|
-
for band_set in config.band_sets:
|
|
623
|
-
if not set(band_set.bands).intersection(set(asset_bands)):
|
|
624
|
-
continue
|
|
625
|
-
needed_assets.add(asset_key)
|
|
626
|
-
|
|
627
|
-
kwargs: dict[str, Any] = dict(
|
|
628
|
-
assets=list(needed_assets),
|
|
629
|
-
)
|
|
630
|
-
|
|
631
|
-
if "timeout_seconds" in d:
|
|
632
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
633
|
-
|
|
634
|
-
if "cache_dir" in d:
|
|
635
|
-
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
636
|
-
|
|
637
|
-
simple_optionals = [
|
|
638
|
-
"harmonize",
|
|
639
|
-
"query",
|
|
640
|
-
"sort_by",
|
|
641
|
-
"sort_ascending",
|
|
642
|
-
"max_items_per_client",
|
|
643
|
-
]
|
|
644
|
-
for k in simple_optionals:
|
|
645
|
-
if k in d:
|
|
646
|
-
kwargs[k] = d[k]
|
|
647
|
-
|
|
648
|
-
return Sentinel2(**kwargs)
|
|
649
|
-
|
|
650
604
|
def _get_product_xml(self, item: PlanetaryComputerItem) -> ET.Element:
|
|
651
605
|
asset_url = planetary_computer.sign(item.asset_urls["product-metadata"])
|
|
652
606
|
response = requests.get(asset_url, timeout=self.timeout.total_seconds())
|
|
@@ -779,55 +733,42 @@ class Sentinel1(PlanetaryComputer):
|
|
|
779
733
|
|
|
780
734
|
def __init__(
|
|
781
735
|
self,
|
|
782
|
-
band_names: list[str],
|
|
736
|
+
band_names: list[str] | None = None,
|
|
737
|
+
context: DataSourceContext = DataSourceContext(),
|
|
783
738
|
**kwargs: Any,
|
|
784
739
|
):
|
|
785
740
|
"""Initialize a new Sentinel1 instance.
|
|
786
741
|
|
|
787
742
|
Args:
|
|
788
|
-
band_names: list of bands to try to ingest
|
|
743
|
+
band_names: list of bands to try to ingest, if the layer config is missing
|
|
744
|
+
from the context.
|
|
745
|
+
context: the data source context.
|
|
789
746
|
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
790
747
|
"""
|
|
748
|
+
# Get band names from the config if possible. If it isn't in the context, then
|
|
749
|
+
# we have to use the provided band names.
|
|
750
|
+
if context.layer_config is not None:
|
|
751
|
+
band_names = list(
|
|
752
|
+
{
|
|
753
|
+
band
|
|
754
|
+
for band_set in context.layer_config.band_sets
|
|
755
|
+
for band in band_set.bands
|
|
756
|
+
}
|
|
757
|
+
)
|
|
758
|
+
if band_names is None:
|
|
759
|
+
raise ValueError(
|
|
760
|
+
"band_names must be set if layer config is not in the context"
|
|
761
|
+
)
|
|
762
|
+
# For Sentinel-1, the asset key should be the same as the band name (and all
|
|
763
|
+
# assets have one band).
|
|
791
764
|
asset_bands = {band: [band] for band in band_names}
|
|
792
765
|
super().__init__(
|
|
793
766
|
collection_name=self.COLLECTION_NAME,
|
|
794
767
|
asset_bands=asset_bands,
|
|
768
|
+
context=context,
|
|
795
769
|
**kwargs,
|
|
796
770
|
)
|
|
797
771
|
|
|
798
|
-
@staticmethod
|
|
799
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel1":
|
|
800
|
-
"""Creates a new Sentinel1 instance from a configuration dictionary."""
|
|
801
|
-
if config.data_source is None:
|
|
802
|
-
raise ValueError("config.data_source is required")
|
|
803
|
-
d = config.data_source.config_dict
|
|
804
|
-
band_names: set[str] = set()
|
|
805
|
-
for band_set in config.band_sets:
|
|
806
|
-
for band in band_set.bands:
|
|
807
|
-
band_names.add(band)
|
|
808
|
-
|
|
809
|
-
kwargs: dict[str, Any] = dict(
|
|
810
|
-
band_names=list(band_names),
|
|
811
|
-
)
|
|
812
|
-
|
|
813
|
-
if "timeout_seconds" in d:
|
|
814
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
815
|
-
|
|
816
|
-
if "cache_dir" in d:
|
|
817
|
-
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
818
|
-
|
|
819
|
-
simple_optionals = [
|
|
820
|
-
"query",
|
|
821
|
-
"sort_by",
|
|
822
|
-
"sort_ascending",
|
|
823
|
-
"max_items_per_client",
|
|
824
|
-
]
|
|
825
|
-
for k in simple_optionals:
|
|
826
|
-
if k in d:
|
|
827
|
-
kwargs[k] = d[k]
|
|
828
|
-
|
|
829
|
-
return Sentinel1(**kwargs)
|
|
830
|
-
|
|
831
772
|
|
|
832
773
|
class Naip(PlanetaryComputer):
|
|
833
774
|
"""A data source for NAIP data on Microsoft Planetary Computer.
|
|
@@ -840,42 +781,18 @@ class Naip(PlanetaryComputer):
|
|
|
840
781
|
|
|
841
782
|
def __init__(
|
|
842
783
|
self,
|
|
784
|
+
context: DataSourceContext = DataSourceContext(),
|
|
843
785
|
**kwargs: Any,
|
|
844
786
|
):
|
|
845
787
|
"""Initialize a new Naip instance.
|
|
846
788
|
|
|
847
789
|
Args:
|
|
848
|
-
|
|
790
|
+
context: the data source context.
|
|
849
791
|
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
850
792
|
"""
|
|
851
793
|
super().__init__(
|
|
852
794
|
collection_name=self.COLLECTION_NAME,
|
|
853
795
|
asset_bands=self.ASSET_BANDS,
|
|
796
|
+
context=context,
|
|
854
797
|
**kwargs,
|
|
855
798
|
)
|
|
856
|
-
|
|
857
|
-
@staticmethod
|
|
858
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
|
|
859
|
-
"""Creates a new Naip instance from a configuration dictionary."""
|
|
860
|
-
if config.data_source is None:
|
|
861
|
-
raise ValueError("config.data_source is required")
|
|
862
|
-
d = config.data_source.config_dict
|
|
863
|
-
kwargs = {}
|
|
864
|
-
|
|
865
|
-
if "timeout_seconds" in d:
|
|
866
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
867
|
-
|
|
868
|
-
if "cache_dir" in d:
|
|
869
|
-
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
870
|
-
|
|
871
|
-
simple_optionals = [
|
|
872
|
-
"query",
|
|
873
|
-
"sort_by",
|
|
874
|
-
"sort_ascending",
|
|
875
|
-
"max_items_per_client",
|
|
876
|
-
]
|
|
877
|
-
for k in simple_optionals:
|
|
878
|
-
if k in d:
|
|
879
|
-
kwargs[k] = d[k]
|
|
880
|
-
|
|
881
|
-
return Naip(**kwargs)
|
rslearn/data_sources/usda_cdl.py
CHANGED
|
@@ -11,9 +11,9 @@ import requests.auth
|
|
|
11
11
|
import shapely
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
14
|
-
from rslearn.config import QueryConfig
|
|
14
|
+
from rslearn.config import QueryConfig
|
|
15
15
|
from rslearn.const import WGS84_PROJECTION
|
|
16
|
-
from rslearn.data_sources import DataSource, Item
|
|
16
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
17
17
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
18
18
|
from rslearn.log_utils import get_logger
|
|
19
19
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
@@ -60,39 +60,28 @@ class CDL(DataSource):
|
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
62
62
|
self,
|
|
63
|
-
band_name: str = "cdl",
|
|
64
63
|
timeout: timedelta = timedelta(seconds=10),
|
|
64
|
+
context: DataSourceContext = DataSourceContext(),
|
|
65
65
|
):
|
|
66
66
|
"""Initialize a new CDL instance.
|
|
67
67
|
|
|
68
68
|
Args:
|
|
69
|
-
band_name: what to call the band.
|
|
70
69
|
timeout: timeout for requests.
|
|
70
|
+
context: the data source context.
|
|
71
71
|
"""
|
|
72
|
-
self.band_name = band_name
|
|
73
72
|
self.timeout = timeout
|
|
74
73
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
if len(config.band_sets[0].bands) != 1:
|
|
87
|
-
raise ValueError("expected band set to have a single band")
|
|
88
|
-
kwargs: dict[str, Any] = {
|
|
89
|
-
"band_name": config.band_sets[0].bands[0],
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
if "timeout_seconds" in d:
|
|
93
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
94
|
-
|
|
95
|
-
return CDL(**kwargs)
|
|
74
|
+
# Get the band name from the layer config, which should have a single band set
|
|
75
|
+
# with a single band. If the layer config is not available in the context, we
|
|
76
|
+
# default to "cdl".
|
|
77
|
+
if context.layer_config is not None:
|
|
78
|
+
if len(context.layer_config.band_sets) != 1:
|
|
79
|
+
raise ValueError("expected a single band set")
|
|
80
|
+
if len(context.layer_config.band_sets[0].bands) != 1:
|
|
81
|
+
raise ValueError("expected band set to have a single band")
|
|
82
|
+
self.band_name = context.layer_config.band_sets[0].bands[0]
|
|
83
|
+
else:
|
|
84
|
+
self.band_name = "cdl"
|
|
96
85
|
|
|
97
86
|
def get_item_by_name(self, name: str) -> Item:
|
|
98
87
|
"""Gets an item by name.
|
|
@@ -18,9 +18,9 @@ import requests
|
|
|
18
18
|
import shapely
|
|
19
19
|
from upath import UPath
|
|
20
20
|
|
|
21
|
-
from rslearn.config import QueryConfig
|
|
21
|
+
from rslearn.config import QueryConfig
|
|
22
22
|
from rslearn.const import WGS84_PROJECTION
|
|
23
|
-
from rslearn.data_sources import DataSource, Item
|
|
23
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
24
24
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
25
25
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
26
26
|
from rslearn.utils import STGeometry
|
|
@@ -314,25 +314,24 @@ class LandsatOliTirs(DataSource):
|
|
|
314
314
|
|
|
315
315
|
def __init__(
|
|
316
316
|
self,
|
|
317
|
-
config: RasterLayerConfig,
|
|
318
317
|
username: str,
|
|
319
318
|
sort_by: str | None = None,
|
|
320
319
|
password: str | None = None,
|
|
321
320
|
token: str | None = None,
|
|
322
321
|
timeout: timedelta = timedelta(seconds=10),
|
|
322
|
+
context: DataSourceContext = DataSourceContext(),
|
|
323
323
|
):
|
|
324
324
|
"""Initialize a new LandsatOliTirs instance.
|
|
325
325
|
|
|
326
326
|
Args:
|
|
327
|
-
config: the LayerConfig of the layer containing this data source
|
|
328
327
|
username: EROS username
|
|
329
328
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
330
329
|
SpaceMode.WITHIN.
|
|
331
330
|
password: EROS password (see M2MAPIClient).
|
|
332
331
|
token: EROS application token (see M2MAPIClient).
|
|
333
332
|
timeout: timeout for requests.
|
|
333
|
+
context: the data source context.
|
|
334
334
|
"""
|
|
335
|
-
self.config = config
|
|
336
335
|
self.sort_by = sort_by
|
|
337
336
|
self.timeout = timeout
|
|
338
337
|
|
|
@@ -340,30 +339,6 @@ class LandsatOliTirs(DataSource):
|
|
|
340
339
|
username, password=password, token=token, timeout=timeout
|
|
341
340
|
)
|
|
342
341
|
|
|
343
|
-
@staticmethod
|
|
344
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "LandsatOliTirs":
|
|
345
|
-
"""Creates a new LandsatOliTirs instance from a configuration dictionary."""
|
|
346
|
-
if config.data_source is None:
|
|
347
|
-
raise ValueError("data_source is required")
|
|
348
|
-
d = config.data_source.config_dict
|
|
349
|
-
|
|
350
|
-
kwargs = dict(
|
|
351
|
-
config=config,
|
|
352
|
-
username=d["username"],
|
|
353
|
-
sort_by=d.get("sort_by"),
|
|
354
|
-
)
|
|
355
|
-
|
|
356
|
-
if "timeout_seconds" in d:
|
|
357
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
358
|
-
|
|
359
|
-
# Optional config.
|
|
360
|
-
for k in ["password", "token"]:
|
|
361
|
-
if k not in d:
|
|
362
|
-
continue
|
|
363
|
-
kwargs[k] = d[k]
|
|
364
|
-
|
|
365
|
-
return LandsatOliTirs(**kwargs)
|
|
366
|
-
|
|
367
342
|
def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
|
|
368
343
|
"""Convert scene metadata from the API to a LandsatOliTirsItem."""
|
|
369
344
|
metadata_dict = {}
|
rslearn/data_sources/utils.py
CHANGED
|
@@ -8,8 +8,11 @@ import shapely
|
|
|
8
8
|
|
|
9
9
|
from rslearn.config import QueryConfig, SpaceMode, TimeMode
|
|
10
10
|
from rslearn.data_sources import Item
|
|
11
|
+
from rslearn.log_utils import get_logger
|
|
11
12
|
from rslearn.utils import STGeometry, shp_intersects
|
|
12
13
|
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
13
16
|
MOSAIC_MIN_ITEM_COVERAGE = 0.1
|
|
14
17
|
"""Minimum fraction of area that item should cover when adding it to a mosaic group."""
|
|
15
18
|
|
|
@@ -298,6 +301,12 @@ def match_candidate_items_to_window(
|
|
|
298
301
|
|
|
299
302
|
# Enforce minimum matches if set.
|
|
300
303
|
if len(groups) < query_config.min_matches:
|
|
304
|
+
logger.warning(
|
|
305
|
+
"Window rejected: found %d matches (required: %d) for time range %s",
|
|
306
|
+
len(groups),
|
|
307
|
+
query_config.min_matches,
|
|
308
|
+
geometry.time_range if geometry.time_range else "unlimited",
|
|
309
|
+
)
|
|
301
310
|
return []
|
|
302
311
|
|
|
303
312
|
return groups
|
|
@@ -6,18 +6,17 @@ import os
|
|
|
6
6
|
import shutil
|
|
7
7
|
import tempfile
|
|
8
8
|
import zipfile
|
|
9
|
-
from typing import Any
|
|
10
9
|
|
|
11
10
|
import requests
|
|
12
11
|
from fsspec.implementations.local import LocalFileSystem
|
|
13
12
|
from upath import UPath
|
|
14
13
|
|
|
15
|
-
from rslearn.config import
|
|
16
|
-
from rslearn.data_sources.local_files import LocalFiles
|
|
14
|
+
from rslearn.config import LayerType
|
|
15
|
+
from rslearn.data_sources.local_files import LocalFiles, RasterItemSpec
|
|
17
16
|
from rslearn.log_utils import get_logger
|
|
18
17
|
from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
|
|
19
18
|
|
|
20
|
-
from .data_source import Item
|
|
19
|
+
from .data_source import DataSourceContext, Item
|
|
21
20
|
|
|
22
21
|
logger = get_logger(__name__)
|
|
23
22
|
|
|
@@ -237,71 +236,65 @@ class WorldCereal(LocalFiles):
|
|
|
237
236
|
|
|
238
237
|
def __init__(
|
|
239
238
|
self,
|
|
240
|
-
|
|
241
|
-
band: str,
|
|
242
|
-
|
|
239
|
+
worldcereal_dir: str,
|
|
240
|
+
band: str | None = None,
|
|
241
|
+
context: DataSourceContext = DataSourceContext(),
|
|
243
242
|
) -> None:
|
|
244
243
|
"""Create a new WorldCereal.
|
|
245
244
|
|
|
246
245
|
Args:
|
|
247
|
-
config: configuration for this layer.
|
|
248
|
-
band: the worldcereal band being processed.
|
|
249
246
|
worldcereal_dir: the directory to extract the WorldCereal GeoTIFF files. For
|
|
250
247
|
high performance, this should be a local directory; if the dataset is
|
|
251
248
|
remote, prefix with a protocol ("file://") to use a local directory
|
|
252
249
|
instead of a path relative to the dataset path.
|
|
250
|
+
band: the worldcereal band to process. This will only be used if the layer
|
|
251
|
+
config is missing from the context.
|
|
252
|
+
context: the data source context.
|
|
253
253
|
"""
|
|
254
|
-
|
|
255
|
-
|
|
254
|
+
if context.ds_path is not None:
|
|
255
|
+
worldcereal_upath = join_upath(context.ds_path, worldcereal_dir)
|
|
256
|
+
else:
|
|
257
|
+
worldcereal_upath = UPath(worldcereal_dir)
|
|
258
|
+
|
|
259
|
+
if context.layer_config is not None:
|
|
260
|
+
if len(context.layer_config.band_sets) != 1:
|
|
261
|
+
raise ValueError("expected a single band set")
|
|
262
|
+
if len(context.layer_config.band_sets[0].bands) != 1:
|
|
263
|
+
raise ValueError("expected band set to have a single band")
|
|
264
|
+
self.band = context.layer_config.band_sets[0].bands[0]
|
|
265
|
+
elif band is not None:
|
|
266
|
+
self.band = band
|
|
267
|
+
else:
|
|
268
|
+
raise ValueError("band must be set if layer config is not in the context")
|
|
269
|
+
|
|
270
|
+
tif_dir, tif_filepath = self.download_worldcereal_data(
|
|
271
|
+
self.band, worldcereal_upath
|
|
272
|
+
)
|
|
256
273
|
all_aezs: set[int] = self.all_aezs_from_tifs(tif_filepath)
|
|
257
274
|
|
|
258
275
|
# now that we have all our aezs, lets match them to the bands
|
|
259
|
-
|
|
276
|
+
item_specs: list[RasterItemSpec] = []
|
|
260
277
|
for aez in all_aezs:
|
|
261
|
-
|
|
278
|
+
item_spec = RasterItemSpec(
|
|
279
|
+
fnames=[],
|
|
280
|
+
bands=[],
|
|
262
281
|
# must be a str since we / with a posix path later
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
"bands": [],
|
|
266
|
-
}
|
|
282
|
+
name=str(aez),
|
|
283
|
+
)
|
|
267
284
|
aez_band_filepath = self.filepath_for_product_aez(tif_filepath, aez)
|
|
268
285
|
if aez_band_filepath is not None:
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
else:
|
|
282
|
-
config.data_source = DataSourceConfig(
|
|
283
|
-
name="rslearn.data_sources.WorldCereal",
|
|
284
|
-
query_config=QueryConfig.from_config({}),
|
|
285
|
-
config_dict={"item_specs": spec_dicts},
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
super().__init__(config, tif_dir)
|
|
289
|
-
|
|
290
|
-
@staticmethod
|
|
291
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "LocalFiles":
|
|
292
|
-
"""Creates a new LocalFiles instance from a configuration dictionary."""
|
|
293
|
-
if config.data_source is None:
|
|
294
|
-
raise ValueError("LocalFiles data source requires a data source config")
|
|
295
|
-
d = config.data_source.config_dict
|
|
296
|
-
assert isinstance(config, RasterLayerConfig)
|
|
297
|
-
bandsets = config.band_sets
|
|
298
|
-
assert len(bandsets) == 1
|
|
299
|
-
assert len(bandsets[0].bands) == 1
|
|
300
|
-
band = bandsets[0].bands[0]
|
|
301
|
-
return WorldCereal(
|
|
302
|
-
config=config,
|
|
303
|
-
band=band,
|
|
304
|
-
worldcereal_dir=join_upath(ds_path, d["worldcereal_dir"]),
|
|
286
|
+
item_spec.fnames.append(aez_band_filepath.absolute().as_uri())
|
|
287
|
+
assert item_spec.bands is not None
|
|
288
|
+
item_spec.bands.append([self.band])
|
|
289
|
+
item_specs.append(item_spec)
|
|
290
|
+
if len(item_specs) == 0:
|
|
291
|
+
raise ValueError(f"No AEZ files found for {self.band}")
|
|
292
|
+
|
|
293
|
+
super().__init__(
|
|
294
|
+
src_dir=tif_dir,
|
|
295
|
+
raster_item_specs=item_specs,
|
|
296
|
+
layer_type=LayerType.RASTER,
|
|
297
|
+
context=context,
|
|
305
298
|
)
|
|
306
299
|
|
|
307
300
|
@staticmethod
|
|
@@ -441,7 +434,7 @@ class WorldCereal(LocalFiles):
|
|
|
441
434
|
cache_fname = self.src_dir / f"{self.band}_summary.json"
|
|
442
435
|
if not cache_fname.exists():
|
|
443
436
|
logger.debug("cache at %s does not exist, listing items", cache_fname)
|
|
444
|
-
items = self.importer.list_items(self.
|
|
437
|
+
items = self.importer.list_items(self.src_dir)
|
|
445
438
|
serialized_items = [item.serialize() for item in items]
|
|
446
439
|
with cache_fname.open("w") as f:
|
|
447
440
|
json.dump(serialized_items, f)
|