rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +22 -13
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +27 -18
- rslearn/data_sources/aws_open_data.py +41 -42
- rslearn/data_sources/copernicus.py +148 -2
- rslearn/data_sources/data_source.py +17 -10
- rslearn/data_sources/gcp_public_data.py +177 -100
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +17 -15
- rslearn/data_sources/local_files.py +59 -32
- rslearn/data_sources/openstreetmap.py +27 -23
- rslearn/data_sources/planet.py +10 -9
- rslearn/data_sources/planet_basemap.py +303 -0
- rslearn/data_sources/raster_source.py +23 -13
- rslearn/data_sources/usgs_landsat.py +56 -27
- rslearn/data_sources/utils.py +13 -6
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/xyz_tiles.py +8 -9
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +16 -5
- rslearn/dataset/manage.py +9 -4
- rslearn/dataset/materialize.py +26 -5
- rslearn/dataset/window.py +5 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +123 -59
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/faster_rcnn.py +2 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +43 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +1 -1
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +61 -55
- rslearn/models/ssl4eo_s12.py +9 -9
- rslearn/models/swin.py +22 -21
- rslearn/models/unet.py +4 -2
- rslearn/models/upsample.py +35 -0
- rslearn/tile_stores/file.py +6 -3
- rslearn/tile_stores/tile_store.py +19 -7
- rslearn/train/callbacks/freeze_unfreeze.py +3 -3
- rslearn/train/data_module.py +5 -4
- rslearn/train/dataset.py +79 -36
- rslearn/train/lightning_module.py +15 -11
- rslearn/train/prediction_writer.py +22 -11
- rslearn/train/tasks/classification.py +9 -8
- rslearn/train/tasks/detection.py +94 -37
- rslearn/train/tasks/multi_task.py +1 -1
- rslearn/train/tasks/regression.py +8 -4
- rslearn/train/tasks/segmentation.py +23 -19
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +6 -2
- rslearn/train/transforms/crop.py +6 -2
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +9 -5
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +3 -3
- rslearn/utils/__init__.py +4 -5
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +1 -1
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +155 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +81 -73
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/utils.py +11 -3
- rslearn/utils/vector_format.py +113 -17
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
- rslearn-0.0.2.dist-info/RECORD +94 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
- rslearn/utils/mgrs.py +0 -24
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""Data source for Planet Labs Basemaps API."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import rasterio
|
|
9
|
+
import requests
|
|
10
|
+
import shapely
|
|
11
|
+
from upath import UPath
|
|
12
|
+
|
|
13
|
+
from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
|
|
14
|
+
from rslearn.const import WGS84_PROJECTION
|
|
15
|
+
from rslearn.data_sources import DataSource, Item
|
|
16
|
+
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
17
|
+
from rslearn.log_utils import get_logger
|
|
18
|
+
from rslearn.tile_stores import PrefixedTileStore, TileStore
|
|
19
|
+
from rslearn.utils import STGeometry
|
|
20
|
+
|
|
21
|
+
from .raster_source import get_needed_projections, ingest_raster
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PlanetItem(Item):
|
|
27
|
+
"""An item referencing a particular mosaic and quad in Basemaps API."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, name: str, geometry: STGeometry, mosaic_id: str, quad_id: str):
|
|
30
|
+
"""Create a new PlanetItem.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
name: the item name (combination of mosaic and quad ID).
|
|
34
|
+
geometry: the geometry associated with this quad.
|
|
35
|
+
mosaic_id: the mosaic ID in API
|
|
36
|
+
quad_id: the quad ID in API
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(name, geometry)
|
|
39
|
+
self.mosaic_id = mosaic_id
|
|
40
|
+
self.quad_id = quad_id
|
|
41
|
+
|
|
42
|
+
def serialize(self) -> dict:
|
|
43
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
44
|
+
d = super().serialize()
|
|
45
|
+
d["mosaic_id"] = self.mosaic_id
|
|
46
|
+
d["quad_id"] = self.quad_id
|
|
47
|
+
return d
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def deserialize(d: dict) -> Item:
|
|
51
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
52
|
+
item = super(PlanetItem, PlanetItem).deserialize(d)
|
|
53
|
+
return PlanetItem(
|
|
54
|
+
name=item.name,
|
|
55
|
+
geometry=item.geometry,
|
|
56
|
+
mosaic_id=d["mosaic_id"],
|
|
57
|
+
quad_id=d["quad_id"],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ApiError(Exception):
|
|
62
|
+
"""An error from Planet Labs API."""
|
|
63
|
+
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class PlanetBasemap(DataSource):
|
|
68
|
+
"""A data source for Planet Labs Basemaps API."""
|
|
69
|
+
|
|
70
|
+
api_url = "https://api.planet.com/basemaps/v1/"
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
config: RasterLayerConfig,
|
|
75
|
+
series_id: str,
|
|
76
|
+
bands: list[str],
|
|
77
|
+
api_key: str | None = None,
|
|
78
|
+
):
|
|
79
|
+
"""Initialize a new Planet instance.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
config: the LayerConfig of the layer containing this data source
|
|
83
|
+
series_id: the series of mosaics to use.
|
|
84
|
+
bands: list of band names to use.
|
|
85
|
+
api_key: optional Planet API key (it can also be provided via PL_API_KEY
|
|
86
|
+
environmnet variable).
|
|
87
|
+
"""
|
|
88
|
+
self.config = config
|
|
89
|
+
self.bands = bands
|
|
90
|
+
|
|
91
|
+
self.session = requests.Session()
|
|
92
|
+
if api_key is None:
|
|
93
|
+
api_key = os.environ["PL_API_KEY"]
|
|
94
|
+
self.session.auth = (api_key, "")
|
|
95
|
+
|
|
96
|
+
# List mosaics.
|
|
97
|
+
self.mosaics = {}
|
|
98
|
+
for mosaic_dict in self._api_get_paginate(
|
|
99
|
+
path=f"series/{series_id}/mosaics", list_key="mosaics"
|
|
100
|
+
):
|
|
101
|
+
shp = shapely.box(*mosaic_dict["bbox"])
|
|
102
|
+
time_range = (
|
|
103
|
+
datetime.fromisoformat(mosaic_dict["first_acquired"]),
|
|
104
|
+
datetime.fromisoformat(mosaic_dict["last_acquired"]),
|
|
105
|
+
)
|
|
106
|
+
geom = STGeometry(WGS84_PROJECTION, shp, time_range)
|
|
107
|
+
self.mosaics[mosaic_dict["id"]] = geom
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def from_config(config: LayerConfig, ds_path: UPath) -> "PlanetBasemap":
|
|
111
|
+
"""Creates a new PlanetBasemap instance from a configuration dictionary."""
|
|
112
|
+
assert isinstance(config, RasterLayerConfig)
|
|
113
|
+
if config.data_source is None:
|
|
114
|
+
raise ValueError("data_source is required")
|
|
115
|
+
d = config.data_source.config_dict
|
|
116
|
+
kwargs = dict(
|
|
117
|
+
config=config,
|
|
118
|
+
series_id=d["series_id"],
|
|
119
|
+
bands=d["bands"],
|
|
120
|
+
)
|
|
121
|
+
optional_keys = [
|
|
122
|
+
"api_key",
|
|
123
|
+
]
|
|
124
|
+
for optional_key in optional_keys:
|
|
125
|
+
if optional_key in d:
|
|
126
|
+
kwargs[optional_key] = d[optional_key]
|
|
127
|
+
return PlanetBasemap(**kwargs)
|
|
128
|
+
|
|
129
|
+
def _api_get(
|
|
130
|
+
self,
|
|
131
|
+
path: str | None = None,
|
|
132
|
+
url: str | None = None,
|
|
133
|
+
query_args: dict[str, str] | None = None,
|
|
134
|
+
) -> list[Any] | dict[str, Any]:
|
|
135
|
+
"""Perform a GET request on the API.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
path: the path to GET, like "series".
|
|
139
|
+
url: the full URL to GET. Only one of path or url should be set.
|
|
140
|
+
query_args: optional params to include with the request.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
the JSON response data.
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
ApiError: if the API returned an error response.
|
|
147
|
+
"""
|
|
148
|
+
if path is None and url is None:
|
|
149
|
+
raise ValueError("Only one of path or url should be set")
|
|
150
|
+
if query_args:
|
|
151
|
+
kwargs = dict(params=query_args)
|
|
152
|
+
else:
|
|
153
|
+
kwargs = {}
|
|
154
|
+
|
|
155
|
+
if path:
|
|
156
|
+
url = self.api_url + path
|
|
157
|
+
if url is None:
|
|
158
|
+
raise ValueError("url is required")
|
|
159
|
+
response = self.session.get(url, **kwargs) # type: ignore
|
|
160
|
+
|
|
161
|
+
if response.status_code != 200:
|
|
162
|
+
raise ApiError(
|
|
163
|
+
f"{url}: got status code {response.status_code}: {response.text}"
|
|
164
|
+
)
|
|
165
|
+
return response.json()
|
|
166
|
+
|
|
167
|
+
def _api_get_paginate(
|
|
168
|
+
self, path: str, list_key: str, query_args: dict[str, str] | None = None
|
|
169
|
+
) -> list:
|
|
170
|
+
"""Get all items in a paginated response.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
path: the path to GET.
|
|
174
|
+
list_key: the key in the response containing the list that should be
|
|
175
|
+
concatenated across all available pages.
|
|
176
|
+
query_args: optional params to include with the requests.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
the concatenated list of items.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ApiError if the API returned an error response.
|
|
183
|
+
"""
|
|
184
|
+
next_url = self.api_url + path
|
|
185
|
+
items = []
|
|
186
|
+
while True:
|
|
187
|
+
json_data = self._api_get(url=next_url, query_args=query_args)
|
|
188
|
+
if not isinstance(json_data, dict):
|
|
189
|
+
logger.warning(f"Expected dict, got {type(json_data)}")
|
|
190
|
+
continue
|
|
191
|
+
items += json_data[list_key]
|
|
192
|
+
|
|
193
|
+
if "_next" in json_data["_links"]:
|
|
194
|
+
next_url = json_data["_links"]["_next"]
|
|
195
|
+
else:
|
|
196
|
+
return items
|
|
197
|
+
|
|
198
|
+
def get_items(
|
|
199
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
200
|
+
) -> list[list[list[PlanetItem]]]:
|
|
201
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
geometries: the spatiotemporal geometries
|
|
205
|
+
query_config: the query configuration
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
List of groups of items that should be retrieved for each geometry.
|
|
209
|
+
"""
|
|
210
|
+
groups = []
|
|
211
|
+
for geometry in geometries:
|
|
212
|
+
geom_bbox = geometry.to_projection(WGS84_PROJECTION).shp.bounds
|
|
213
|
+
geom_bbox_str = ",".join([str(value) for value in geom_bbox])
|
|
214
|
+
|
|
215
|
+
# Find the relevant mosaics that the geometry intersects.
|
|
216
|
+
# For each relevant mosaic, identify the intersecting quads.
|
|
217
|
+
items = []
|
|
218
|
+
for mosaic_id, mosaic_geom in self.mosaics.items():
|
|
219
|
+
if not geometry.intersects(mosaic_geom):
|
|
220
|
+
continue
|
|
221
|
+
logger.info(f"found mosaic {mosaic_geom} for geom {geometry}")
|
|
222
|
+
# List all quads that intersect the current geometry's
|
|
223
|
+
# longitude/latitude bbox in this mosaic.
|
|
224
|
+
for quad_dict in self._api_get_paginate(
|
|
225
|
+
path=f"mosaics/{mosaic_id}/quads",
|
|
226
|
+
list_key="items",
|
|
227
|
+
query_args={"bbox": geom_bbox_str},
|
|
228
|
+
):
|
|
229
|
+
logger.info(f"found quad {quad_dict}")
|
|
230
|
+
shp = shapely.box(*quad_dict["bbox"])
|
|
231
|
+
geom = STGeometry(WGS84_PROJECTION, shp, mosaic_geom.time_range)
|
|
232
|
+
quad_id = quad_dict["id"]
|
|
233
|
+
items.append(
|
|
234
|
+
PlanetItem(f"{mosaic_id}_{quad_id}", geom, mosaic_id, quad_id)
|
|
235
|
+
)
|
|
236
|
+
logger.info(f"found {len(items)} items for geom {geometry}")
|
|
237
|
+
cur_groups = match_candidate_items_to_window(geometry, items, query_config)
|
|
238
|
+
groups.append(cur_groups)
|
|
239
|
+
|
|
240
|
+
return groups
|
|
241
|
+
|
|
242
|
+
def deserialize_item(self, serialized_item: Any) -> Item:
|
|
243
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
244
|
+
assert isinstance(serialized_item, dict)
|
|
245
|
+
return PlanetItem.deserialize(serialized_item)
|
|
246
|
+
|
|
247
|
+
def ingest(
|
|
248
|
+
self,
|
|
249
|
+
tile_store: TileStore,
|
|
250
|
+
items: list[Item],
|
|
251
|
+
geometries: list[list[STGeometry]],
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Ingest items into the given tile store.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
tile_store: the tile store to ingest into
|
|
257
|
+
items: the items to ingest
|
|
258
|
+
geometries: a list of geometries needed for each item
|
|
259
|
+
"""
|
|
260
|
+
for item, cur_geometries in zip(items, geometries):
|
|
261
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
262
|
+
band_names = self.bands
|
|
263
|
+
cur_tile_store = PrefixedTileStore(
|
|
264
|
+
tile_store, (item.name, "_".join(band_names))
|
|
265
|
+
)
|
|
266
|
+
needed_projections = get_needed_projections(
|
|
267
|
+
cur_tile_store, band_names, self.config.band_sets, cur_geometries
|
|
268
|
+
)
|
|
269
|
+
if not needed_projections:
|
|
270
|
+
continue
|
|
271
|
+
|
|
272
|
+
assert isinstance(item, PlanetItem)
|
|
273
|
+
download_url = (
|
|
274
|
+
self.api_url + f"mosaics/{item.mosaic_id}/quads/{item.quad_id}/full"
|
|
275
|
+
)
|
|
276
|
+
response = self.session.get(
|
|
277
|
+
download_url, allow_redirects=True, stream=True
|
|
278
|
+
)
|
|
279
|
+
if response.status_code != 200:
|
|
280
|
+
# # temporary skip for now
|
|
281
|
+
# logger.error(
|
|
282
|
+
# f"{download_url}: got status code {response.status_code}: {response.text}"
|
|
283
|
+
# )
|
|
284
|
+
# continue
|
|
285
|
+
raise ApiError(
|
|
286
|
+
f"{download_url}: got status code {response.status_code}: {response.text}"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
290
|
+
local_fname = os.path.join(tmp_dir, "temp.tif")
|
|
291
|
+
with open(local_fname, "wb") as f:
|
|
292
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
293
|
+
f.write(chunk)
|
|
294
|
+
|
|
295
|
+
with rasterio.open(local_fname) as raster:
|
|
296
|
+
for projection in needed_projections:
|
|
297
|
+
ingest_raster(
|
|
298
|
+
tile_store=cur_tile_store,
|
|
299
|
+
raster=raster,
|
|
300
|
+
projection=projection,
|
|
301
|
+
time_range=item.geometry.time_range,
|
|
302
|
+
layer_config=self.config,
|
|
303
|
+
)
|
|
@@ -15,10 +15,13 @@ from rasterio.crs import CRS
|
|
|
15
15
|
from rslearn.config import BandSetConfig, RasterFormatConfig, RasterLayerConfig
|
|
16
16
|
from rslearn.const import TILE_SIZE
|
|
17
17
|
from rslearn.dataset import Window
|
|
18
|
+
from rslearn.log_utils import get_logger
|
|
18
19
|
from rslearn.tile_stores import LayerMetadata, TileStore
|
|
19
20
|
from rslearn.utils import Projection, STGeometry
|
|
20
21
|
from rslearn.utils.raster_format import load_raster_format
|
|
21
22
|
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
class ArrayWithTransform:
|
|
24
27
|
"""Stores an array along with the transform associated with the array."""
|
|
@@ -70,7 +73,7 @@ class ArrayWithTransform:
|
|
|
70
73
|
"""
|
|
71
74
|
return self.array
|
|
72
75
|
|
|
73
|
-
def close(self):
|
|
76
|
+
def close(self) -> None:
|
|
74
77
|
"""This is to mimic the rasterio.DatasetReader API.
|
|
75
78
|
|
|
76
79
|
The close function is a no-op.
|
|
@@ -144,23 +147,25 @@ def get_needed_projections(
|
|
|
144
147
|
list of Projection objects for which the item has not been ingested yet
|
|
145
148
|
"""
|
|
146
149
|
# Identify which band set configs are relevant to this raster.
|
|
147
|
-
|
|
148
|
-
|
|
150
|
+
raster_bands_set = set(raster_bands)
|
|
151
|
+
relevant_band_set_list = []
|
|
149
152
|
for band_set in band_sets:
|
|
150
153
|
is_match = False
|
|
154
|
+
if band_set.bands is None:
|
|
155
|
+
continue
|
|
151
156
|
for band in band_set.bands:
|
|
152
|
-
if band not in
|
|
157
|
+
if band not in raster_bands_set:
|
|
153
158
|
continue
|
|
154
159
|
is_match = True
|
|
155
160
|
break
|
|
156
161
|
if not is_match:
|
|
157
162
|
continue
|
|
158
|
-
|
|
163
|
+
relevant_band_set_list.append(band_set)
|
|
159
164
|
|
|
160
165
|
all_projections = {geometry.projection for geometry in geometries}
|
|
161
166
|
needed_projections = []
|
|
162
167
|
for projection in all_projections:
|
|
163
|
-
for band_set in
|
|
168
|
+
for band_set in relevant_band_set_list:
|
|
164
169
|
final_projection, _ = band_set.get_final_projection_and_bounds(
|
|
165
170
|
projection, None
|
|
166
171
|
)
|
|
@@ -216,16 +221,17 @@ def ingest_raster(
|
|
|
216
221
|
else:
|
|
217
222
|
# Compute the suggested target transform.
|
|
218
223
|
# rasterio negates the y resolution itself so here we have to negate it.
|
|
224
|
+
raster_bounds: rasterio.coords.BoundingBox = raster.bounds
|
|
219
225
|
(dst_transform, dst_width, dst_height) = (
|
|
220
226
|
rasterio.warp.calculate_default_transform(
|
|
221
227
|
# Source info.
|
|
222
228
|
src_crs=raster.crs,
|
|
223
229
|
width=raster.width,
|
|
224
230
|
height=raster.height,
|
|
225
|
-
left=
|
|
226
|
-
bottom=
|
|
227
|
-
right=
|
|
228
|
-
top=
|
|
231
|
+
left=raster_bounds.left,
|
|
232
|
+
bottom=raster_bounds.bottom,
|
|
233
|
+
right=raster_bounds.right,
|
|
234
|
+
top=raster_bounds.top,
|
|
229
235
|
# Destination info.
|
|
230
236
|
dst_crs=projection.crs,
|
|
231
237
|
resolution=(projection.x_resolution, -projection.y_resolution),
|
|
@@ -258,7 +264,7 @@ def materialize_raster(
|
|
|
258
264
|
window: Window,
|
|
259
265
|
layer_name: str,
|
|
260
266
|
band_cfg: BandSetConfig,
|
|
261
|
-
):
|
|
267
|
+
) -> None:
|
|
262
268
|
"""Materialize a given raster for a window.
|
|
263
269
|
|
|
264
270
|
Currently it is only supported for materializing one band set.
|
|
@@ -272,7 +278,8 @@ def materialize_raster(
|
|
|
272
278
|
window_projection, window_bounds = band_cfg.get_final_projection_and_bounds(
|
|
273
279
|
window.projection, window.bounds
|
|
274
280
|
)
|
|
275
|
-
|
|
281
|
+
if window_bounds is None:
|
|
282
|
+
raise ValueError(f"No windowbounds specified for {layer_name}")
|
|
276
283
|
# Re-project to just extract the window.
|
|
277
284
|
array = raster.read()
|
|
278
285
|
window_width = window_bounds[2] - window_bounds[0]
|
|
@@ -297,7 +304,10 @@ def materialize_raster(
|
|
|
297
304
|
dst_transform=dst_transform,
|
|
298
305
|
resampling=rasterio.enums.Resampling.bilinear,
|
|
299
306
|
)
|
|
300
|
-
|
|
307
|
+
if band_cfg.bands is None or band_cfg.format is None:
|
|
308
|
+
raise ValueError(
|
|
309
|
+
f"No bands or format specified for {layer_name} materialization"
|
|
310
|
+
)
|
|
301
311
|
# Write the array to layer directory.
|
|
302
312
|
layer_dir = window.path / "layers" / layer_name
|
|
303
313
|
out_dir = layer_dir / "_".join(band_cfg.bands)
|
|
@@ -1,4 +1,7 @@
|
|
|
1
|
-
"""Data source for Landsat data from USGS M2M API.
|
|
1
|
+
"""Data source for Landsat data from USGS M2M API.
|
|
2
|
+
|
|
3
|
+
# TODO: Handle the requests in a helper function for none checking
|
|
4
|
+
"""
|
|
2
5
|
|
|
3
6
|
import io
|
|
4
7
|
import json
|
|
@@ -15,7 +18,7 @@ import requests
|
|
|
15
18
|
import shapely
|
|
16
19
|
from upath import UPath
|
|
17
20
|
|
|
18
|
-
from rslearn.config import
|
|
21
|
+
from rslearn.config import QueryConfig, RasterLayerConfig
|
|
19
22
|
from rslearn.const import WGS84_PROJECTION
|
|
20
23
|
from rslearn.data_sources import DataSource, Item
|
|
21
24
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
@@ -36,8 +39,9 @@ class M2MAPIClient:
|
|
|
36
39
|
|
|
37
40
|
api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
|
|
38
41
|
pagination_size = 1000
|
|
42
|
+
TIMEOUT = 1000000 # Set very high to start
|
|
39
43
|
|
|
40
|
-
def __init__(self, username, password):
|
|
44
|
+
def __init__(self, username: str, password: str) -> None:
|
|
41
45
|
"""Initialize a new M2MAPIClient.
|
|
42
46
|
|
|
43
47
|
Args:
|
|
@@ -47,7 +51,9 @@ class M2MAPIClient:
|
|
|
47
51
|
self.username = username
|
|
48
52
|
self.password = password
|
|
49
53
|
json_data = json.dumps({"username": self.username, "password": self.password})
|
|
50
|
-
response = requests.post(
|
|
54
|
+
response = requests.post(
|
|
55
|
+
self.api_url + "login", data=json_data, timeout=self.TIMEOUT
|
|
56
|
+
)
|
|
51
57
|
response.raise_for_status()
|
|
52
58
|
self.auth_token = response.json()["data"]
|
|
53
59
|
|
|
@@ -67,24 +73,26 @@ class M2MAPIClient:
|
|
|
67
73
|
self.api_url + endpoint,
|
|
68
74
|
headers={"X-Auth-Token": self.auth_token},
|
|
69
75
|
data=json.dumps(data),
|
|
76
|
+
timeout=self.TIMEOUT,
|
|
70
77
|
)
|
|
71
78
|
response.raise_for_status()
|
|
72
79
|
if response.text:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
80
|
+
response_dict = response.json()
|
|
81
|
+
|
|
82
|
+
if response_dict["errorMessage"]:
|
|
83
|
+
raise APIException(response_dict["errorMessage"])
|
|
84
|
+
return response_dict
|
|
77
85
|
return None
|
|
78
86
|
|
|
79
|
-
def close(self):
|
|
87
|
+
def close(self) -> None:
|
|
80
88
|
"""Logout from the API."""
|
|
81
89
|
self.request("logout")
|
|
82
90
|
|
|
83
|
-
def __enter__(self):
|
|
91
|
+
def __enter__(self) -> "M2MAPIClient":
|
|
84
92
|
"""Enter function to provide with semantics."""
|
|
85
93
|
return self
|
|
86
94
|
|
|
87
|
-
def __exit__(self):
|
|
95
|
+
def __exit__(self) -> None:
|
|
88
96
|
"""Exit function to provide with semantics.
|
|
89
97
|
|
|
90
98
|
Logs out the API.
|
|
@@ -100,7 +108,10 @@ class M2MAPIClient:
|
|
|
100
108
|
Returns:
|
|
101
109
|
list of filter objects
|
|
102
110
|
"""
|
|
103
|
-
|
|
111
|
+
response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
|
|
112
|
+
if response_dict is None:
|
|
113
|
+
raise APIException("No response from API")
|
|
114
|
+
return response_dict["data"]
|
|
104
115
|
|
|
105
116
|
def scene_search(
|
|
106
117
|
self,
|
|
@@ -119,7 +130,7 @@ class M2MAPIClient:
|
|
|
119
130
|
bbox: optional spatial filter
|
|
120
131
|
metadata_filter: optional metadata filter dict
|
|
121
132
|
"""
|
|
122
|
-
base_data = {"datasetName": dataset_name, "sceneFilter": {}}
|
|
133
|
+
base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
|
|
123
134
|
if acquisition_time_range:
|
|
124
135
|
base_data["sceneFilter"]["acquisitionFilter"] = {
|
|
125
136
|
"start": acquisition_time_range[0].isoformat(),
|
|
@@ -146,7 +157,10 @@ class M2MAPIClient:
|
|
|
146
157
|
cur_data = base_data.copy()
|
|
147
158
|
cur_data["startingNumber"] = starting_number
|
|
148
159
|
cur_data["maxResults"] = self.pagination_size
|
|
149
|
-
|
|
160
|
+
response_dict = self.request("scene-search", cur_data)
|
|
161
|
+
if response_dict is None:
|
|
162
|
+
raise APIException("No response from API")
|
|
163
|
+
data = response_dict["data"]
|
|
150
164
|
results.extend(data["results"])
|
|
151
165
|
if data["recordsReturned"] < self.pagination_size:
|
|
152
166
|
break
|
|
@@ -164,14 +178,17 @@ class M2MAPIClient:
|
|
|
164
178
|
Returns:
|
|
165
179
|
full scene metadata
|
|
166
180
|
"""
|
|
167
|
-
|
|
181
|
+
response_dict = self.request(
|
|
168
182
|
"scene-metadata",
|
|
169
183
|
{
|
|
170
184
|
"datasetName": dataset_name,
|
|
171
185
|
"entityId": entity_id,
|
|
172
186
|
"metadataType": "full",
|
|
173
187
|
},
|
|
174
|
-
)
|
|
188
|
+
)
|
|
189
|
+
if response_dict is None:
|
|
190
|
+
raise APIException("No response from API")
|
|
191
|
+
return response_dict["data"]
|
|
175
192
|
|
|
176
193
|
def get_downloadable_products(
|
|
177
194
|
self, dataset_name: str, entity_id: str
|
|
@@ -186,7 +203,10 @@ class M2MAPIClient:
|
|
|
186
203
|
list of downloadable products
|
|
187
204
|
"""
|
|
188
205
|
data = {"datasetName": dataset_name, "entityIds": [entity_id]}
|
|
189
|
-
|
|
206
|
+
response_dict = self.request("download-options", data)
|
|
207
|
+
if response_dict is None:
|
|
208
|
+
raise APIException("No response from API")
|
|
209
|
+
return response_dict["data"]
|
|
190
210
|
|
|
191
211
|
def get_download_url(self, entity_id: str, product_id: str) -> str:
|
|
192
212
|
"""Get the download URL for a given product.
|
|
@@ -204,9 +224,15 @@ class M2MAPIClient:
|
|
|
204
224
|
{"label": label, "entityId": entity_id, "productId": product_id}
|
|
205
225
|
]
|
|
206
226
|
}
|
|
207
|
-
|
|
227
|
+
response_dict = self.request("download-request", data)
|
|
228
|
+
if response_dict is None:
|
|
229
|
+
raise APIException("No response from API")
|
|
230
|
+
response = response_dict["data"]
|
|
208
231
|
while True:
|
|
209
|
-
|
|
232
|
+
response_dict = self.request("download-retrieve", {"label": label})
|
|
233
|
+
if response_dict is None:
|
|
234
|
+
raise APIException("No response from API")
|
|
235
|
+
response = response_dict["data"]
|
|
210
236
|
if len(response["available"]) > 0:
|
|
211
237
|
return response["available"][0]["url"]
|
|
212
238
|
if len(response["requested"]) == 0:
|
|
@@ -264,7 +290,7 @@ class LandsatOliTirs(DataSource):
|
|
|
264
290
|
|
|
265
291
|
def __init__(
|
|
266
292
|
self,
|
|
267
|
-
config:
|
|
293
|
+
config: RasterLayerConfig,
|
|
268
294
|
username: str,
|
|
269
295
|
password: str,
|
|
270
296
|
max_time_delta: timedelta = timedelta(days=30),
|
|
@@ -289,9 +315,10 @@ class LandsatOliTirs(DataSource):
|
|
|
289
315
|
self.client = M2MAPIClient(username, password)
|
|
290
316
|
|
|
291
317
|
@staticmethod
|
|
292
|
-
def from_config(config:
|
|
318
|
+
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "LandsatOliTirs":
|
|
293
319
|
"""Creates a new LandsatOliTirs instance from a configuration dictionary."""
|
|
294
|
-
|
|
320
|
+
if config.data_source is None:
|
|
321
|
+
raise ValueError("data_source is required")
|
|
295
322
|
d = config.data_source.config_dict
|
|
296
323
|
if "max_time_delta" in d:
|
|
297
324
|
max_time_delta = timedelta(seconds=pytimeparse.parse(d["max_time_delta"]))
|
|
@@ -328,7 +355,7 @@ class LandsatOliTirs(DataSource):
|
|
|
328
355
|
|
|
329
356
|
def get_items(
|
|
330
357
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
331
|
-
) -> list[list[list[
|
|
358
|
+
) -> list[list[list[LandsatOliTirsItem]]]:
|
|
332
359
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
333
360
|
|
|
334
361
|
Args:
|
|
@@ -400,7 +427,7 @@ class LandsatOliTirs(DataSource):
|
|
|
400
427
|
assert isinstance(serialized_item, dict)
|
|
401
428
|
return LandsatOliTirsItem.deserialize(serialized_item)
|
|
402
429
|
|
|
403
|
-
def _get_download_urls(self, item: Item) -> dict[str, str]:
|
|
430
|
+
def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
|
|
404
431
|
"""Gets the download URLs for each band.
|
|
405
432
|
|
|
406
433
|
Args:
|
|
@@ -438,7 +465,7 @@ class LandsatOliTirs(DataSource):
|
|
|
438
465
|
download_urls = self._get_download_urls(item)
|
|
439
466
|
for _, (display_id, download_url) in download_urls.items():
|
|
440
467
|
buf = io.BytesIO()
|
|
441
|
-
with requests.get(download_url, stream=True) as r:
|
|
468
|
+
with requests.get(download_url, stream=True, timeout=self.TIMEOUT) as r:
|
|
442
469
|
r.raise_for_status()
|
|
443
470
|
shutil.copyfileobj(r.raw, buf)
|
|
444
471
|
buf.seek(0)
|
|
@@ -447,7 +474,7 @@ class LandsatOliTirs(DataSource):
|
|
|
447
474
|
def ingest(
|
|
448
475
|
self,
|
|
449
476
|
tile_store: TileStore,
|
|
450
|
-
items: list[
|
|
477
|
+
items: list[LandsatOliTirsItem],
|
|
451
478
|
geometries: list[list[STGeometry]],
|
|
452
479
|
) -> None:
|
|
453
480
|
"""Ingest items into the given tile store.
|
|
@@ -471,7 +498,9 @@ class LandsatOliTirs(DataSource):
|
|
|
471
498
|
continue
|
|
472
499
|
|
|
473
500
|
buf = io.BytesIO()
|
|
474
|
-
with requests.get(
|
|
501
|
+
with requests.get(
|
|
502
|
+
download_urls[band][1], stream=True, timeout=self.TIMEOUT
|
|
503
|
+
) as r:
|
|
475
504
|
r.raise_for_status()
|
|
476
505
|
shutil.copyfileobj(r.raw, buf)
|
|
477
506
|
buf.seek(0)
|
rslearn/data_sources/utils.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Utilities shared by data sources."""
|
|
2
2
|
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
|
|
3
5
|
from rslearn.config import QueryConfig, SpaceMode, TimeMode
|
|
4
6
|
from rslearn.data_sources import Item
|
|
5
7
|
from rslearn.utils import STGeometry, shp_intersects
|
|
@@ -11,10 +13,12 @@ MOSAIC_REMAINDER_EPSILON = 0.01
|
|
|
11
13
|
"""Fraction of original geometry area below which mosaic is considered to contain the
|
|
12
14
|
entire geometry."""
|
|
13
15
|
|
|
16
|
+
ItemType = TypeVar("ItemType", bound=Item)
|
|
17
|
+
|
|
14
18
|
|
|
15
19
|
def match_candidate_items_to_window(
|
|
16
|
-
geometry: STGeometry, items: list[
|
|
17
|
-
) -> list[list[
|
|
20
|
+
geometry: STGeometry, items: list[ItemType], query_config: QueryConfig
|
|
21
|
+
) -> list[list[ItemType]]:
|
|
18
22
|
"""Match candidate items to a window based on the query configuration.
|
|
19
23
|
|
|
20
24
|
Candidate items should be collected that intersect with the window's spatial
|
|
@@ -45,17 +49,20 @@ def match_candidate_items_to_window(
|
|
|
45
49
|
items = [
|
|
46
50
|
item
|
|
47
51
|
for item in items
|
|
48
|
-
if not item.
|
|
52
|
+
if not item.geometry.time_range
|
|
53
|
+
or item.geometry.time_range[1] <= geometry.time_range[0]
|
|
49
54
|
]
|
|
50
55
|
elif query_config.time_mode == TimeMode.AFTER:
|
|
51
56
|
items = [
|
|
52
57
|
item
|
|
53
58
|
for item in items
|
|
54
|
-
if not item.time_range
|
|
55
|
-
or item.time_range[0] >= geometry.time_range[1]
|
|
59
|
+
if not item.geometry.time_range
|
|
60
|
+
or item.geometry.time_range[0] >= geometry.time_range[1]
|
|
56
61
|
]
|
|
57
62
|
items.sort(
|
|
58
|
-
key=lambda item: geometry.distance_to_time_range(
|
|
63
|
+
key=lambda item: geometry.distance_to_time_range(
|
|
64
|
+
item.geometry.time_range
|
|
65
|
+
)
|
|
59
66
|
)
|
|
60
67
|
|
|
61
68
|
# Now apply space mode.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Placeholder for a vector data source."""
|