rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +1 -1
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/models/concatenate_features.py +6 -1
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +27 -27
- rslearn/train/dataset.py +109 -62
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +3 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/regression.py +1 -1
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -4,12 +4,9 @@
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import io
|
|
7
|
-
import json
|
|
8
7
|
import os
|
|
9
8
|
import shutil
|
|
10
9
|
import tempfile
|
|
11
|
-
import time
|
|
12
|
-
import uuid
|
|
13
10
|
from collections.abc import Generator
|
|
14
11
|
from datetime import UTC, datetime, timedelta
|
|
15
12
|
from typing import Any, BinaryIO
|
|
@@ -24,246 +21,7 @@ from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
|
24
21
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
25
22
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
26
23
|
from rslearn.utils import STGeometry
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class APIException(Exception):
|
|
30
|
-
"""Exception raised for M2M API errors."""
|
|
31
|
-
|
|
32
|
-
pass
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class M2MAPIClient:
|
|
36
|
-
"""An API client for interacting with the USGS M2M API."""
|
|
37
|
-
|
|
38
|
-
api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
|
|
39
|
-
pagination_size = 1000
|
|
40
|
-
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
username: str,
|
|
44
|
-
password: str | None = None,
|
|
45
|
-
token: str | None = None,
|
|
46
|
-
timeout: timedelta = timedelta(seconds=120),
|
|
47
|
-
) -> None:
|
|
48
|
-
"""Initialize a new M2MAPIClient.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
username: the EROS username
|
|
52
|
-
password: the EROS password
|
|
53
|
-
token: the application token. One of password or token must be specified.
|
|
54
|
-
timeout: timeout for requests.
|
|
55
|
-
"""
|
|
56
|
-
self.username = username
|
|
57
|
-
self.timeout = timeout
|
|
58
|
-
|
|
59
|
-
if password is not None and token is not None:
|
|
60
|
-
raise ValueError("only one of password or token can be specified")
|
|
61
|
-
|
|
62
|
-
if password is not None:
|
|
63
|
-
json_data = json.dumps({"username": self.username, "password": password})
|
|
64
|
-
response = requests.post(
|
|
65
|
-
self.api_url + "login",
|
|
66
|
-
data=json_data,
|
|
67
|
-
timeout=self.timeout.total_seconds(),
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
elif token is not None:
|
|
71
|
-
json_data = json.dumps({"username": username, "token": token})
|
|
72
|
-
response = requests.post(
|
|
73
|
-
self.api_url + "login-token",
|
|
74
|
-
data=json_data,
|
|
75
|
-
timeout=self.timeout.total_seconds(),
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
else:
|
|
79
|
-
raise ValueError("one of password or token must be specified")
|
|
80
|
-
|
|
81
|
-
response.raise_for_status()
|
|
82
|
-
self.auth_token = response.json()["data"]
|
|
83
|
-
|
|
84
|
-
def request(
|
|
85
|
-
self, endpoint: str, data: dict[str, Any] | None = None
|
|
86
|
-
) -> dict[str, Any] | None:
|
|
87
|
-
"""Make a request to the API.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
endpoint: the endpoint to call
|
|
91
|
-
data: POST data to pass
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
JSON response data if any
|
|
95
|
-
"""
|
|
96
|
-
response = requests.post(
|
|
97
|
-
self.api_url + endpoint,
|
|
98
|
-
headers={"X-Auth-Token": self.auth_token},
|
|
99
|
-
data=json.dumps(data),
|
|
100
|
-
timeout=self.timeout.total_seconds(),
|
|
101
|
-
)
|
|
102
|
-
response.raise_for_status()
|
|
103
|
-
if response.text:
|
|
104
|
-
response_dict = response.json()
|
|
105
|
-
|
|
106
|
-
if response_dict["errorMessage"]:
|
|
107
|
-
raise APIException(response_dict["errorMessage"])
|
|
108
|
-
return response_dict
|
|
109
|
-
return None
|
|
110
|
-
|
|
111
|
-
def close(self) -> None:
|
|
112
|
-
"""Logout from the API."""
|
|
113
|
-
self.request("logout")
|
|
114
|
-
|
|
115
|
-
def __enter__(self) -> "M2MAPIClient":
|
|
116
|
-
"""Enter function to provide with semantics."""
|
|
117
|
-
return self
|
|
118
|
-
|
|
119
|
-
def __exit__(self) -> None:
|
|
120
|
-
"""Exit function to provide with semantics.
|
|
121
|
-
|
|
122
|
-
Logs out the API.
|
|
123
|
-
"""
|
|
124
|
-
self.close()
|
|
125
|
-
|
|
126
|
-
def get_filters(self, dataset_name: str) -> list[dict[str, Any]]:
|
|
127
|
-
"""Returns filters available for the given dataset.
|
|
128
|
-
|
|
129
|
-
Args:
|
|
130
|
-
dataset_name: the dataset name e.g. landsat_ot_c2_l1
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
list of filter objects
|
|
134
|
-
"""
|
|
135
|
-
response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
|
|
136
|
-
if response_dict is None:
|
|
137
|
-
raise APIException("No response from API")
|
|
138
|
-
return response_dict["data"]
|
|
139
|
-
|
|
140
|
-
def scene_search(
|
|
141
|
-
self,
|
|
142
|
-
dataset_name: str,
|
|
143
|
-
acquisition_time_range: tuple[datetime, datetime] | None = None,
|
|
144
|
-
cloud_cover_range: tuple[int, int] | None = None,
|
|
145
|
-
bbox: tuple[float, float, float, float] | None = None,
|
|
146
|
-
metadata_filter: dict[str, Any] | None = None,
|
|
147
|
-
) -> list[dict[str, Any]]:
|
|
148
|
-
"""Search for scenes matching the arguments.
|
|
149
|
-
|
|
150
|
-
Args:
|
|
151
|
-
dataset_name: the dataset name e.g. landsat_ot_c2_l1
|
|
152
|
-
acquisition_time_range: optional filter on the acquisition time
|
|
153
|
-
cloud_cover_range: optional filter on the cloud cover
|
|
154
|
-
bbox: optional spatial filter
|
|
155
|
-
metadata_filter: optional metadata filter dict
|
|
156
|
-
"""
|
|
157
|
-
base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
|
|
158
|
-
if acquisition_time_range:
|
|
159
|
-
base_data["sceneFilter"]["acquisitionFilter"] = {
|
|
160
|
-
"start": acquisition_time_range[0].isoformat(),
|
|
161
|
-
"end": acquisition_time_range[1].isoformat(),
|
|
162
|
-
}
|
|
163
|
-
if cloud_cover_range:
|
|
164
|
-
base_data["sceneFilter"]["cloudCoverFilter"] = {
|
|
165
|
-
"min": cloud_cover_range[0],
|
|
166
|
-
"max": cloud_cover_range[1],
|
|
167
|
-
"includeUnknown": False,
|
|
168
|
-
}
|
|
169
|
-
if bbox:
|
|
170
|
-
base_data["sceneFilter"]["spatialFilter"] = {
|
|
171
|
-
"filterType": "mbr",
|
|
172
|
-
"lowerLeft": {"longitude": bbox[0], "latitude": bbox[1]},
|
|
173
|
-
"upperRight": {"longitude": bbox[2], "latitude": bbox[3]},
|
|
174
|
-
}
|
|
175
|
-
if metadata_filter:
|
|
176
|
-
base_data["sceneFilter"]["metadataFilter"] = metadata_filter
|
|
177
|
-
|
|
178
|
-
starting_number = 1
|
|
179
|
-
results = []
|
|
180
|
-
while True:
|
|
181
|
-
cur_data = base_data.copy()
|
|
182
|
-
cur_data["startingNumber"] = starting_number
|
|
183
|
-
cur_data["maxResults"] = self.pagination_size
|
|
184
|
-
response_dict = self.request("scene-search", cur_data)
|
|
185
|
-
if response_dict is None:
|
|
186
|
-
raise APIException("No response from API")
|
|
187
|
-
data = response_dict["data"]
|
|
188
|
-
results.extend(data["results"])
|
|
189
|
-
if data["recordsReturned"] < self.pagination_size:
|
|
190
|
-
break
|
|
191
|
-
starting_number += self.pagination_size
|
|
192
|
-
|
|
193
|
-
return results
|
|
194
|
-
|
|
195
|
-
def get_scene_metadata(self, dataset_name: str, entity_id: str) -> dict[str, Any]:
|
|
196
|
-
"""Get detailed metadata for a scene.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
dataset_name: the dataset name in which to search
|
|
200
|
-
entity_id: the entity ID of the scene
|
|
201
|
-
|
|
202
|
-
Returns:
|
|
203
|
-
full scene metadata
|
|
204
|
-
"""
|
|
205
|
-
response_dict = self.request(
|
|
206
|
-
"scene-metadata",
|
|
207
|
-
{
|
|
208
|
-
"datasetName": dataset_name,
|
|
209
|
-
"entityId": entity_id,
|
|
210
|
-
"metadataType": "full",
|
|
211
|
-
},
|
|
212
|
-
)
|
|
213
|
-
if response_dict is None:
|
|
214
|
-
raise APIException("No response from API")
|
|
215
|
-
return response_dict["data"]
|
|
216
|
-
|
|
217
|
-
def get_downloadable_products(
|
|
218
|
-
self, dataset_name: str, entity_id: str
|
|
219
|
-
) -> list[dict[str, Any]]:
|
|
220
|
-
"""Get the downloadable products for a given scene.
|
|
221
|
-
|
|
222
|
-
Args:
|
|
223
|
-
dataset_name: the dataset name
|
|
224
|
-
entity_id: the entity ID of the scene
|
|
225
|
-
|
|
226
|
-
Returns:
|
|
227
|
-
list of downloadable products
|
|
228
|
-
"""
|
|
229
|
-
data = {"datasetName": dataset_name, "entityIds": [entity_id]}
|
|
230
|
-
response_dict = self.request("download-options", data)
|
|
231
|
-
if response_dict is None:
|
|
232
|
-
raise APIException("No response from API")
|
|
233
|
-
return response_dict["data"]
|
|
234
|
-
|
|
235
|
-
def get_download_url(self, entity_id: str, product_id: str) -> str:
|
|
236
|
-
"""Get the download URL for a given product.
|
|
237
|
-
|
|
238
|
-
Args:
|
|
239
|
-
entity_id: the entity ID of the product
|
|
240
|
-
product_id: the product ID of the product
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
the download URL
|
|
244
|
-
"""
|
|
245
|
-
label = str(uuid.uuid4())
|
|
246
|
-
data = {
|
|
247
|
-
"downloads": [
|
|
248
|
-
{"label": label, "entityId": entity_id, "productId": product_id}
|
|
249
|
-
]
|
|
250
|
-
}
|
|
251
|
-
response_dict = self.request("download-request", data)
|
|
252
|
-
if response_dict is None:
|
|
253
|
-
raise APIException("No response from API")
|
|
254
|
-
response = response_dict["data"]
|
|
255
|
-
while True:
|
|
256
|
-
response_dict = self.request("download-retrieve", {"label": label})
|
|
257
|
-
if response_dict is None:
|
|
258
|
-
raise APIException("No response from API")
|
|
259
|
-
response = response_dict["data"]
|
|
260
|
-
if len(response["available"]) > 0:
|
|
261
|
-
return response["available"][0]["url"]
|
|
262
|
-
if len(response["requested"]) == 0:
|
|
263
|
-
raise Exception("Did not get download URL")
|
|
264
|
-
if response["requested"][0].get("url"):
|
|
265
|
-
return response["requested"][0]["url"]
|
|
266
|
-
time.sleep(10)
|
|
24
|
+
from rslearn.utils.m2m_api import APIException, M2MAPIClient
|
|
267
25
|
|
|
268
26
|
|
|
269
27
|
class LandsatOliTirsItem(Item):
|
|
@@ -314,30 +72,26 @@ class LandsatOliTirs(DataSource):
|
|
|
314
72
|
|
|
315
73
|
def __init__(
|
|
316
74
|
self,
|
|
317
|
-
username: str,
|
|
318
|
-
sort_by: str | None = None,
|
|
319
|
-
password: str | None = None,
|
|
75
|
+
username: str | None = None,
|
|
320
76
|
token: str | None = None,
|
|
77
|
+
sort_by: str | None = None,
|
|
321
78
|
timeout: timedelta = timedelta(seconds=10),
|
|
322
79
|
context: DataSourceContext = DataSourceContext(),
|
|
323
80
|
):
|
|
324
81
|
"""Initialize a new LandsatOliTirs instance.
|
|
325
82
|
|
|
326
83
|
Args:
|
|
327
|
-
username: EROS username
|
|
84
|
+
username: EROS username (see M2MAPIClient).
|
|
85
|
+
token: EROS application token (see M2MAPIClient).
|
|
328
86
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
329
87
|
SpaceMode.WITHIN.
|
|
330
|
-
password: EROS password (see M2MAPIClient).
|
|
331
|
-
token: EROS application token (see M2MAPIClient).
|
|
332
88
|
timeout: timeout for requests.
|
|
333
89
|
context: the data source context.
|
|
334
90
|
"""
|
|
335
91
|
self.sort_by = sort_by
|
|
336
92
|
self.timeout = timeout
|
|
337
93
|
|
|
338
|
-
self.client = M2MAPIClient(
|
|
339
|
-
username, password=password, token=token, timeout=timeout
|
|
340
|
-
)
|
|
94
|
+
self.client = M2MAPIClient(username=username, token=token, timeout=timeout)
|
|
341
95
|
|
|
342
96
|
def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
|
|
343
97
|
"""Convert scene metadata from the API to a LandsatOliTirsItem."""
|
|
@@ -429,9 +183,8 @@ class LandsatOliTirs(DataSource):
|
|
|
429
183
|
)
|
|
430
184
|
return self._scene_metadata_to_item(scene_metadata)
|
|
431
185
|
|
|
432
|
-
def deserialize_item(self, serialized_item:
|
|
186
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
433
187
|
"""Deserializes an item from JSON-decoded data."""
|
|
434
|
-
assert isinstance(serialized_item, dict)
|
|
435
188
|
return LandsatOliTirsItem.deserialize(serialized_item)
|
|
436
189
|
|
|
437
190
|
def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
|
|
@@ -291,7 +291,7 @@ class WorldCereal(LocalFiles):
|
|
|
291
291
|
raise ValueError(f"No AEZ files found for {self.band}")
|
|
292
292
|
|
|
293
293
|
super().__init__(
|
|
294
|
-
src_dir=tif_dir,
|
|
294
|
+
src_dir=tif_dir.absolute().as_uri(),
|
|
295
295
|
raster_item_specs=item_specs,
|
|
296
296
|
layer_type=LayerType.RASTER,
|
|
297
297
|
context=context,
|
rslearn/data_sources/worldpop.py
CHANGED
|
@@ -80,7 +80,7 @@ class WorldPop(LocalFiles):
|
|
|
80
80
|
worldpop_upath.mkdir(parents=True, exist_ok=True)
|
|
81
81
|
self.download_worldpop_data(worldpop_upath, timeout)
|
|
82
82
|
super().__init__(
|
|
83
|
-
src_dir=worldpop_upath,
|
|
83
|
+
src_dir=worldpop_upath.absolute().as_uri(),
|
|
84
84
|
layer_type=LayerType.RASTER,
|
|
85
85
|
context=context,
|
|
86
86
|
)
|
|
@@ -19,7 +19,7 @@ from rslearn.config import LayerConfig, QueryConfig
|
|
|
19
19
|
from rslearn.dataset import Window
|
|
20
20
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
21
21
|
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
22
|
-
from rslearn.utils import PixelBounds, Projection, STGeometry
|
|
22
|
+
from rslearn.utils import PixelBounds, Projection, STGeometry, get_global_raster_bounds
|
|
23
23
|
from rslearn.utils.array import copy_spatial_array
|
|
24
24
|
from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
|
|
25
25
|
|
|
@@ -184,7 +184,7 @@ class XyzTiles(DataSource, TileStore):
|
|
|
184
184
|
groups.append(cur_groups)
|
|
185
185
|
return groups
|
|
186
186
|
|
|
187
|
-
def deserialize_item(self, serialized_item:
|
|
187
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
188
188
|
"""Deserializes an item from JSON-decoded data."""
|
|
189
189
|
return Item.deserialize(serialized_item)
|
|
190
190
|
|
|
@@ -278,13 +278,9 @@ class XyzTiles(DataSource, TileStore):
|
|
|
278
278
|
Returns:
|
|
279
279
|
the bounds of the raster in the projection.
|
|
280
280
|
"""
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
int(geom.shp.bounds[1]),
|
|
285
|
-
int(geom.shp.bounds[2]),
|
|
286
|
-
int(geom.shp.bounds[3]),
|
|
287
|
-
)
|
|
281
|
+
# XyzTiles is a global data source, so we return global raster bounds based on
|
|
282
|
+
# the projection.
|
|
283
|
+
return get_global_raster_bounds(projection)
|
|
288
284
|
|
|
289
285
|
def read_raster(
|
|
290
286
|
self,
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from einops import rearrange
|
|
6
7
|
|
|
7
8
|
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
@@ -79,7 +80,11 @@ class ConcatenateFeatures(IntermediateComponent):
|
|
|
79
80
|
)
|
|
80
81
|
|
|
81
82
|
add_data = torch.stack(
|
|
82
|
-
[
|
|
83
|
+
[
|
|
84
|
+
rearrange(input_data[self.key].image, "c t h w -> (c t) h w")
|
|
85
|
+
for input_data in context.inputs
|
|
86
|
+
],
|
|
87
|
+
dim=0,
|
|
83
88
|
)
|
|
84
89
|
add_features = self.conv_layers(add_data)
|
|
85
90
|
|