rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +414 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +28 -26
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +370 -1
- rslearn/main.py +3 -3
- rslearn/models/clay/clay.py +14 -1
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/croma.py +26 -3
- rslearn/models/satlaspretrain.py +18 -4
- rslearn/models/terramind.py +19 -0
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -17,9 +17,9 @@ from earthdaily import EDSClient, EDSConfig
|
|
|
17
17
|
from rasterio.enums import Resampling
|
|
18
18
|
from upath import UPath
|
|
19
19
|
|
|
20
|
-
from rslearn.config import LayerConfig, QueryConfig
|
|
20
|
+
from rslearn.config import LayerConfig, QueryConfig
|
|
21
21
|
from rslearn.const import WGS84_PROJECTION
|
|
22
|
-
from rslearn.data_sources import DataSource, Item
|
|
22
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
23
23
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
24
24
|
from rslearn.dataset import Window
|
|
25
25
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
@@ -81,10 +81,11 @@ class EarthDaily(DataSource, TileStore):
|
|
|
81
81
|
sort_ascending: bool = True,
|
|
82
82
|
timeout: timedelta = timedelta(seconds=10),
|
|
83
83
|
skip_items_missing_assets: bool = False,
|
|
84
|
-
cache_dir:
|
|
84
|
+
cache_dir: str | None = None,
|
|
85
85
|
max_retries: int = 3,
|
|
86
86
|
retry_backoff_factor: float = 5.0,
|
|
87
87
|
service_name: Literal["platform"] = "platform",
|
|
88
|
+
context: DataSourceContext = DataSourceContext(),
|
|
88
89
|
):
|
|
89
90
|
"""Initialize a new EarthDaily instance.
|
|
90
91
|
|
|
@@ -108,6 +109,7 @@ class EarthDaily(DataSource, TileStore):
|
|
|
108
109
|
`(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
|
|
109
110
|
service_name: the service name, only "platform" is supported, the other
|
|
110
111
|
services "legacy" and "internal" are not supported.
|
|
112
|
+
context: the data source context.
|
|
111
113
|
"""
|
|
112
114
|
self.collection_name = collection_name
|
|
113
115
|
self.asset_bands = asset_bands
|
|
@@ -116,51 +118,25 @@ class EarthDaily(DataSource, TileStore):
|
|
|
116
118
|
self.sort_ascending = sort_ascending
|
|
117
119
|
self.timeout = timeout
|
|
118
120
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
119
|
-
self.cache_dir = cache_dir
|
|
120
121
|
self.max_retries = max_retries
|
|
121
122
|
self.retry_backoff_factor = retry_backoff_factor
|
|
122
123
|
self.service_name = service_name
|
|
123
124
|
|
|
124
125
|
if cache_dir is not None:
|
|
125
|
-
|
|
126
|
+
# Use dataset path as root if provided.
|
|
127
|
+
if context.ds_path is not None:
|
|
128
|
+
self.cache_dir = join_upath(context.ds_path, cache_dir)
|
|
129
|
+
else:
|
|
130
|
+
self.cache_dir = UPath(cache_dir)
|
|
131
|
+
|
|
126
132
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
else:
|
|
134
|
+
self.cache_dir = None
|
|
127
135
|
|
|
128
136
|
self.eds_client: EDSClient | None = None
|
|
129
137
|
self.client: pystac_client.Client | None = None
|
|
130
138
|
self.collection: pystac_client.CollectionClient | None = None
|
|
131
139
|
|
|
132
|
-
@staticmethod
|
|
133
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "EarthDaily":
|
|
134
|
-
"""Creates a new EarthDaily instance from a configuration dictionary."""
|
|
135
|
-
if config.data_source is None:
|
|
136
|
-
raise ValueError("config.data_source is required")
|
|
137
|
-
d = config.data_source.config_dict
|
|
138
|
-
|
|
139
|
-
kwargs: dict[str, Any] = dict(
|
|
140
|
-
collection_name=d["collection_name"],
|
|
141
|
-
service_name=d["service_name"],
|
|
142
|
-
asset_bands=d["asset_bands"],
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
if "timeout_seconds" in d:
|
|
146
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
147
|
-
|
|
148
|
-
if "cache_dir" in d:
|
|
149
|
-
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
150
|
-
|
|
151
|
-
if "max_retries" in d:
|
|
152
|
-
kwargs["max_retries"] = d["max_retries"]
|
|
153
|
-
|
|
154
|
-
if "retry_backoff_factor" in d:
|
|
155
|
-
kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
|
|
156
|
-
|
|
157
|
-
simple_optionals = ["query", "sort_by", "sort_ascending"]
|
|
158
|
-
for k in simple_optionals:
|
|
159
|
-
if k in d:
|
|
160
|
-
kwargs[k] = d[k]
|
|
161
|
-
|
|
162
|
-
return EarthDaily(**kwargs)
|
|
163
|
-
|
|
164
140
|
def _load_client(
|
|
165
141
|
self,
|
|
166
142
|
) -> tuple[EDSClient, pystac_client.Client, pystac_client.CollectionClient]:
|
|
@@ -499,7 +475,6 @@ class EarthDaily(DataSource, TileStore):
|
|
|
499
475
|
layer_name: the name of this layer
|
|
500
476
|
layer_cfg: the config of this layer
|
|
501
477
|
"""
|
|
502
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
503
478
|
RasterMaterializer().materialize(
|
|
504
479
|
TileStoreWithLayer(self, layer_name),
|
|
505
480
|
window,
|
|
@@ -12,9 +12,9 @@ import requests.auth
|
|
|
12
12
|
import shapely
|
|
13
13
|
from upath import UPath
|
|
14
14
|
|
|
15
|
-
from rslearn.config import QueryConfig,
|
|
15
|
+
from rslearn.config import QueryConfig, SpaceMode
|
|
16
16
|
from rslearn.const import WGS84_PROJECTION
|
|
17
|
-
from rslearn.data_sources import DataSource, Item
|
|
17
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
18
18
|
from rslearn.log_utils import get_logger
|
|
19
19
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
20
20
|
from rslearn.utils.geometry import STGeometry
|
|
@@ -47,8 +47,8 @@ class SRTM(DataSource):
|
|
|
47
47
|
self,
|
|
48
48
|
username: str | None = None,
|
|
49
49
|
password: str | None = None,
|
|
50
|
-
band_name: str = "srtm",
|
|
51
50
|
timeout: timedelta = timedelta(seconds=10),
|
|
51
|
+
context: DataSourceContext = DataSourceContext(),
|
|
52
52
|
):
|
|
53
53
|
"""Initialize a new SRTM instance.
|
|
54
54
|
|
|
@@ -57,10 +57,19 @@ class SRTM(DataSource):
|
|
|
57
57
|
NASA_EARTHDATA_USERNAME environment variable.
|
|
58
58
|
password: NASA Earthdata account password. If not set, it is read from the
|
|
59
59
|
NASA_EARTHDATA_PASSWORD environment variable.
|
|
60
|
-
band_name: what to call the band.
|
|
61
60
|
timeout: timeout for requests.
|
|
61
|
+
context: the data source context.
|
|
62
62
|
"""
|
|
63
|
-
|
|
63
|
+
# Get band name from context if possible, falling back to "srtm".
|
|
64
|
+
if context.layer_config is not None:
|
|
65
|
+
if len(context.layer_config.band_sets) != 1:
|
|
66
|
+
raise ValueError("expected a single band set")
|
|
67
|
+
if len(context.layer_config.band_sets[0].bands) != 1:
|
|
68
|
+
raise ValueError("expected band set to have a single band")
|
|
69
|
+
self.band_name = context.layer_config.band_sets[0].bands[0]
|
|
70
|
+
else:
|
|
71
|
+
self.band_name = "srtm"
|
|
72
|
+
|
|
64
73
|
self.timeout = timeout
|
|
65
74
|
|
|
66
75
|
if username is None:
|
|
@@ -73,33 +82,6 @@ class SRTM(DataSource):
|
|
|
73
82
|
|
|
74
83
|
self.session = requests.session()
|
|
75
84
|
|
|
76
|
-
@staticmethod
|
|
77
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "SRTM":
|
|
78
|
-
"""Creates a new SRTM instance from a configuration dictionary."""
|
|
79
|
-
if config.data_source is None:
|
|
80
|
-
raise ValueError("config.data_source is required")
|
|
81
|
-
d = config.data_source.config_dict
|
|
82
|
-
|
|
83
|
-
# Get the band name chosen by the user.
|
|
84
|
-
# There should be a single band set with a single band.
|
|
85
|
-
if len(config.band_sets) != 1:
|
|
86
|
-
raise ValueError("expected a single band set")
|
|
87
|
-
if len(config.band_sets[0].bands) != 1:
|
|
88
|
-
raise ValueError("expected band set to have a single band")
|
|
89
|
-
kwargs: dict[str, Any] = {
|
|
90
|
-
"band_name": config.band_sets[0].bands[0],
|
|
91
|
-
}
|
|
92
|
-
|
|
93
|
-
if "timeout_seconds" in d:
|
|
94
|
-
kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
|
|
95
|
-
|
|
96
|
-
simple_optionals = ["username", "password"]
|
|
97
|
-
for k in simple_optionals:
|
|
98
|
-
if k in d:
|
|
99
|
-
kwargs[k] = d[k]
|
|
100
|
-
|
|
101
|
-
return SRTM(**kwargs)
|
|
102
|
-
|
|
103
85
|
def get_item_by_name(self, name: str) -> Item:
|
|
104
86
|
"""Gets an item by name.
|
|
105
87
|
|
|
@@ -10,11 +10,10 @@ from typing import Any
|
|
|
10
10
|
import fiona
|
|
11
11
|
import requests
|
|
12
12
|
from rasterio.crs import CRS
|
|
13
|
-
from upath import UPath
|
|
14
13
|
|
|
15
|
-
from rslearn.config import QueryConfig
|
|
14
|
+
from rslearn.config import QueryConfig
|
|
16
15
|
from rslearn.const import WGS84_PROJECTION
|
|
17
|
-
from rslearn.data_sources import DataSource, Item
|
|
16
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
18
17
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
19
18
|
from rslearn.log_utils import get_logger
|
|
20
19
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
@@ -108,12 +107,9 @@ class EuroCrops(DataSource[EuroCropsItem]):
|
|
|
108
107
|
}
|
|
109
108
|
TIMEOUT = timedelta(seconds=10)
|
|
110
109
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if config.data_source is None:
|
|
115
|
-
raise ValueError("data_source is required")
|
|
116
|
-
return EuroCrops()
|
|
110
|
+
def __init__(self, context: DataSourceContext = DataSourceContext()):
|
|
111
|
+
"""Create a new EuroCrops."""
|
|
112
|
+
pass
|
|
117
113
|
|
|
118
114
|
def _get_all_items(self) -> list[EuroCropsItem]:
|
|
119
115
|
"""Get a list of all available items in the data source."""
|
|
@@ -18,10 +18,9 @@ import tqdm
|
|
|
18
18
|
from google.cloud import bigquery, storage
|
|
19
19
|
from upath import UPath
|
|
20
20
|
|
|
21
|
-
from rslearn.config import QueryConfig
|
|
21
|
+
from rslearn.config import QueryConfig
|
|
22
22
|
from rslearn.const import WGS84_PROJECTION
|
|
23
|
-
from rslearn.data_sources import DataSource, Item
|
|
24
|
-
from rslearn.data_sources.raster_source import is_raster_needed
|
|
23
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
25
24
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
26
25
|
from rslearn.log_utils import get_logger
|
|
27
26
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
@@ -157,19 +156,19 @@ class Sentinel2(DataSource):
|
|
|
157
156
|
|
|
158
157
|
def __init__(
|
|
159
158
|
self,
|
|
160
|
-
|
|
161
|
-
index_cache_dir: UPath,
|
|
159
|
+
index_cache_dir: str,
|
|
162
160
|
sort_by: str | None = None,
|
|
163
161
|
use_rtree_index: bool = True,
|
|
164
162
|
harmonize: bool = False,
|
|
165
163
|
rtree_time_range: tuple[datetime, datetime] | None = None,
|
|
166
|
-
rtree_cache_dir:
|
|
164
|
+
rtree_cache_dir: str | None = None,
|
|
167
165
|
use_bigquery: bool | None = None,
|
|
166
|
+
bands: list[str] | None = None,
|
|
167
|
+
context: DataSourceContext = DataSourceContext(),
|
|
168
168
|
):
|
|
169
169
|
"""Initialize a new Sentinel2 instance.
|
|
170
170
|
|
|
171
171
|
Args:
|
|
172
|
-
config: the LayerConfig of the layer containing this data source.
|
|
173
172
|
index_cache_dir: local directory to cache the index contents, as well as
|
|
174
173
|
individual product metadata files.
|
|
175
174
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
@@ -193,6 +192,9 @@ class Sentinel2(DataSource):
|
|
|
193
192
|
credentials, set use_bigquery=False and use_rtree_index=False. The
|
|
194
193
|
default value is None which enables BigQuery when use_rtree_index=True
|
|
195
194
|
and disables when use_rtree_index=False.
|
|
195
|
+
bands: the bands to download, or None to download all bands. This is only
|
|
196
|
+
used if the layer config is not in the context.
|
|
197
|
+
context: the data source context.
|
|
196
198
|
"""
|
|
197
199
|
if use_bigquery is None:
|
|
198
200
|
use_bigquery = use_rtree_index
|
|
@@ -201,22 +203,52 @@ class Sentinel2(DataSource):
|
|
|
201
203
|
"use_bigquery must be enabled if use_rtree_index is enabled"
|
|
202
204
|
)
|
|
203
205
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
+
# Resolve index_cache_dir and rtree_cache_dir depending on dataset context.
|
|
207
|
+
if context.ds_path is not None:
|
|
208
|
+
self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
|
|
209
|
+
else:
|
|
210
|
+
self.index_cache_dir = UPath(index_cache_dir)
|
|
211
|
+
|
|
212
|
+
if rtree_cache_dir is None:
|
|
213
|
+
self.rtree_cache_dir = self.index_cache_dir
|
|
214
|
+
elif context.ds_path is not None:
|
|
215
|
+
self.rtree_cache_dir = join_upath(context.ds_path, rtree_cache_dir)
|
|
216
|
+
else:
|
|
217
|
+
self.rtree_cache_dir = UPath(rtree_cache_dir)
|
|
218
|
+
|
|
206
219
|
self.sort_by = sort_by
|
|
207
220
|
self.harmonize = harmonize
|
|
208
221
|
self.use_bigquery = use_bigquery
|
|
209
222
|
|
|
210
223
|
self.index_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
211
224
|
|
|
225
|
+
# Determine the subset of bands that are needed based on the layer config.
|
|
226
|
+
self.needed_bands: list[tuple[str, list[str]]]
|
|
227
|
+
if context.layer_config is not None:
|
|
228
|
+
self.needed_bands = []
|
|
229
|
+
for fname, cur_bands in self.BANDS:
|
|
230
|
+
# See if the bands provided by this file intersect with the bands in at
|
|
231
|
+
# least one configured band set.
|
|
232
|
+
for band_set in context.layer_config.band_sets:
|
|
233
|
+
if not set(band_set.bands).intersection(cur_bands):
|
|
234
|
+
continue
|
|
235
|
+
self.needed_bands.append((fname, cur_bands))
|
|
236
|
+
break
|
|
237
|
+
elif bands is not None:
|
|
238
|
+
self.needed_bands = []
|
|
239
|
+
for fname, cur_bands in self.BANDS:
|
|
240
|
+
if not set(bands).intersection(cur_bands):
|
|
241
|
+
continue
|
|
242
|
+
self.needed_bands.append((fname, cur_bands))
|
|
243
|
+
else:
|
|
244
|
+
self.needed_bands = list(self.BANDS)
|
|
245
|
+
|
|
212
246
|
self.bucket = storage.Client.create_anonymous_client().bucket(self.BUCKET_NAME)
|
|
213
247
|
self.rtree_index: Any | None = None
|
|
214
248
|
if use_rtree_index:
|
|
215
249
|
from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
|
|
216
250
|
|
|
217
|
-
|
|
218
|
-
rtree_cache_dir = self.index_cache_dir
|
|
219
|
-
rtree_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
251
|
+
self.rtree_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
220
252
|
|
|
221
253
|
def build_fn(index: RtreeIndex) -> None:
|
|
222
254
|
"""Build the RtreeIndex from items in the data source."""
|
|
@@ -226,34 +258,7 @@ class Sentinel2(DataSource):
|
|
|
226
258
|
for shp in flatten_shape(item.geometry.shp):
|
|
227
259
|
index.insert(shp.bounds, json.dumps(item.serialize()))
|
|
228
260
|
|
|
229
|
-
self.rtree_index = get_cached_rtree(rtree_cache_dir, build_fn)
|
|
230
|
-
|
|
231
|
-
@staticmethod
|
|
232
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Sentinel2":
|
|
233
|
-
"""Creates a new Sentinel2 instance from a configuration dictionary."""
|
|
234
|
-
if config.data_source is None:
|
|
235
|
-
raise ValueError("config.data_source is required")
|
|
236
|
-
d = config.data_source.config_dict
|
|
237
|
-
kwargs = dict(
|
|
238
|
-
config=config,
|
|
239
|
-
index_cache_dir=join_upath(ds_path, d["index_cache_dir"]),
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
if "rtree_time_range" in d:
|
|
243
|
-
kwargs["rtree_time_range"] = (
|
|
244
|
-
datetime.fromisoformat(d["rtree_time_range"][0]),
|
|
245
|
-
datetime.fromisoformat(d["rtree_time_range"][1]),
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
if "rtree_cache_dir" in d:
|
|
249
|
-
kwargs["rtree_cache_dir"] = join_upath(ds_path, d["rtree_cache_dir"])
|
|
250
|
-
|
|
251
|
-
simple_optionals = ["sort_by", "use_rtree_index", "harmonize", "use_bigquery"]
|
|
252
|
-
for k in simple_optionals:
|
|
253
|
-
if k in d:
|
|
254
|
-
kwargs[k] = d[k]
|
|
255
|
-
|
|
256
|
-
return Sentinel2(**kwargs)
|
|
261
|
+
self.rtree_index = get_cached_rtree(self.rtree_cache_dir, build_fn)
|
|
257
262
|
|
|
258
263
|
def _read_bigquery(
|
|
259
264
|
self,
|
|
@@ -833,9 +838,7 @@ class Sentinel2(DataSource):
|
|
|
833
838
|
geometries: a list of geometries needed for each item
|
|
834
839
|
"""
|
|
835
840
|
for item in items:
|
|
836
|
-
for suffix, band_names in self.
|
|
837
|
-
if not is_raster_needed(band_names, self.config.band_sets):
|
|
838
|
-
continue
|
|
841
|
+
for suffix, band_names in self.needed_bands:
|
|
839
842
|
if tile_store.is_raster_ready(item.name, band_names):
|
|
840
843
|
continue
|
|
841
844
|
|
|
@@ -20,7 +20,7 @@ from google.cloud import storage
|
|
|
20
20
|
from upath import UPath
|
|
21
21
|
|
|
22
22
|
import rslearn.data_sources.utils
|
|
23
|
-
from rslearn.config import DType, LayerConfig
|
|
23
|
+
from rslearn.config import DType, LayerConfig
|
|
24
24
|
from rslearn.const import WGS84_PROJECTION
|
|
25
25
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
26
26
|
from rslearn.dataset.window import Window
|
|
@@ -36,7 +36,7 @@ from rslearn.utils.raster_format import (
|
|
|
36
36
|
)
|
|
37
37
|
from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
|
|
38
38
|
|
|
39
|
-
from .data_source import DataSource, Item, QueryConfig
|
|
39
|
+
from .data_source import DataSource, DataSourceContext, Item, QueryConfig
|
|
40
40
|
|
|
41
41
|
logger = get_logger(__name__)
|
|
42
42
|
|
|
@@ -61,34 +61,55 @@ class GEE(DataSource, TileStore):
|
|
|
61
61
|
self,
|
|
62
62
|
collection_name: str,
|
|
63
63
|
gcs_bucket_name: str,
|
|
64
|
-
|
|
65
|
-
index_cache_dir: UPath,
|
|
64
|
+
index_cache_dir: str,
|
|
66
65
|
service_account_name: str,
|
|
67
66
|
service_account_credentials: str,
|
|
67
|
+
bands: list[str] | None = None,
|
|
68
68
|
filters: list[tuple[str, Any]] | None = None,
|
|
69
69
|
dtype: DType | None = None,
|
|
70
|
+
context: DataSourceContext = DataSourceContext(),
|
|
70
71
|
) -> None:
|
|
71
72
|
"""Initialize a new GEE instance.
|
|
72
73
|
|
|
73
74
|
Args:
|
|
74
75
|
collection_name: the Earth Engine ImageCollection to ingest images from
|
|
75
76
|
gcs_bucket_name: the Cloud Storage bucket to export GEE images to
|
|
76
|
-
bands: the list of bands to ingest
|
|
77
77
|
index_cache_dir: cache directory to store rtree index
|
|
78
78
|
service_account_name: name of the service account to use for authentication
|
|
79
79
|
service_account_credentials: service account credentials filename
|
|
80
|
+
bands: the list of bands to ingest, in case the layer config is not present
|
|
81
|
+
in the context.
|
|
80
82
|
filters: optional list of tuples (property_name, property_value) to filter
|
|
81
83
|
images (using ee.Filter.eq)
|
|
82
84
|
dtype: optional desired array data type. If the data obtained from GEE does
|
|
83
85
|
not match this type, then it is converted.
|
|
86
|
+
context: the data source context.
|
|
84
87
|
"""
|
|
85
88
|
self.collection_name = collection_name
|
|
86
89
|
self.gcs_bucket_name = gcs_bucket_name
|
|
87
|
-
self.bands = bands
|
|
88
|
-
self.index_cache_dir = index_cache_dir
|
|
89
90
|
self.filters = filters
|
|
90
91
|
self.dtype = dtype
|
|
91
92
|
|
|
93
|
+
# Get index cache dir depending on dataset path.
|
|
94
|
+
if context.ds_path is not None:
|
|
95
|
+
self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
|
|
96
|
+
else:
|
|
97
|
+
self.index_cache_dir = UPath(index_cache_dir)
|
|
98
|
+
|
|
99
|
+
# Get bands we need to export.
|
|
100
|
+
if context.layer_config is not None:
|
|
101
|
+
self.bands = [
|
|
102
|
+
band
|
|
103
|
+
for band_set in context.layer_config.band_sets
|
|
104
|
+
for band in band_set.bands
|
|
105
|
+
]
|
|
106
|
+
elif bands is not None:
|
|
107
|
+
self.bands = bands
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"bands must be specified if layer_config is not present in the context"
|
|
111
|
+
)
|
|
112
|
+
|
|
92
113
|
self.bucket = storage.Client().bucket(self.gcs_bucket_name)
|
|
93
114
|
|
|
94
115
|
credentials = ee.ServiceAccountCredentials(
|
|
@@ -99,27 +120,6 @@ class GEE(DataSource, TileStore):
|
|
|
99
120
|
self.index_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
100
121
|
self.rtree_index = get_cached_rtree(self.index_cache_dir, self._build_index)
|
|
101
122
|
|
|
102
|
-
@staticmethod
|
|
103
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "GEE":
|
|
104
|
-
"""Creates a new GEE instance from a configuration dictionary."""
|
|
105
|
-
if config.data_source is None:
|
|
106
|
-
raise ValueError("data_source is required in config")
|
|
107
|
-
d = config.data_source.config_dict
|
|
108
|
-
bands = [band for band_set in config.band_sets for band in band_set.bands]
|
|
109
|
-
kwargs = {
|
|
110
|
-
"collection_name": d["collection_name"],
|
|
111
|
-
"gcs_bucket_name": d["gcs_bucket_name"],
|
|
112
|
-
"bands": bands,
|
|
113
|
-
"service_account_name": d["service_account_name"],
|
|
114
|
-
"service_account_credentials": d["service_account_credentials"],
|
|
115
|
-
"filters": d.get("filters"),
|
|
116
|
-
"index_cache_dir": join_upath(ds_path, d["index_cache_dir"]),
|
|
117
|
-
}
|
|
118
|
-
if "dtype" in d:
|
|
119
|
-
kwargs["dtype"] = DType(d["dtype"])
|
|
120
|
-
|
|
121
|
-
return GEE(**kwargs)
|
|
122
|
-
|
|
123
123
|
def get_collection(self) -> ee.ImageCollection:
|
|
124
124
|
"""Returns the Earth Engine image collection for this data source."""
|
|
125
125
|
image_collection = ee.ImageCollection(self.collection_name)
|
|
@@ -578,7 +578,6 @@ class GEE(DataSource, TileStore):
|
|
|
578
578
|
layer_name: the name of this layer
|
|
579
579
|
layer_cfg: the config of this layer
|
|
580
580
|
"""
|
|
581
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
582
581
|
RasterMaterializer().materialize(
|
|
583
582
|
TileStoreWithLayer(self, layer_name),
|
|
584
583
|
window,
|
|
@@ -600,9 +599,10 @@ class GoogleSatelliteEmbeddings(GEE):
|
|
|
600
599
|
def __init__(
|
|
601
600
|
self,
|
|
602
601
|
gcs_bucket_name: str,
|
|
603
|
-
index_cache_dir:
|
|
602
|
+
index_cache_dir: str,
|
|
604
603
|
service_account_name: str,
|
|
605
604
|
service_account_credentials: str,
|
|
605
|
+
context: DataSourceContext = DataSourceContext(),
|
|
606
606
|
):
|
|
607
607
|
"""Create a new GoogleSatelliteEmbeddings. See GEE for the arguments."""
|
|
608
608
|
super().__init__(
|
|
@@ -612,22 +612,9 @@ class GoogleSatelliteEmbeddings(GEE):
|
|
|
612
612
|
index_cache_dir=index_cache_dir,
|
|
613
613
|
service_account_name=service_account_name,
|
|
614
614
|
service_account_credentials=service_account_credentials,
|
|
615
|
+
context=context,
|
|
615
616
|
)
|
|
616
617
|
|
|
617
|
-
@staticmethod
|
|
618
|
-
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "GEE":
|
|
619
|
-
"""Creates a new GEE instance from a configuration dictionary."""
|
|
620
|
-
if config.data_source is None:
|
|
621
|
-
raise ValueError("data_source is required in config")
|
|
622
|
-
d = config.data_source.config_dict
|
|
623
|
-
kwargs = {
|
|
624
|
-
"gcs_bucket_name": d["gcs_bucket_name"],
|
|
625
|
-
"index_cache_dir": join_upath(ds_path, d["index_cache_dir"]),
|
|
626
|
-
"service_account_name": d["service_account_name"],
|
|
627
|
-
"service_account_credentials": d["service_account_credentials"],
|
|
628
|
-
}
|
|
629
|
-
return GoogleSatelliteEmbeddings(**kwargs)
|
|
630
|
-
|
|
631
618
|
# Override to add conversion to uint16.
|
|
632
619
|
def item_to_image(self, item: Item) -> ee.image.Image:
|
|
633
620
|
"""Get the Image corresponding to the Item."""
|