rslearn 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/aws_open_data.py +11 -15
- rslearn/data_sources/aws_sentinel2_element84.py +374 -0
- rslearn/data_sources/gcp_public_data.py +16 -0
- rslearn/data_sources/planetary_computer.py +78 -257
- rslearn/data_sources/soilgrids.py +331 -0
- rslearn/data_sources/stac.py +275 -0
- rslearn/main.py +4 -1
- rslearn/models/attention_pooling.py +5 -2
- rslearn/train/lightning_module.py +24 -11
- rslearn/train/tasks/embedding.py +2 -2
- rslearn/train/tasks/multi_task.py +8 -5
- rslearn/train/tasks/per_pixel_regression.py +1 -1
- rslearn/train/tasks/segmentation.py +143 -21
- rslearn/train/tasks/task.py +4 -2
- rslearn/utils/geometry.py +2 -2
- rslearn/utils/stac.py +173 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/METADATA +4 -1
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/RECORD +23 -19
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/WHEEL +1 -1
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Data source for SoilGrids via the `soilgrids` Python package.
|
|
2
|
+
|
|
3
|
+
This source is intended to be used with `ingest: false` (direct materialization),
|
|
4
|
+
since data is fetched on-demand per window.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import tempfile
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
import rasterio
|
|
15
|
+
import rasterio.warp
|
|
16
|
+
import shapely
|
|
17
|
+
from rasterio.crs import CRS
|
|
18
|
+
from rasterio.enums import Resampling
|
|
19
|
+
from upath import UPath
|
|
20
|
+
|
|
21
|
+
from rslearn.config import LayerConfig, QueryConfig
|
|
22
|
+
from rslearn.dataset import Window
|
|
23
|
+
from rslearn.dataset.materialize import RasterMaterializer
|
|
24
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
25
|
+
from rslearn.utils import PixelBounds, Projection, STGeometry
|
|
26
|
+
from rslearn.utils.geometry import get_global_geometry
|
|
27
|
+
from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
|
|
28
|
+
|
|
29
|
+
from .data_source import DataSource, DataSourceContext, Item
|
|
30
|
+
from .utils import match_candidate_items_to_window
|
|
31
|
+
|
|
32
|
+
SOILGRIDS_NODATA_VALUE = -32768.0
|
|
33
|
+
"""Default nodata value used by SoilGrids GeoTIFF responses (GEOTIFF_INT16)."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _crs_to_rasterio(crs: str) -> CRS:
|
|
37
|
+
"""Best-effort conversion of CRS strings used by `soilgrids` to rasterio CRS."""
|
|
38
|
+
try:
|
|
39
|
+
return CRS.from_string(crs)
|
|
40
|
+
except Exception:
|
|
41
|
+
# Fallback: if rasterio can't parse the string but it contains an EPSG code,
|
|
42
|
+
# extract the trailing integer and build a CRS from it.
|
|
43
|
+
parts = [p for p in crs.replace(":", " ").split() if p.isdigit()]
|
|
44
|
+
if parts:
|
|
45
|
+
return CRS.from_epsg(int(parts[-1]))
|
|
46
|
+
raise
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _crs_to_soilgrids_urn(crs: str) -> str:
|
|
50
|
+
"""Convert common CRS spellings to the URN form expected by `soilgrids`.
|
|
51
|
+
|
|
52
|
+
The `soilgrids` package compares CRS strings against the supported CRS URNs from
|
|
53
|
+
OWSLib (e.g. "urn:ogc:def:crs:EPSG::3857"). This helper allows users to specify
|
|
54
|
+
simpler forms like "EPSG:3857" while still working.
|
|
55
|
+
"""
|
|
56
|
+
s = crs.strip()
|
|
57
|
+
|
|
58
|
+
# If already an EPSG URN, canonicalize to the form soilgrids expects.
|
|
59
|
+
if s.lower().startswith("urn:ogc:def:crs:") and "epsg" in s.lower():
|
|
60
|
+
parts = [p for p in s.replace(":", " ").split() if p.isdigit()]
|
|
61
|
+
if parts:
|
|
62
|
+
return f"urn:ogc:def:crs:EPSG::{parts[-1]}"
|
|
63
|
+
return s
|
|
64
|
+
|
|
65
|
+
# Accept "EPSG:3857", "epsg:3857", or other strings containing an EPSG code.
|
|
66
|
+
if "epsg" in s.lower():
|
|
67
|
+
parts = [p for p in s.replace(":", " ").split() if p.isdigit()]
|
|
68
|
+
if parts:
|
|
69
|
+
return f"urn:ogc:def:crs:EPSG::{parts[-1]}"
|
|
70
|
+
|
|
71
|
+
return s
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class SoilGrids(DataSource, TileStore):
|
|
75
|
+
"""Access SoilGrids coverages as an rslearn raster data source."""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
service_id: str,
|
|
80
|
+
coverage_id: str,
|
|
81
|
+
crs: str = "EPSG:3857",
|
|
82
|
+
width: int | None = None,
|
|
83
|
+
height: int | None = None,
|
|
84
|
+
resx: float | None = None,
|
|
85
|
+
resy: float | None = None,
|
|
86
|
+
response_crs: str | None = None,
|
|
87
|
+
band_names: list[str] = ["B1"],
|
|
88
|
+
context: DataSourceContext = DataSourceContext(),
|
|
89
|
+
):
|
|
90
|
+
"""Create a new SoilGrids data source.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
service_id: SoilGrids map service id (e.g., "clay", "phh2o").
|
|
94
|
+
coverage_id: coverage id within the service (e.g., "clay_0-5cm_mean").
|
|
95
|
+
crs: request CRS string passed through to `soilgrids.SoilGrids`, typically
|
|
96
|
+
a URN like "urn:ogc:def:crs:EPSG::4326" or "urn:ogc:def:crs:EPSG::152160".
|
|
97
|
+
width: optional WCS WIDTH parameter. Required by SoilGrids WCS when CRS is
|
|
98
|
+
EPSG:4326.
|
|
99
|
+
height: optional WCS HEIGHT parameter.
|
|
100
|
+
resx: optional WCS RESX parameter (projection units / pixel).
|
|
101
|
+
resy: optional WCS RESY parameter (projection units / pixel).
|
|
102
|
+
response_crs: optional response CRS (defaults to `crs`).
|
|
103
|
+
band_names: band names exposed to rslearn. For a single coverage, this
|
|
104
|
+
should have length 1.
|
|
105
|
+
context: rslearn data source context.
|
|
106
|
+
"""
|
|
107
|
+
if len(band_names) != 1:
|
|
108
|
+
raise ValueError("SoilGrids currently supports only single-band coverages")
|
|
109
|
+
if (width is None) != (height is None):
|
|
110
|
+
raise ValueError("width and height must be specified together")
|
|
111
|
+
if (resx is None) != (resy is None):
|
|
112
|
+
raise ValueError("resx and resy must be specified together")
|
|
113
|
+
if width is not None and resx is not None:
|
|
114
|
+
raise ValueError("specify either width/height or resx/resy, not both")
|
|
115
|
+
|
|
116
|
+
self.service_id = service_id
|
|
117
|
+
self.coverage_id = coverage_id
|
|
118
|
+
self.crs = crs
|
|
119
|
+
self.width = width
|
|
120
|
+
self.height = height
|
|
121
|
+
self.resx = resx
|
|
122
|
+
self.resy = resy
|
|
123
|
+
self.response_crs = response_crs
|
|
124
|
+
self.band_names = band_names
|
|
125
|
+
|
|
126
|
+
# Represent the coverage as a single item that matches all windows.
|
|
127
|
+
item_name = f"{self.service_id}:{self.coverage_id}"
|
|
128
|
+
self._items = [Item(item_name, get_global_geometry(time_range=None))]
|
|
129
|
+
|
|
130
|
+
def get_items(
|
|
131
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
132
|
+
) -> list[list[list[Item]]]:
|
|
133
|
+
"""Get item groups matching each requested geometry."""
|
|
134
|
+
groups = []
|
|
135
|
+
for geometry in geometries:
|
|
136
|
+
cur_groups = match_candidate_items_to_window(
|
|
137
|
+
geometry, self._items, query_config
|
|
138
|
+
)
|
|
139
|
+
groups.append(cur_groups)
|
|
140
|
+
return groups
|
|
141
|
+
|
|
142
|
+
def deserialize_item(self, serialized_item: Any) -> Item:
|
|
143
|
+
"""Deserialize an item from JSON-decoded data."""
|
|
144
|
+
return Item.deserialize(serialized_item)
|
|
145
|
+
|
|
146
|
+
def ingest(
|
|
147
|
+
self,
|
|
148
|
+
tile_store: TileStoreWithLayer,
|
|
149
|
+
items: list[Item],
|
|
150
|
+
geometries: list[list[STGeometry]],
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Ingest is not supported (direct materialization only)."""
|
|
153
|
+
raise NotImplementedError(
|
|
154
|
+
"SoilGrids is intended for direct materialization; set data_source.ingest=false."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def is_raster_ready(
|
|
158
|
+
self, layer_name: str, item_name: str, bands: list[str]
|
|
159
|
+
) -> bool:
|
|
160
|
+
"""Return whether the requested raster is ready (always true for direct reads)."""
|
|
161
|
+
return True
|
|
162
|
+
|
|
163
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
164
|
+
"""Return the band sets available for this coverage."""
|
|
165
|
+
return [self.band_names]
|
|
166
|
+
|
|
167
|
+
def get_raster_bounds(
|
|
168
|
+
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
169
|
+
) -> PixelBounds:
|
|
170
|
+
"""Return (approximate) bounds for this raster in the requested projection."""
|
|
171
|
+
# We don't know bounds without an extra metadata request; treat as "very large"
|
|
172
|
+
# so materialization always attempts reads for windows.
|
|
173
|
+
return (-(10**9), -(10**9), 10**9, 10**9)
|
|
174
|
+
|
|
175
|
+
def _download_geotiff(
|
|
176
|
+
self,
|
|
177
|
+
west: float,
|
|
178
|
+
south: float,
|
|
179
|
+
east: float,
|
|
180
|
+
north: float,
|
|
181
|
+
output: str,
|
|
182
|
+
width: int | None,
|
|
183
|
+
height: int | None,
|
|
184
|
+
resx: float | None,
|
|
185
|
+
resy: float | None,
|
|
186
|
+
) -> None:
|
|
187
|
+
from soilgrids import SoilGrids as SoilGridsClient
|
|
188
|
+
|
|
189
|
+
client = SoilGridsClient()
|
|
190
|
+
kwargs: dict[str, Any] = dict(
|
|
191
|
+
service_id=self.service_id,
|
|
192
|
+
coverage_id=self.coverage_id,
|
|
193
|
+
crs=_crs_to_soilgrids_urn(self.crs),
|
|
194
|
+
west=west,
|
|
195
|
+
south=south,
|
|
196
|
+
east=east,
|
|
197
|
+
north=north,
|
|
198
|
+
output=output,
|
|
199
|
+
)
|
|
200
|
+
if width is not None and height is not None:
|
|
201
|
+
kwargs["width"] = width
|
|
202
|
+
kwargs["height"] = height
|
|
203
|
+
elif resx is not None and resy is not None:
|
|
204
|
+
kwargs["resx"] = resx
|
|
205
|
+
kwargs["resy"] = resy
|
|
206
|
+
|
|
207
|
+
if self.response_crs is not None:
|
|
208
|
+
kwargs["response_crs"] = _crs_to_soilgrids_urn(self.response_crs)
|
|
209
|
+
|
|
210
|
+
client.get_coverage_data(**kwargs)
|
|
211
|
+
|
|
212
|
+
def read_raster(
|
|
213
|
+
self,
|
|
214
|
+
layer_name: str,
|
|
215
|
+
item_name: str,
|
|
216
|
+
bands: list[str],
|
|
217
|
+
projection: Projection,
|
|
218
|
+
bounds: PixelBounds,
|
|
219
|
+
resampling: Resampling = Resampling.bilinear,
|
|
220
|
+
) -> npt.NDArray[Any]:
|
|
221
|
+
"""Read and reproject a SoilGrids coverage subset into the requested grid."""
|
|
222
|
+
if bands != self.band_names:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"expected request for bands {self.band_names} but got {bands}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Compute bounding box in CRS coordinates for the request.
|
|
228
|
+
request_crs = _crs_to_rasterio(self.crs)
|
|
229
|
+
request_projection = Projection(request_crs, 1.0, 1.0)
|
|
230
|
+
request_geom = STGeometry(projection, shapely.box(*bounds), None).to_projection(
|
|
231
|
+
request_projection
|
|
232
|
+
)
|
|
233
|
+
west, south, east, north = request_geom.shp.bounds
|
|
234
|
+
|
|
235
|
+
# Determine output grid for the WCS request.
|
|
236
|
+
#
|
|
237
|
+
# If the user explicitly configured an output grid (width/height or resx/resy),
|
|
238
|
+
# we respect it.
|
|
239
|
+
#
|
|
240
|
+
# Otherwise, default to requesting at ~250 m resolution in the request CRS
|
|
241
|
+
# (when it is projected), and then reprojecting to the window grid.
|
|
242
|
+
#
|
|
243
|
+
# For EPSG:4326 requests, SoilGrids WCS requires WIDTH/HEIGHT, so we default
|
|
244
|
+
# to matching the window pixel size.
|
|
245
|
+
window_width = bounds[2] - bounds[0]
|
|
246
|
+
window_height = bounds[3] - bounds[1]
|
|
247
|
+
|
|
248
|
+
out_width = self.width
|
|
249
|
+
out_height = self.height
|
|
250
|
+
out_resx = self.resx
|
|
251
|
+
out_resy = self.resy
|
|
252
|
+
|
|
253
|
+
if request_crs.to_epsg() == 4326 and out_width is None:
|
|
254
|
+
# Required by the SoilGrids WCS for EPSG:4326; resx/resy is not accepted.
|
|
255
|
+
out_width = window_width
|
|
256
|
+
out_height = window_height
|
|
257
|
+
out_resx = None
|
|
258
|
+
out_resy = None
|
|
259
|
+
elif out_width is None and out_resx is None:
|
|
260
|
+
# Default to native-ish SoilGrids resolution (~250 m) in projected CRSs.
|
|
261
|
+
out_resx = 250.0
|
|
262
|
+
out_resy = 250.0
|
|
263
|
+
|
|
264
|
+
with tempfile.TemporaryDirectory(prefix="rslearn_soilgrids_") as tmpdir:
|
|
265
|
+
output_path = str(UPath(tmpdir) / "coverage.tif")
|
|
266
|
+
self._download_geotiff(
|
|
267
|
+
west=west,
|
|
268
|
+
south=south,
|
|
269
|
+
east=east,
|
|
270
|
+
north=north,
|
|
271
|
+
output=output_path,
|
|
272
|
+
width=out_width,
|
|
273
|
+
height=out_height,
|
|
274
|
+
resx=out_resx,
|
|
275
|
+
resy=out_resy,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
with rasterio.open(output_path) as src:
|
|
279
|
+
src_array = src.read(1).astype(np.float32)
|
|
280
|
+
src_nodata = src.nodata
|
|
281
|
+
scale = float(src.scales[0]) if src.scales else 1.0
|
|
282
|
+
offset = float(src.offsets[0]) if src.offsets else 0.0
|
|
283
|
+
|
|
284
|
+
if src_nodata is not None:
|
|
285
|
+
valid_mask = src_array != float(src_nodata)
|
|
286
|
+
src_array[valid_mask] = src_array[valid_mask] * scale + offset
|
|
287
|
+
dst_nodata = float(src_nodata)
|
|
288
|
+
src_nodata_val = dst_nodata
|
|
289
|
+
else:
|
|
290
|
+
src_array = src_array * scale + offset
|
|
291
|
+
dst_nodata = SOILGRIDS_NODATA_VALUE
|
|
292
|
+
src_nodata_val = None
|
|
293
|
+
|
|
294
|
+
src_chw = src_array[None, :, :]
|
|
295
|
+
dst = np.full(
|
|
296
|
+
(1, bounds[3] - bounds[1], bounds[2] - bounds[0]),
|
|
297
|
+
dst_nodata,
|
|
298
|
+
dtype=np.float32,
|
|
299
|
+
)
|
|
300
|
+
dst_transform = get_transform_from_projection_and_bounds(
|
|
301
|
+
projection, bounds
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
rasterio.warp.reproject(
|
|
305
|
+
source=src_chw,
|
|
306
|
+
src_crs=src.crs,
|
|
307
|
+
src_transform=src.transform,
|
|
308
|
+
src_nodata=src_nodata_val,
|
|
309
|
+
destination=dst,
|
|
310
|
+
dst_crs=projection.crs,
|
|
311
|
+
dst_transform=dst_transform,
|
|
312
|
+
dst_nodata=dst_nodata,
|
|
313
|
+
resampling=resampling,
|
|
314
|
+
)
|
|
315
|
+
return dst
|
|
316
|
+
|
|
317
|
+
def materialize(
|
|
318
|
+
self,
|
|
319
|
+
window: Window,
|
|
320
|
+
item_groups: list[list[Item]],
|
|
321
|
+
layer_name: str,
|
|
322
|
+
layer_cfg: LayerConfig,
|
|
323
|
+
) -> None:
|
|
324
|
+
"""Materialize a window by reading from SoilGrids on-demand."""
|
|
325
|
+
RasterMaterializer().materialize(
|
|
326
|
+
TileStoreWithLayer(self, layer_name),
|
|
327
|
+
window,
|
|
328
|
+
layer_name,
|
|
329
|
+
layer_cfg,
|
|
330
|
+
item_groups,
|
|
331
|
+
)
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""A partial data source implementation providing get_items using a STAC API."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import shapely
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.config import QueryConfig
|
|
11
|
+
from rslearn.const import WGS84_PROJECTION
|
|
12
|
+
from rslearn.data_sources.data_source import Item, ItemLookupDataSource
|
|
13
|
+
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
14
|
+
from rslearn.log_utils import get_logger
|
|
15
|
+
from rslearn.utils.geometry import STGeometry
|
|
16
|
+
from rslearn.utils.stac import StacClient, StacItem
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SourceItem(Item):
|
|
22
|
+
"""An item in the StacDataSource data source."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
name: str,
|
|
27
|
+
geometry: STGeometry,
|
|
28
|
+
asset_urls: dict[str, str],
|
|
29
|
+
properties: dict[str, str],
|
|
30
|
+
):
|
|
31
|
+
"""Creates a new SourceItem.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
name: unique name of the item
|
|
35
|
+
geometry: the spatial and temporal extent of the item
|
|
36
|
+
asset_urls: map from asset key to the unsigned asset URL.
|
|
37
|
+
properties: properties requested by the data source implementation.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__(name, geometry)
|
|
40
|
+
self.asset_urls = asset_urls
|
|
41
|
+
self.properties = properties
|
|
42
|
+
|
|
43
|
+
def serialize(self) -> dict[str, Any]:
|
|
44
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
45
|
+
d = super().serialize()
|
|
46
|
+
d["asset_urls"] = self.asset_urls
|
|
47
|
+
d["properties"] = self.properties
|
|
48
|
+
return d
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def deserialize(d: dict[str, Any]) -> "SourceItem":
|
|
52
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
53
|
+
item = super(SourceItem, SourceItem).deserialize(d)
|
|
54
|
+
return SourceItem(
|
|
55
|
+
name=item.name,
|
|
56
|
+
geometry=item.geometry,
|
|
57
|
+
asset_urls=d["asset_urls"],
|
|
58
|
+
properties=d["properties"],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
63
|
+
"""A partial data source implementing get_items using a STAC API.
|
|
64
|
+
|
|
65
|
+
This is a helper class that full implementations can extend to not have to worry
|
|
66
|
+
about the get_items and get_item_by_name implementation.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
endpoint: str,
|
|
72
|
+
collection_name: str,
|
|
73
|
+
query: dict[str, Any] | None = None,
|
|
74
|
+
sort_by: str | None = None,
|
|
75
|
+
sort_ascending: bool = True,
|
|
76
|
+
required_assets: list[str] | None = None,
|
|
77
|
+
cache_dir: UPath | None = None,
|
|
78
|
+
limit: int = 100,
|
|
79
|
+
properties_to_record: list[str] = [],
|
|
80
|
+
):
|
|
81
|
+
"""Create a new StacDataSource.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
endpoint: the STAC endpoint to use.
|
|
85
|
+
collection_name: the STAC collection name.
|
|
86
|
+
query: optional STAC query dict to include in searches, e.g. {"eo:cloud_cover": {"lt": 50}}.
|
|
87
|
+
sort_by: sort results by this STAC property.
|
|
88
|
+
sort_ascending: if sort_by is set, sort in ascending order (default).
|
|
89
|
+
Otherwise sort in descending order.
|
|
90
|
+
required_assets: if set, we ignore items that do not have all of these
|
|
91
|
+
asset keys.
|
|
92
|
+
cache_dir: optional cache directory to cache items. This is recommended if
|
|
93
|
+
allowing direct materialization from the data source, since it will
|
|
94
|
+
likely be necessary to make lots of get_item_by_name calls during
|
|
95
|
+
materialization. TODO: give direct materialization access to the Item
|
|
96
|
+
object.
|
|
97
|
+
limit: limit to pass to search queries.
|
|
98
|
+
properties_to_record: if these properties on the STAC item exist, they are
|
|
99
|
+
are retained in the SourceItem when we initialize it.
|
|
100
|
+
"""
|
|
101
|
+
self.client = StacClient(endpoint)
|
|
102
|
+
self.collection_name = collection_name
|
|
103
|
+
self.query = query
|
|
104
|
+
self.sort_by = sort_by
|
|
105
|
+
self.sort_ascending = sort_ascending
|
|
106
|
+
self.required_assets = required_assets
|
|
107
|
+
self.cache_dir = cache_dir
|
|
108
|
+
self.limit = limit
|
|
109
|
+
self.properties_to_record = properties_to_record
|
|
110
|
+
|
|
111
|
+
def _stac_item_to_item(self, stac_item: StacItem) -> SourceItem:
|
|
112
|
+
# Make sure geometry, time range, and assets are set.
|
|
113
|
+
if stac_item.geometry is None:
|
|
114
|
+
raise ValueError("got unexpected item with no geometry")
|
|
115
|
+
if stac_item.time_range is None:
|
|
116
|
+
raise ValueError("got unexpected item with no time range")
|
|
117
|
+
if stac_item.assets is None:
|
|
118
|
+
raise ValueError("got unexpected item with no assets")
|
|
119
|
+
|
|
120
|
+
shp = shapely.geometry.shape(stac_item.geometry)
|
|
121
|
+
geom = STGeometry(WGS84_PROJECTION, shp, stac_item.time_range)
|
|
122
|
+
asset_urls = {
|
|
123
|
+
asset_key: asset_obj.href
|
|
124
|
+
for asset_key, asset_obj in stac_item.assets.items()
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
# Keep any properties requested by the data source implementation.
|
|
128
|
+
properties = {}
|
|
129
|
+
for prop_name in self.properties_to_record:
|
|
130
|
+
if prop_name not in stac_item.properties:
|
|
131
|
+
continue
|
|
132
|
+
properties[prop_name] = stac_item.properties[prop_name]
|
|
133
|
+
|
|
134
|
+
return SourceItem(stac_item.id, geom, asset_urls, properties)
|
|
135
|
+
|
|
136
|
+
def _get_search_time_range(
|
|
137
|
+
self, geometry: STGeometry
|
|
138
|
+
) -> datetime | tuple[datetime, datetime] | None:
|
|
139
|
+
"""Get time range to include in STAC API search.
|
|
140
|
+
|
|
141
|
+
By default, we filter STAC searches to the window's time range. Subclasses can
|
|
142
|
+
override this to disable time filtering for "static" datasets.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
geometry: the geometry we are searching for.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
the time range (or timestamp) to pass to the STAC search, or None to avoid
|
|
149
|
+
temporal filtering in the search request.
|
|
150
|
+
"""
|
|
151
|
+
# Note: StacClient.search accepts either a datetime or a (start, end) tuple.
|
|
152
|
+
return geometry.time_range
|
|
153
|
+
|
|
154
|
+
def get_item_by_name(self, name: str) -> SourceItem:
|
|
155
|
+
"""Gets an item by name.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
name: the name of the item to get
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
the item object
|
|
162
|
+
"""
|
|
163
|
+
# If cache_dir is set, we cache the item. First here we check if it is already
|
|
164
|
+
# in the cache.
|
|
165
|
+
cache_fname: UPath | None = None
|
|
166
|
+
if self.cache_dir:
|
|
167
|
+
cache_fname = self.cache_dir / f"{name}.json"
|
|
168
|
+
if cache_fname is not None and cache_fname.exists():
|
|
169
|
+
with cache_fname.open() as f:
|
|
170
|
+
return SourceItem.deserialize(json.load(f))
|
|
171
|
+
|
|
172
|
+
# No cache or not in cache, so we need to make the STAC request.
|
|
173
|
+
logger.debug(f"Getting STAC item {name}")
|
|
174
|
+
stac_items = self.client.search(ids=[name], collections=[self.collection_name])
|
|
175
|
+
|
|
176
|
+
if len(stac_items) == 0:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Item {name} not found in collection {self.collection_name}"
|
|
179
|
+
)
|
|
180
|
+
if len(stac_items) > 1:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
f"Multiple items found for ID {name} in collection {self.collection_name}"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
stac_item = stac_items[0]
|
|
186
|
+
item = self._stac_item_to_item(stac_item)
|
|
187
|
+
|
|
188
|
+
# Finally we cache it if cache_dir is set.
|
|
189
|
+
if cache_fname is not None:
|
|
190
|
+
with cache_fname.open("w") as f:
|
|
191
|
+
json.dump(item.serialize(), f)
|
|
192
|
+
|
|
193
|
+
return item
|
|
194
|
+
|
|
195
|
+
def get_items(
|
|
196
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
197
|
+
) -> list[list[list[SourceItem]]]:
|
|
198
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
geometries: the spatiotemporal geometries
|
|
202
|
+
query_config: the query configuration
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
List of groups of items that should be retrieved for each geometry.
|
|
206
|
+
"""
|
|
207
|
+
groups = []
|
|
208
|
+
for geometry in geometries:
|
|
209
|
+
# Get potentially relevant items from the collection by performing one search
|
|
210
|
+
# for each requested geometry.
|
|
211
|
+
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
212
|
+
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
|
|
213
|
+
search_time_range = self._get_search_time_range(wgs84_geometry)
|
|
214
|
+
stac_items = self.client.search(
|
|
215
|
+
collections=[self.collection_name],
|
|
216
|
+
intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
|
|
217
|
+
date_time=search_time_range,
|
|
218
|
+
query=self.query,
|
|
219
|
+
limit=self.limit,
|
|
220
|
+
)
|
|
221
|
+
logger.debug("STAC search yielded %d items", len(stac_items))
|
|
222
|
+
|
|
223
|
+
if self.required_assets is not None:
|
|
224
|
+
# Filter out items that are missing any of the assets in self.asset_bands.
|
|
225
|
+
good_stac_items = []
|
|
226
|
+
for stac_item in stac_items:
|
|
227
|
+
if stac_item.assets is None:
|
|
228
|
+
raise ValueError(f"got STAC item {stac_item.id} with no assets")
|
|
229
|
+
|
|
230
|
+
good = True
|
|
231
|
+
for asset_key in self.required_assets:
|
|
232
|
+
if asset_key in stac_item.assets:
|
|
233
|
+
continue
|
|
234
|
+
good = False
|
|
235
|
+
break
|
|
236
|
+
if good:
|
|
237
|
+
good_stac_items.append(stac_item)
|
|
238
|
+
logger.debug(
|
|
239
|
+
"required_assets filter from %d to %d items",
|
|
240
|
+
len(stac_items),
|
|
241
|
+
len(good_stac_items),
|
|
242
|
+
)
|
|
243
|
+
stac_items = good_stac_items
|
|
244
|
+
|
|
245
|
+
if self.sort_by is not None:
|
|
246
|
+
sort_by = self.sort_by
|
|
247
|
+
stac_items.sort(
|
|
248
|
+
key=lambda stac_item: stac_item.properties[sort_by],
|
|
249
|
+
reverse=not self.sort_ascending,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
candidate_items = [
|
|
253
|
+
self._stac_item_to_item(stac_item) for stac_item in stac_items
|
|
254
|
+
]
|
|
255
|
+
|
|
256
|
+
# Since we made the STAC request, might as well save these to the cache.
|
|
257
|
+
if self.cache_dir is not None:
|
|
258
|
+
for item in candidate_items:
|
|
259
|
+
cache_fname = self.cache_dir / f"{item.name}.json"
|
|
260
|
+
if cache_fname.exists():
|
|
261
|
+
continue
|
|
262
|
+
with cache_fname.open("w") as f:
|
|
263
|
+
json.dump(item.serialize(), f)
|
|
264
|
+
|
|
265
|
+
cur_groups = match_candidate_items_to_window(
|
|
266
|
+
geometry, candidate_items, query_config
|
|
267
|
+
)
|
|
268
|
+
groups.append(cur_groups)
|
|
269
|
+
|
|
270
|
+
return groups
|
|
271
|
+
|
|
272
|
+
def deserialize_item(self, serialized_item: Any) -> SourceItem:
|
|
273
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
274
|
+
assert isinstance(serialized_item, dict)
|
|
275
|
+
return SourceItem.deserialize(serialized_item)
|
rslearn/main.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import multiprocessing
|
|
5
|
+
import os
|
|
5
6
|
import random
|
|
6
7
|
import sys
|
|
7
8
|
import time
|
|
@@ -45,6 +46,7 @@ handler_registry = {}
|
|
|
45
46
|
ItemType = TypeVar("ItemType", bound="Item")
|
|
46
47
|
|
|
47
48
|
MULTIPROCESSING_CONTEXT = "forkserver"
|
|
49
|
+
MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
def register_handler(category: Any, command: str) -> Callable:
|
|
@@ -837,7 +839,8 @@ def model_predict() -> None:
|
|
|
837
839
|
def main() -> None:
|
|
838
840
|
"""CLI entrypoint."""
|
|
839
841
|
try:
|
|
840
|
-
|
|
842
|
+
mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
|
|
843
|
+
multiprocessing.set_start_method(mp_context)
|
|
841
844
|
except RuntimeError as e:
|
|
842
845
|
logger.error(
|
|
843
846
|
f"Multiprocessing context already set to {multiprocessing.get_context()}: "
|
|
@@ -150,8 +150,11 @@ class AttentionPool(IntermediateComponent):
|
|
|
150
150
|
D // self.num_heads
|
|
151
151
|
)
|
|
152
152
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
153
|
-
x = torch.matmul(attn_weights, v) # [B,
|
|
154
|
-
|
|
153
|
+
x = torch.matmul(attn_weights, v) # [B*H*W, num_heads, 1, D_head]
|
|
154
|
+
x = x.squeeze(-2) # [B*H*W, num_heads, D_head]
|
|
155
|
+
return rearrange(
|
|
156
|
+
x, "(b h w) nh dh -> b (nh dh) h w", b=B, h=H, w=W
|
|
157
|
+
) # [B, D, H, W]
|
|
155
158
|
|
|
156
159
|
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
157
160
|
"""Forward pass for attention pooling linear probe.
|