rslearn 0.0.21__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/aws_open_data.py +11 -15
- rslearn/data_sources/aws_sentinel2_element84.py +374 -0
- rslearn/data_sources/gcp_public_data.py +16 -0
- rslearn/data_sources/planetary_computer.py +28 -257
- rslearn/data_sources/soilgrids.py +331 -0
- rslearn/data_sources/stac.py +255 -0
- rslearn/models/attention_pooling.py +5 -2
- rslearn/train/lightning_module.py +3 -3
- rslearn/train/tasks/embedding.py +2 -2
- rslearn/train/tasks/task.py +4 -2
- rslearn/utils/geometry.py +2 -2
- rslearn/utils/stac.py +173 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.22.dist-info}/METADATA +4 -1
- {rslearn-0.0.21.dist-info → rslearn-0.0.22.dist-info}/RECORD +19 -15
- {rslearn-0.0.21.dist-info → rslearn-0.0.22.dist-info}/WHEEL +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.22.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.22.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.22.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.22.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Data on Planetary Computer."""
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
import os
|
|
5
4
|
import tempfile
|
|
6
5
|
import xml.etree.ElementTree as ET
|
|
@@ -10,18 +9,14 @@ from typing import Any
|
|
|
10
9
|
import affine
|
|
11
10
|
import numpy.typing as npt
|
|
12
11
|
import planetary_computer
|
|
13
|
-
import pystac
|
|
14
|
-
import pystac_client
|
|
15
12
|
import rasterio
|
|
16
13
|
import requests
|
|
17
|
-
import shapely
|
|
18
14
|
from rasterio.enums import Resampling
|
|
19
15
|
from upath import UPath
|
|
20
16
|
|
|
21
|
-
from rslearn.config import LayerConfig
|
|
22
|
-
from rslearn.
|
|
23
|
-
from rslearn.data_sources import
|
|
24
|
-
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
17
|
+
from rslearn.config import LayerConfig
|
|
18
|
+
from rslearn.data_sources import DataSourceContext
|
|
19
|
+
from rslearn.data_sources.stac import SourceItem, StacDataSource
|
|
25
20
|
from rslearn.dataset import Window
|
|
26
21
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
27
22
|
from rslearn.log_utils import get_logger
|
|
@@ -35,38 +30,7 @@ from .copernicus import get_harmonize_callback
|
|
|
35
30
|
logger = get_logger(__name__)
|
|
36
31
|
|
|
37
32
|
|
|
38
|
-
class
|
|
39
|
-
"""An item in the PlanetaryComputer data source."""
|
|
40
|
-
|
|
41
|
-
def __init__(self, name: str, geometry: STGeometry, asset_urls: dict[str, str]):
|
|
42
|
-
"""Creates a new PlanetaryComputerItem.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
name: unique name of the item
|
|
46
|
-
geometry: the spatial and temporal extent of the item
|
|
47
|
-
asset_urls: map from asset key to the unsigned asset URL.
|
|
48
|
-
"""
|
|
49
|
-
super().__init__(name, geometry)
|
|
50
|
-
self.asset_urls = asset_urls
|
|
51
|
-
|
|
52
|
-
def serialize(self) -> dict[str, Any]:
|
|
53
|
-
"""Serializes the item to a JSON-encodable dictionary."""
|
|
54
|
-
d = super().serialize()
|
|
55
|
-
d["asset_urls"] = self.asset_urls
|
|
56
|
-
return d
|
|
57
|
-
|
|
58
|
-
@staticmethod
|
|
59
|
-
def deserialize(d: dict[str, Any]) -> "PlanetaryComputerItem":
|
|
60
|
-
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
61
|
-
item = super(PlanetaryComputerItem, PlanetaryComputerItem).deserialize(d)
|
|
62
|
-
return PlanetaryComputerItem(
|
|
63
|
-
name=item.name,
|
|
64
|
-
geometry=item.geometry,
|
|
65
|
-
asset_urls=d["asset_urls"],
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
class PlanetaryComputer(DataSource, TileStore):
|
|
33
|
+
class PlanetaryComputer(StacDataSource, TileStore):
|
|
70
34
|
"""Modality-agnostic data source for data on Microsoft Planetary Computer.
|
|
71
35
|
|
|
72
36
|
If there is a subclass available for a modality, it is recommended to use the
|
|
@@ -83,10 +47,6 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
83
47
|
|
|
84
48
|
STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1"
|
|
85
49
|
|
|
86
|
-
# Default threshold for recreating the STAC client to prevent memory leaks
|
|
87
|
-
# from the pystac Catalog's resolved objects cache growing unbounded
|
|
88
|
-
DEFAULT_MAX_ITEMS_PER_CLIENT = 1000
|
|
89
|
-
|
|
90
50
|
def __init__(
|
|
91
51
|
self,
|
|
92
52
|
collection_name: str,
|
|
@@ -97,7 +57,6 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
97
57
|
timeout: timedelta = timedelta(seconds=10),
|
|
98
58
|
skip_items_missing_assets: bool = False,
|
|
99
59
|
cache_dir: str | None = None,
|
|
100
|
-
max_items_per_client: int | None = None,
|
|
101
60
|
context: DataSourceContext = DataSourceContext(),
|
|
102
61
|
):
|
|
103
62
|
"""Initialize a new PlanetaryComputer instance.
|
|
@@ -115,228 +74,40 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
115
74
|
cache_dir: optional directory to cache items by name, including asset URLs.
|
|
116
75
|
If not set, there will be no cache and instead STAC requests will be
|
|
117
76
|
needed each time.
|
|
118
|
-
max_items_per_client: number of STAC items to process before recreating
|
|
119
|
-
the client to prevent memory leaks from the resolved objects cache.
|
|
120
|
-
Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
|
|
121
77
|
context: the data source context.
|
|
122
78
|
"""
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
self.query = query
|
|
126
|
-
self.sort_by = sort_by
|
|
127
|
-
self.sort_ascending = sort_ascending
|
|
128
|
-
self.timeout = timeout
|
|
129
|
-
self.skip_items_missing_assets = skip_items_missing_assets
|
|
130
|
-
self.max_items_per_client = (
|
|
131
|
-
max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
|
|
132
|
-
)
|
|
133
|
-
|
|
79
|
+
# Determine the cache_dir to use.
|
|
80
|
+
cache_upath: UPath | None = None
|
|
134
81
|
if cache_dir is not None:
|
|
135
82
|
if context.ds_path is not None:
|
|
136
|
-
|
|
83
|
+
cache_upath = join_upath(context.ds_path, cache_dir)
|
|
137
84
|
else:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
141
|
-
else:
|
|
142
|
-
self.cache_dir = None
|
|
85
|
+
cache_upath = UPath(cache_dir)
|
|
143
86
|
|
|
144
|
-
|
|
145
|
-
self._client_item_count = 0
|
|
87
|
+
cache_upath.mkdir(parents=True, exist_ok=True)
|
|
146
88
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
We don't load it when creating the data source because it takes time and caller
|
|
153
|
-
may not be calling get_items. Additionally, loading it during the get_items
|
|
154
|
-
call enables leveraging the retry loop functionality in
|
|
155
|
-
prepare_dataset_windows.
|
|
89
|
+
# We pass required_assets to StacDataSource of skip_items_missing_assets is set.
|
|
90
|
+
required_assets: list[str] | None = None
|
|
91
|
+
if skip_items_missing_assets:
|
|
92
|
+
required_assets = list(asset_bands.keys())
|
|
156
93
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
if self._client_item_count < self.max_items_per_client:
|
|
167
|
-
return self.client
|
|
168
|
-
|
|
169
|
-
# Recreate client to clear the resolved objects cache
|
|
170
|
-
current_client = self.client
|
|
171
|
-
logger.debug(
|
|
172
|
-
"Recreating STAC client after processing %d items (threshold: %d)",
|
|
173
|
-
self._client_item_count,
|
|
174
|
-
self.max_items_per_client,
|
|
94
|
+
super().__init__(
|
|
95
|
+
endpoint=self.STAC_ENDPOINT,
|
|
96
|
+
collection_name=collection_name,
|
|
97
|
+
query=query,
|
|
98
|
+
sort_by=sort_by,
|
|
99
|
+
sort_ascending=sort_ascending,
|
|
100
|
+
required_assets=required_assets,
|
|
101
|
+
cache_dir=cache_upath,
|
|
175
102
|
)
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
client_root.clear_children()
|
|
180
|
-
self._client_item_count = 0
|
|
181
|
-
self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
|
|
182
|
-
return self.client
|
|
183
|
-
|
|
184
|
-
def _stac_item_to_item(self, stac_item: pystac.Item) -> PlanetaryComputerItem:
|
|
185
|
-
shp = shapely.geometry.shape(stac_item.geometry)
|
|
186
|
-
|
|
187
|
-
# Get time range.
|
|
188
|
-
metadata = stac_item.common_metadata
|
|
189
|
-
if metadata.start_datetime is not None and metadata.end_datetime is not None:
|
|
190
|
-
time_range = (
|
|
191
|
-
metadata.start_datetime,
|
|
192
|
-
metadata.end_datetime,
|
|
193
|
-
)
|
|
194
|
-
elif stac_item.datetime is not None:
|
|
195
|
-
time_range = (stac_item.datetime, stac_item.datetime)
|
|
196
|
-
else:
|
|
197
|
-
raise ValueError(
|
|
198
|
-
f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime"
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
geom = STGeometry(WGS84_PROJECTION, shp, time_range)
|
|
202
|
-
asset_urls = {
|
|
203
|
-
asset_key: asset_obj.href
|
|
204
|
-
for asset_key, asset_obj in stac_item.assets.items()
|
|
205
|
-
}
|
|
206
|
-
return PlanetaryComputerItem(stac_item.id, geom, asset_urls)
|
|
207
|
-
|
|
208
|
-
def get_item_by_name(self, name: str) -> PlanetaryComputerItem:
|
|
209
|
-
"""Gets an item by name.
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
name: the name of the item to get
|
|
213
|
-
|
|
214
|
-
Returns:
|
|
215
|
-
the item object
|
|
216
|
-
"""
|
|
217
|
-
# If cache_dir is set, we cache the item. First here we check if it is already
|
|
218
|
-
# in the cache.
|
|
219
|
-
cache_fname: UPath | None = None
|
|
220
|
-
if self.cache_dir:
|
|
221
|
-
cache_fname = self.cache_dir / f"{name}.json"
|
|
222
|
-
if cache_fname is not None and cache_fname.exists():
|
|
223
|
-
with cache_fname.open() as f:
|
|
224
|
-
return PlanetaryComputerItem.deserialize(json.load(f))
|
|
225
|
-
|
|
226
|
-
# No cache or not in cache, so we need to make the STAC request.
|
|
227
|
-
logger.debug("Getting STAC item {name}")
|
|
228
|
-
client = self._load_client()
|
|
229
|
-
|
|
230
|
-
search_result = client.search(ids=[name], collections=[self.collection_name])
|
|
231
|
-
stac_items = list(search_result.items())
|
|
232
|
-
|
|
233
|
-
if not stac_items:
|
|
234
|
-
raise ValueError(
|
|
235
|
-
f"Item {name} not found in collection {self.collection_name}"
|
|
236
|
-
)
|
|
237
|
-
if len(stac_items) > 1:
|
|
238
|
-
raise ValueError(
|
|
239
|
-
f"Multiple items found for ID {name} in collection {self.collection_name}"
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
stac_item = stac_items[0]
|
|
243
|
-
item = self._stac_item_to_item(stac_item)
|
|
244
|
-
|
|
245
|
-
# Track items processed for client recreation threshold (after deserialization)
|
|
246
|
-
self._client_item_count += 1
|
|
247
|
-
|
|
248
|
-
# Finally we cache it if cache_dir is set.
|
|
249
|
-
if cache_fname is not None:
|
|
250
|
-
with cache_fname.open("w") as f:
|
|
251
|
-
json.dump(item.serialize(), f)
|
|
252
|
-
|
|
253
|
-
return item
|
|
254
|
-
|
|
255
|
-
def get_items(
|
|
256
|
-
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
257
|
-
) -> list[list[list[PlanetaryComputerItem]]]:
|
|
258
|
-
"""Get a list of items in the data source intersecting the given geometries.
|
|
259
|
-
|
|
260
|
-
Args:
|
|
261
|
-
geometries: the spatiotemporal geometries
|
|
262
|
-
query_config: the query configuration
|
|
263
|
-
|
|
264
|
-
Returns:
|
|
265
|
-
List of groups of items that should be retrieved for each geometry.
|
|
266
|
-
"""
|
|
267
|
-
client = self._load_client()
|
|
268
|
-
|
|
269
|
-
groups = []
|
|
270
|
-
for geometry in geometries:
|
|
271
|
-
# Get potentially relevant items from the collection by performing one search
|
|
272
|
-
# for each requested geometry.
|
|
273
|
-
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
274
|
-
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
|
|
275
|
-
result = client.search(
|
|
276
|
-
collections=[self.collection_name],
|
|
277
|
-
intersects=shapely.to_geojson(wgs84_geometry.shp),
|
|
278
|
-
datetime=wgs84_geometry.time_range,
|
|
279
|
-
query=self.query,
|
|
280
|
-
)
|
|
281
|
-
stac_items = [item for item in result.items()]
|
|
282
|
-
# Track items processed for client recreation threshold (after deserialization)
|
|
283
|
-
self._client_item_count += len(stac_items)
|
|
284
|
-
logger.debug("STAC search yielded %d items", len(stac_items))
|
|
285
|
-
|
|
286
|
-
if self.skip_items_missing_assets:
|
|
287
|
-
# Filter out items that are missing any of the assets in self.asset_bands.
|
|
288
|
-
good_stac_items = []
|
|
289
|
-
for stac_item in stac_items:
|
|
290
|
-
good = True
|
|
291
|
-
for asset_key in self.asset_bands.keys():
|
|
292
|
-
if asset_key in stac_item.assets:
|
|
293
|
-
continue
|
|
294
|
-
good = False
|
|
295
|
-
break
|
|
296
|
-
if good:
|
|
297
|
-
good_stac_items.append(stac_item)
|
|
298
|
-
logger.debug(
|
|
299
|
-
"skip_items_missing_assets filter from %d to %d items",
|
|
300
|
-
len(stac_items),
|
|
301
|
-
len(good_stac_items),
|
|
302
|
-
)
|
|
303
|
-
stac_items = good_stac_items
|
|
304
|
-
|
|
305
|
-
if self.sort_by is not None:
|
|
306
|
-
stac_items.sort(
|
|
307
|
-
key=lambda stac_item: stac_item.properties[self.sort_by],
|
|
308
|
-
reverse=not self.sort_ascending,
|
|
309
|
-
)
|
|
310
|
-
|
|
311
|
-
candidate_items = [
|
|
312
|
-
self._stac_item_to_item(stac_item) for stac_item in stac_items
|
|
313
|
-
]
|
|
314
|
-
|
|
315
|
-
# Since we made the STAC request, might as well save these to the cache.
|
|
316
|
-
if self.cache_dir is not None:
|
|
317
|
-
for item in candidate_items:
|
|
318
|
-
cache_fname = self.cache_dir / f"{item.name}.json"
|
|
319
|
-
if cache_fname.exists():
|
|
320
|
-
continue
|
|
321
|
-
with cache_fname.open("w") as f:
|
|
322
|
-
json.dump(item.serialize(), f)
|
|
323
|
-
|
|
324
|
-
cur_groups = match_candidate_items_to_window(
|
|
325
|
-
geometry, candidate_items, query_config
|
|
326
|
-
)
|
|
327
|
-
groups.append(cur_groups)
|
|
328
|
-
|
|
329
|
-
return groups
|
|
330
|
-
|
|
331
|
-
def deserialize_item(self, serialized_item: Any) -> PlanetaryComputerItem:
|
|
332
|
-
"""Deserializes an item from JSON-decoded data."""
|
|
333
|
-
assert isinstance(serialized_item, dict)
|
|
334
|
-
return PlanetaryComputerItem.deserialize(serialized_item)
|
|
103
|
+
self.asset_bands = asset_bands
|
|
104
|
+
self.timeout = timeout
|
|
105
|
+
self.skip_items_missing_assets = skip_items_missing_assets
|
|
335
106
|
|
|
336
107
|
def ingest(
|
|
337
108
|
self,
|
|
338
109
|
tile_store: TileStoreWithLayer,
|
|
339
|
-
items: list[
|
|
110
|
+
items: list[SourceItem],
|
|
340
111
|
geometries: list[list[STGeometry]],
|
|
341
112
|
) -> None:
|
|
342
113
|
"""Ingest items into the given tile store.
|
|
@@ -512,7 +283,7 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
512
283
|
def materialize(
|
|
513
284
|
self,
|
|
514
285
|
window: Window,
|
|
515
|
-
item_groups: list[list[
|
|
286
|
+
item_groups: list[list[SourceItem]],
|
|
516
287
|
layer_name: str,
|
|
517
288
|
layer_cfg: LayerConfig,
|
|
518
289
|
) -> None:
|
|
@@ -601,7 +372,7 @@ class Sentinel2(PlanetaryComputer):
|
|
|
601
372
|
**kwargs,
|
|
602
373
|
)
|
|
603
374
|
|
|
604
|
-
def _get_product_xml(self, item:
|
|
375
|
+
def _get_product_xml(self, item: SourceItem) -> ET.Element:
|
|
605
376
|
asset_url = planetary_computer.sign(item.asset_urls["product-metadata"])
|
|
606
377
|
response = requests.get(asset_url, timeout=self.timeout.total_seconds())
|
|
607
378
|
response.raise_for_status()
|
|
@@ -610,7 +381,7 @@ class Sentinel2(PlanetaryComputer):
|
|
|
610
381
|
def ingest(
|
|
611
382
|
self,
|
|
612
383
|
tile_store: TileStoreWithLayer,
|
|
613
|
-
items: list[
|
|
384
|
+
items: list[SourceItem],
|
|
614
385
|
geometries: list[list[STGeometry]],
|
|
615
386
|
) -> None:
|
|
616
387
|
"""Ingest items into the given tile store.
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Data source for SoilGrids via the `soilgrids` Python package.
|
|
2
|
+
|
|
3
|
+
This source is intended to be used with `ingest: false` (direct materialization),
|
|
4
|
+
since data is fetched on-demand per window.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import tempfile
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
import rasterio
|
|
15
|
+
import rasterio.warp
|
|
16
|
+
import shapely
|
|
17
|
+
from rasterio.crs import CRS
|
|
18
|
+
from rasterio.enums import Resampling
|
|
19
|
+
from upath import UPath
|
|
20
|
+
|
|
21
|
+
from rslearn.config import LayerConfig, QueryConfig
|
|
22
|
+
from rslearn.dataset import Window
|
|
23
|
+
from rslearn.dataset.materialize import RasterMaterializer
|
|
24
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
25
|
+
from rslearn.utils import PixelBounds, Projection, STGeometry
|
|
26
|
+
from rslearn.utils.geometry import get_global_geometry
|
|
27
|
+
from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
|
|
28
|
+
|
|
29
|
+
from .data_source import DataSource, DataSourceContext, Item
|
|
30
|
+
from .utils import match_candidate_items_to_window
|
|
31
|
+
|
|
32
|
+
SOILGRIDS_NODATA_VALUE = -32768.0
|
|
33
|
+
"""Default nodata value used by SoilGrids GeoTIFF responses (GEOTIFF_INT16)."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _crs_to_rasterio(crs: str) -> CRS:
|
|
37
|
+
"""Best-effort conversion of CRS strings used by `soilgrids` to rasterio CRS."""
|
|
38
|
+
try:
|
|
39
|
+
return CRS.from_string(crs)
|
|
40
|
+
except Exception:
|
|
41
|
+
# Fallback: if rasterio can't parse the string but it contains an EPSG code,
|
|
42
|
+
# extract the trailing integer and build a CRS from it.
|
|
43
|
+
parts = [p for p in crs.replace(":", " ").split() if p.isdigit()]
|
|
44
|
+
if parts:
|
|
45
|
+
return CRS.from_epsg(int(parts[-1]))
|
|
46
|
+
raise
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _crs_to_soilgrids_urn(crs: str) -> str:
|
|
50
|
+
"""Convert common CRS spellings to the URN form expected by `soilgrids`.
|
|
51
|
+
|
|
52
|
+
The `soilgrids` package compares CRS strings against the supported CRS URNs from
|
|
53
|
+
OWSLib (e.g. "urn:ogc:def:crs:EPSG::3857"). This helper allows users to specify
|
|
54
|
+
simpler forms like "EPSG:3857" while still working.
|
|
55
|
+
"""
|
|
56
|
+
s = crs.strip()
|
|
57
|
+
|
|
58
|
+
# If already an EPSG URN, canonicalize to the form soilgrids expects.
|
|
59
|
+
if s.lower().startswith("urn:ogc:def:crs:") and "epsg" in s.lower():
|
|
60
|
+
parts = [p for p in s.replace(":", " ").split() if p.isdigit()]
|
|
61
|
+
if parts:
|
|
62
|
+
return f"urn:ogc:def:crs:EPSG::{parts[-1]}"
|
|
63
|
+
return s
|
|
64
|
+
|
|
65
|
+
# Accept "EPSG:3857", "epsg:3857", or other strings containing an EPSG code.
|
|
66
|
+
if "epsg" in s.lower():
|
|
67
|
+
parts = [p for p in s.replace(":", " ").split() if p.isdigit()]
|
|
68
|
+
if parts:
|
|
69
|
+
return f"urn:ogc:def:crs:EPSG::{parts[-1]}"
|
|
70
|
+
|
|
71
|
+
return s
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class SoilGrids(DataSource, TileStore):
|
|
75
|
+
"""Access SoilGrids coverages as an rslearn raster data source."""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
service_id: str,
|
|
80
|
+
coverage_id: str,
|
|
81
|
+
crs: str = "EPSG:3857",
|
|
82
|
+
width: int | None = None,
|
|
83
|
+
height: int | None = None,
|
|
84
|
+
resx: float | None = None,
|
|
85
|
+
resy: float | None = None,
|
|
86
|
+
response_crs: str | None = None,
|
|
87
|
+
band_names: list[str] = ["B1"],
|
|
88
|
+
context: DataSourceContext = DataSourceContext(),
|
|
89
|
+
):
|
|
90
|
+
"""Create a new SoilGrids data source.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
service_id: SoilGrids map service id (e.g., "clay", "phh2o").
|
|
94
|
+
coverage_id: coverage id within the service (e.g., "clay_0-5cm_mean").
|
|
95
|
+
crs: request CRS string passed through to `soilgrids.SoilGrids`, typically
|
|
96
|
+
a URN like "urn:ogc:def:crs:EPSG::4326" or "urn:ogc:def:crs:EPSG::152160".
|
|
97
|
+
width: optional WCS WIDTH parameter. Required by SoilGrids WCS when CRS is
|
|
98
|
+
EPSG:4326.
|
|
99
|
+
height: optional WCS HEIGHT parameter.
|
|
100
|
+
resx: optional WCS RESX parameter (projection units / pixel).
|
|
101
|
+
resy: optional WCS RESY parameter (projection units / pixel).
|
|
102
|
+
response_crs: optional response CRS (defaults to `crs`).
|
|
103
|
+
band_names: band names exposed to rslearn. For a single coverage, this
|
|
104
|
+
should have length 1.
|
|
105
|
+
context: rslearn data source context.
|
|
106
|
+
"""
|
|
107
|
+
if len(band_names) != 1:
|
|
108
|
+
raise ValueError("SoilGrids currently supports only single-band coverages")
|
|
109
|
+
if (width is None) != (height is None):
|
|
110
|
+
raise ValueError("width and height must be specified together")
|
|
111
|
+
if (resx is None) != (resy is None):
|
|
112
|
+
raise ValueError("resx and resy must be specified together")
|
|
113
|
+
if width is not None and resx is not None:
|
|
114
|
+
raise ValueError("specify either width/height or resx/resy, not both")
|
|
115
|
+
|
|
116
|
+
self.service_id = service_id
|
|
117
|
+
self.coverage_id = coverage_id
|
|
118
|
+
self.crs = crs
|
|
119
|
+
self.width = width
|
|
120
|
+
self.height = height
|
|
121
|
+
self.resx = resx
|
|
122
|
+
self.resy = resy
|
|
123
|
+
self.response_crs = response_crs
|
|
124
|
+
self.band_names = band_names
|
|
125
|
+
|
|
126
|
+
# Represent the coverage as a single item that matches all windows.
|
|
127
|
+
item_name = f"{self.service_id}:{self.coverage_id}"
|
|
128
|
+
self._items = [Item(item_name, get_global_geometry(time_range=None))]
|
|
129
|
+
|
|
130
|
+
def get_items(
|
|
131
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
132
|
+
) -> list[list[list[Item]]]:
|
|
133
|
+
"""Get item groups matching each requested geometry."""
|
|
134
|
+
groups = []
|
|
135
|
+
for geometry in geometries:
|
|
136
|
+
cur_groups = match_candidate_items_to_window(
|
|
137
|
+
geometry, self._items, query_config
|
|
138
|
+
)
|
|
139
|
+
groups.append(cur_groups)
|
|
140
|
+
return groups
|
|
141
|
+
|
|
142
|
+
def deserialize_item(self, serialized_item: Any) -> Item:
|
|
143
|
+
"""Deserialize an item from JSON-decoded data."""
|
|
144
|
+
return Item.deserialize(serialized_item)
|
|
145
|
+
|
|
146
|
+
def ingest(
|
|
147
|
+
self,
|
|
148
|
+
tile_store: TileStoreWithLayer,
|
|
149
|
+
items: list[Item],
|
|
150
|
+
geometries: list[list[STGeometry]],
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Ingest is not supported (direct materialization only)."""
|
|
153
|
+
raise NotImplementedError(
|
|
154
|
+
"SoilGrids is intended for direct materialization; set data_source.ingest=false."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def is_raster_ready(
|
|
158
|
+
self, layer_name: str, item_name: str, bands: list[str]
|
|
159
|
+
) -> bool:
|
|
160
|
+
"""Return whether the requested raster is ready (always true for direct reads)."""
|
|
161
|
+
return True
|
|
162
|
+
|
|
163
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
164
|
+
"""Return the band sets available for this coverage."""
|
|
165
|
+
return [self.band_names]
|
|
166
|
+
|
|
167
|
+
def get_raster_bounds(
|
|
168
|
+
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
169
|
+
) -> PixelBounds:
|
|
170
|
+
"""Return (approximate) bounds for this raster in the requested projection."""
|
|
171
|
+
# We don't know bounds without an extra metadata request; treat as "very large"
|
|
172
|
+
# so materialization always attempts reads for windows.
|
|
173
|
+
return (-(10**9), -(10**9), 10**9, 10**9)
|
|
174
|
+
|
|
175
|
+
def _download_geotiff(
|
|
176
|
+
self,
|
|
177
|
+
west: float,
|
|
178
|
+
south: float,
|
|
179
|
+
east: float,
|
|
180
|
+
north: float,
|
|
181
|
+
output: str,
|
|
182
|
+
width: int | None,
|
|
183
|
+
height: int | None,
|
|
184
|
+
resx: float | None,
|
|
185
|
+
resy: float | None,
|
|
186
|
+
) -> None:
|
|
187
|
+
from soilgrids import SoilGrids as SoilGridsClient
|
|
188
|
+
|
|
189
|
+
client = SoilGridsClient()
|
|
190
|
+
kwargs: dict[str, Any] = dict(
|
|
191
|
+
service_id=self.service_id,
|
|
192
|
+
coverage_id=self.coverage_id,
|
|
193
|
+
crs=_crs_to_soilgrids_urn(self.crs),
|
|
194
|
+
west=west,
|
|
195
|
+
south=south,
|
|
196
|
+
east=east,
|
|
197
|
+
north=north,
|
|
198
|
+
output=output,
|
|
199
|
+
)
|
|
200
|
+
if width is not None and height is not None:
|
|
201
|
+
kwargs["width"] = width
|
|
202
|
+
kwargs["height"] = height
|
|
203
|
+
elif resx is not None and resy is not None:
|
|
204
|
+
kwargs["resx"] = resx
|
|
205
|
+
kwargs["resy"] = resy
|
|
206
|
+
|
|
207
|
+
if self.response_crs is not None:
|
|
208
|
+
kwargs["response_crs"] = _crs_to_soilgrids_urn(self.response_crs)
|
|
209
|
+
|
|
210
|
+
client.get_coverage_data(**kwargs)
|
|
211
|
+
|
|
212
|
+
def read_raster(
|
|
213
|
+
self,
|
|
214
|
+
layer_name: str,
|
|
215
|
+
item_name: str,
|
|
216
|
+
bands: list[str],
|
|
217
|
+
projection: Projection,
|
|
218
|
+
bounds: PixelBounds,
|
|
219
|
+
resampling: Resampling = Resampling.bilinear,
|
|
220
|
+
) -> npt.NDArray[Any]:
|
|
221
|
+
"""Read and reproject a SoilGrids coverage subset into the requested grid."""
|
|
222
|
+
if bands != self.band_names:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"expected request for bands {self.band_names} but got {bands}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Compute bounding box in CRS coordinates for the request.
|
|
228
|
+
request_crs = _crs_to_rasterio(self.crs)
|
|
229
|
+
request_projection = Projection(request_crs, 1.0, 1.0)
|
|
230
|
+
request_geom = STGeometry(projection, shapely.box(*bounds), None).to_projection(
|
|
231
|
+
request_projection
|
|
232
|
+
)
|
|
233
|
+
west, south, east, north = request_geom.shp.bounds
|
|
234
|
+
|
|
235
|
+
# Determine output grid for the WCS request.
|
|
236
|
+
#
|
|
237
|
+
# If the user explicitly configured an output grid (width/height or resx/resy),
|
|
238
|
+
# we respect it.
|
|
239
|
+
#
|
|
240
|
+
# Otherwise, default to requesting at ~250 m resolution in the request CRS
|
|
241
|
+
# (when it is projected), and then reprojecting to the window grid.
|
|
242
|
+
#
|
|
243
|
+
# For EPSG:4326 requests, SoilGrids WCS requires WIDTH/HEIGHT, so we default
|
|
244
|
+
# to matching the window pixel size.
|
|
245
|
+
window_width = bounds[2] - bounds[0]
|
|
246
|
+
window_height = bounds[3] - bounds[1]
|
|
247
|
+
|
|
248
|
+
out_width = self.width
|
|
249
|
+
out_height = self.height
|
|
250
|
+
out_resx = self.resx
|
|
251
|
+
out_resy = self.resy
|
|
252
|
+
|
|
253
|
+
if request_crs.to_epsg() == 4326 and out_width is None:
|
|
254
|
+
# Required by the SoilGrids WCS for EPSG:4326; resx/resy is not accepted.
|
|
255
|
+
out_width = window_width
|
|
256
|
+
out_height = window_height
|
|
257
|
+
out_resx = None
|
|
258
|
+
out_resy = None
|
|
259
|
+
elif out_width is None and out_resx is None:
|
|
260
|
+
# Default to native-ish SoilGrids resolution (~250 m) in projected CRSs.
|
|
261
|
+
out_resx = 250.0
|
|
262
|
+
out_resy = 250.0
|
|
263
|
+
|
|
264
|
+
with tempfile.TemporaryDirectory(prefix="rslearn_soilgrids_") as tmpdir:
|
|
265
|
+
output_path = str(UPath(tmpdir) / "coverage.tif")
|
|
266
|
+
self._download_geotiff(
|
|
267
|
+
west=west,
|
|
268
|
+
south=south,
|
|
269
|
+
east=east,
|
|
270
|
+
north=north,
|
|
271
|
+
output=output_path,
|
|
272
|
+
width=out_width,
|
|
273
|
+
height=out_height,
|
|
274
|
+
resx=out_resx,
|
|
275
|
+
resy=out_resy,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
with rasterio.open(output_path) as src:
|
|
279
|
+
src_array = src.read(1).astype(np.float32)
|
|
280
|
+
src_nodata = src.nodata
|
|
281
|
+
scale = float(src.scales[0]) if src.scales else 1.0
|
|
282
|
+
offset = float(src.offsets[0]) if src.offsets else 0.0
|
|
283
|
+
|
|
284
|
+
if src_nodata is not None:
|
|
285
|
+
valid_mask = src_array != float(src_nodata)
|
|
286
|
+
src_array[valid_mask] = src_array[valid_mask] * scale + offset
|
|
287
|
+
dst_nodata = float(src_nodata)
|
|
288
|
+
src_nodata_val = dst_nodata
|
|
289
|
+
else:
|
|
290
|
+
src_array = src_array * scale + offset
|
|
291
|
+
dst_nodata = SOILGRIDS_NODATA_VALUE
|
|
292
|
+
src_nodata_val = None
|
|
293
|
+
|
|
294
|
+
src_chw = src_array[None, :, :]
|
|
295
|
+
dst = np.full(
|
|
296
|
+
(1, bounds[3] - bounds[1], bounds[2] - bounds[0]),
|
|
297
|
+
dst_nodata,
|
|
298
|
+
dtype=np.float32,
|
|
299
|
+
)
|
|
300
|
+
dst_transform = get_transform_from_projection_and_bounds(
|
|
301
|
+
projection, bounds
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
rasterio.warp.reproject(
|
|
305
|
+
source=src_chw,
|
|
306
|
+
src_crs=src.crs,
|
|
307
|
+
src_transform=src.transform,
|
|
308
|
+
src_nodata=src_nodata_val,
|
|
309
|
+
destination=dst,
|
|
310
|
+
dst_crs=projection.crs,
|
|
311
|
+
dst_transform=dst_transform,
|
|
312
|
+
dst_nodata=dst_nodata,
|
|
313
|
+
resampling=resampling,
|
|
314
|
+
)
|
|
315
|
+
return dst
|
|
316
|
+
|
|
317
|
+
def materialize(
|
|
318
|
+
self,
|
|
319
|
+
window: Window,
|
|
320
|
+
item_groups: list[list[Item]],
|
|
321
|
+
layer_name: str,
|
|
322
|
+
layer_cfg: LayerConfig,
|
|
323
|
+
) -> None:
|
|
324
|
+
"""Materialize a window by reading from SoilGrids on-demand."""
|
|
325
|
+
RasterMaterializer().materialize(
|
|
326
|
+
TileStoreWithLayer(self, layer_name),
|
|
327
|
+
window,
|
|
328
|
+
layer_name,
|
|
329
|
+
layer_cfg,
|
|
330
|
+
item_groups,
|
|
331
|
+
)
|