rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +1 -1
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/models/concatenate_features.py +6 -1
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +27 -27
- rslearn/train/dataset.py +109 -62
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +3 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/regression.py +1 -1
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
"""Global SRTM void-filled elevation data from USGS, mirrored on Hugging Face by AI2.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
1. A bulk download utility to fetch SRTM data from USGS EarthExplorer via the M2M API.
|
|
5
|
+
This can be used to initialize a mirror of the data.
|
|
6
|
+
2. A data source that will pull from the AI2 Hugging Face mirror.
|
|
7
|
+
|
|
8
|
+
The SRTM dataset in USGS EarthExplorer is "srtm_v2" which contains void-filled elevation
|
|
9
|
+
data from the Shuttle Radar Topography Mission. The bulk download fetches the highest
|
|
10
|
+
resolution available: 1 arc-second (~30m) in the US and 3 arc-second (~90m) globally.
|
|
11
|
+
|
|
12
|
+
See https://www.usgs.gov/centers/eros/science/usgs-eros-archive-digital-elevation-shuttle-radar-topography-mission-srtm
|
|
13
|
+
for details.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import json
|
|
18
|
+
import math
|
|
19
|
+
import multiprocessing
|
|
20
|
+
import os
|
|
21
|
+
import re
|
|
22
|
+
import shutil
|
|
23
|
+
import tempfile
|
|
24
|
+
from datetime import timedelta
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
import requests
|
|
28
|
+
import shapely
|
|
29
|
+
from upath import UPath
|
|
30
|
+
|
|
31
|
+
from rslearn.config import QueryConfig, SpaceMode
|
|
32
|
+
from rslearn.const import WGS84_PROJECTION
|
|
33
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
34
|
+
from rslearn.log_utils import get_logger
|
|
35
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
36
|
+
from rslearn.utils.fsspec import join_upath, open_atomic
|
|
37
|
+
from rslearn.utils.geometry import STGeometry
|
|
38
|
+
from rslearn.utils.m2m_api import M2MAPIClient
|
|
39
|
+
from rslearn.utils.mp import star_imap_unordered
|
|
40
|
+
from rslearn.utils.retry_session import create_retry_session
|
|
41
|
+
|
|
42
|
+
logger = get_logger(__name__)
|
|
43
|
+
|
|
44
|
+
# SRTM dataset name in USGS EarthExplorer M2M API
|
|
45
|
+
SRTM_DATASET_NAME = "srtm_v2"
|
|
46
|
+
|
|
47
|
+
# Product names in order of preference (highest resolution first)
|
|
48
|
+
# 1 Arc-second is only available in the US, 3 Arc-second is global
|
|
49
|
+
SRTM_PRODUCT_NAMES = ["GeoTIFF 1 Arc-second", "GeoTIFF 3 Arc-second"]
|
|
50
|
+
|
|
51
|
+
# SRTM covers latitude -60 to 60
|
|
52
|
+
SRTM_LAT_MIN = -60
|
|
53
|
+
SRTM_LAT_MAX = 60
|
|
54
|
+
|
|
55
|
+
# Cache filename for scene list
|
|
56
|
+
SCENE_CACHE_FILENAME = "scenes.json"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SRTM(DataSource):
|
|
60
|
+
"""Data source for SRTM elevation data from the AI2 Hugging Face mirror.
|
|
61
|
+
|
|
62
|
+
The data is split into 1x1-degree tiles, with filenames like:
|
|
63
|
+
N05/SRTM1N05W163V2.tif (1 arc-second, ~30m resolution)
|
|
64
|
+
N05/SRTM3N05W163V2.tif (3 arc-second, ~90m resolution)
|
|
65
|
+
|
|
66
|
+
SRTM1 (1 arc-second) is available for some regions (primarily US territories),
|
|
67
|
+
while SRTM3 (3 arc-second) is available globally. By default, SRTM1 is preferred
|
|
68
|
+
when available for higher resolution. Set always_use_3arcsecond=True to always
|
|
69
|
+
use the lower resolution SRTM3 data for consistency.
|
|
70
|
+
|
|
71
|
+
Items from this data source do not come with a time range. The band name will match
|
|
72
|
+
that specified in the band set, which should have a single band (e.g. "dem").
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
BASE_URL = (
|
|
76
|
+
"https://huggingface.co/datasets/allenai/srtm-global-void-filled/resolve/main/"
|
|
77
|
+
)
|
|
78
|
+
FILE_LIST_FILENAME = "file_list.json"
|
|
79
|
+
FILENAME_SUFFIX = "V2.tif"
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
timeout: timedelta = timedelta(seconds=10),
|
|
84
|
+
cache_dir: str | None = None,
|
|
85
|
+
always_use_3arcsecond: bool = False,
|
|
86
|
+
context: DataSourceContext = DataSourceContext(),
|
|
87
|
+
):
|
|
88
|
+
"""Initialize a new SRTM instance.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
timeout: timeout for requests.
|
|
92
|
+
cache_dir: optional directory to cache the file list.
|
|
93
|
+
always_use_3arcsecond: if True, always use 3 arc-second (SRTM3) data even
|
|
94
|
+
when 1 arc-second (SRTM1) is available. Defaults to False, which
|
|
95
|
+
prefers SRTM1 for higher resolution when available.
|
|
96
|
+
context: the data source context.
|
|
97
|
+
"""
|
|
98
|
+
# Get band name from context if possible, falling back to "dem".
|
|
99
|
+
if context.layer_config is not None:
|
|
100
|
+
if len(context.layer_config.band_sets) != 1:
|
|
101
|
+
raise ValueError("expected a single band set")
|
|
102
|
+
if len(context.layer_config.band_sets[0].bands) != 1:
|
|
103
|
+
raise ValueError("expected band set to have a single band")
|
|
104
|
+
self.band_name = context.layer_config.band_sets[0].bands[0]
|
|
105
|
+
else:
|
|
106
|
+
self.band_name = "dem"
|
|
107
|
+
|
|
108
|
+
self.timeout = timeout
|
|
109
|
+
self.session = requests.session()
|
|
110
|
+
self.always_use_3arcsecond = always_use_3arcsecond
|
|
111
|
+
|
|
112
|
+
# Set the cache path if a cache_dir is provided.
|
|
113
|
+
self.file_list_cache_path: UPath | None = None
|
|
114
|
+
if cache_dir is not None:
|
|
115
|
+
if context.ds_path is not None:
|
|
116
|
+
cache_root = join_upath(context.ds_path, cache_dir)
|
|
117
|
+
else:
|
|
118
|
+
cache_root = UPath(cache_dir)
|
|
119
|
+
cache_root.mkdir(parents=True, exist_ok=True)
|
|
120
|
+
self.file_list_cache_path = join_upath(cache_root, self.FILE_LIST_FILENAME)
|
|
121
|
+
|
|
122
|
+
self._basename_to_item, self._tile_to_item = self._load_file_index()
|
|
123
|
+
|
|
124
|
+
def _load_file_index(
|
|
125
|
+
self,
|
|
126
|
+
) -> tuple[dict[str, Item], dict[tuple[int, int], Item]]:
|
|
127
|
+
"""Load the file list and build indices for lookups."""
|
|
128
|
+
file_list = self._load_file_list_json()
|
|
129
|
+
if not isinstance(file_list, list):
|
|
130
|
+
raise ValueError("expected file_list.json to be a list of filenames")
|
|
131
|
+
|
|
132
|
+
basename_to_item: dict[str, Item] = {}
|
|
133
|
+
tile_to_item: dict[tuple[int, int], Item] = {}
|
|
134
|
+
|
|
135
|
+
for entry in file_list:
|
|
136
|
+
if not isinstance(entry, str):
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"expected file_list.json to contain only string filenames"
|
|
139
|
+
)
|
|
140
|
+
basename = os.path.basename(entry)
|
|
141
|
+
|
|
142
|
+
# Check if this is SRTM1 or SRTM3 based on filename prefix
|
|
143
|
+
is_srtm1 = basename.startswith("SRTM1")
|
|
144
|
+
|
|
145
|
+
# Skip SRTM1 files if always_use_3arcsecond is enabled
|
|
146
|
+
if self.always_use_3arcsecond and is_srtm1:
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
lat_min, lon_min = self._parse_tile_basename(basename)
|
|
150
|
+
geometry = STGeometry(
|
|
151
|
+
WGS84_PROJECTION,
|
|
152
|
+
shapely.box(lon_min, lat_min, lon_min + 1, lat_min + 1),
|
|
153
|
+
None,
|
|
154
|
+
)
|
|
155
|
+
item = Item(entry, geometry)
|
|
156
|
+
|
|
157
|
+
key = (lon_min, lat_min)
|
|
158
|
+
|
|
159
|
+
# For tile_to_item, prefer SRTM1 over SRTM3 when not using always_use_3arcsecond
|
|
160
|
+
if key in tile_to_item:
|
|
161
|
+
existing_is_srtm1 = os.path.basename(tile_to_item[key].name).startswith(
|
|
162
|
+
"SRTM1"
|
|
163
|
+
)
|
|
164
|
+
# Only replace if current is SRTM1 and existing is not
|
|
165
|
+
if is_srtm1 and not existing_is_srtm1:
|
|
166
|
+
tile_to_item[key] = item
|
|
167
|
+
# Keep existing if it's SRTM1 and current is not
|
|
168
|
+
elif existing_is_srtm1 and not is_srtm1:
|
|
169
|
+
pass
|
|
170
|
+
else:
|
|
171
|
+
# Same type, keep the existing one
|
|
172
|
+
pass
|
|
173
|
+
else:
|
|
174
|
+
tile_to_item[key] = item
|
|
175
|
+
|
|
176
|
+
if basename not in basename_to_item:
|
|
177
|
+
basename_to_item[basename] = item
|
|
178
|
+
|
|
179
|
+
logger.info(
|
|
180
|
+
f"Loaded {len(tile_to_item)} SRTM tiles from Hugging Face file list"
|
|
181
|
+
)
|
|
182
|
+
return basename_to_item, tile_to_item
|
|
183
|
+
|
|
184
|
+
def _load_file_list_json(self) -> list[Any]:
|
|
185
|
+
"""Load file list JSON, optionally from cache."""
|
|
186
|
+
if self.file_list_cache_path is not None and self.file_list_cache_path.exists():
|
|
187
|
+
with self.file_list_cache_path.open() as f:
|
|
188
|
+
file_list = json.load(f)
|
|
189
|
+
logger.info(f"Loaded SRTM file list cache from {self.file_list_cache_path}")
|
|
190
|
+
return file_list
|
|
191
|
+
|
|
192
|
+
response = self.session.get(
|
|
193
|
+
self.BASE_URL + self.FILE_LIST_FILENAME,
|
|
194
|
+
timeout=self.timeout.total_seconds(),
|
|
195
|
+
)
|
|
196
|
+
response.raise_for_status()
|
|
197
|
+
file_list = response.json()
|
|
198
|
+
if self.file_list_cache_path is not None:
|
|
199
|
+
with open_atomic(self.file_list_cache_path, "w") as f:
|
|
200
|
+
json.dump(file_list, f)
|
|
201
|
+
return file_list
|
|
202
|
+
|
|
203
|
+
def _parse_tile_basename(self, basename: str) -> tuple[int, int]:
|
|
204
|
+
"""Parse a tile basename into (lat_min, lon_min)."""
|
|
205
|
+
match = re.match(
|
|
206
|
+
r"^SRTM\d([NS])(\d{2})([EW])(\d{3})V2\.tif$",
|
|
207
|
+
basename,
|
|
208
|
+
)
|
|
209
|
+
if match is None:
|
|
210
|
+
raise ValueError(f"invalid SRTM tile filename: {basename}")
|
|
211
|
+
|
|
212
|
+
lat_sign, lat_str, lon_sign, lon_str = match.groups()
|
|
213
|
+
lat_degrees = int(lat_str)
|
|
214
|
+
lon_degrees = int(lon_str)
|
|
215
|
+
|
|
216
|
+
if lat_sign == "N":
|
|
217
|
+
lat_min = lat_degrees
|
|
218
|
+
elif lat_sign == "S":
|
|
219
|
+
lat_min = -lat_degrees
|
|
220
|
+
else:
|
|
221
|
+
raise ValueError(f"invalid SRTM tile latitude for filename: {basename}")
|
|
222
|
+
|
|
223
|
+
if lon_sign == "E":
|
|
224
|
+
lon_min = lon_degrees
|
|
225
|
+
elif lon_sign == "W":
|
|
226
|
+
lon_min = -lon_degrees
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError(f"invalid SRTM tile longitude for filename: {basename}")
|
|
229
|
+
|
|
230
|
+
return lat_min, lon_min
|
|
231
|
+
|
|
232
|
+
def get_item_by_name(self, name: str) -> Item:
|
|
233
|
+
"""Gets an item by name.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
name: the name of the item to get. For SRTM, the item name is the filename
|
|
237
|
+
of the GeoTIFF tile.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
the Item object
|
|
241
|
+
"""
|
|
242
|
+
basename = os.path.basename(name)
|
|
243
|
+
item = self._basename_to_item.get(basename)
|
|
244
|
+
if item is not None:
|
|
245
|
+
return item
|
|
246
|
+
raise ValueError(f"unknown SRTM tile name: {name}")
|
|
247
|
+
|
|
248
|
+
def get_items(
|
|
249
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
250
|
+
) -> list[list[list[Item]]]:
|
|
251
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
geometries: the spatiotemporal geometries
|
|
255
|
+
query_config: the query configuration
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
List of groups of items that should be retrieved for each geometry.
|
|
259
|
+
"""
|
|
260
|
+
# It only makes sense to create mosaic from the SRTM data since it is one
|
|
261
|
+
# global layer (but spatially split up into items).
|
|
262
|
+
if query_config.space_mode != SpaceMode.MOSAIC or query_config.max_matches != 1:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"expected mosaic with max_matches=1 for the query configuration"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
groups = []
|
|
268
|
+
for geometry in geometries:
|
|
269
|
+
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
270
|
+
shp_bounds = wgs84_geometry.shp.bounds
|
|
271
|
+
cell_bounds = (
|
|
272
|
+
math.floor(shp_bounds[0]),
|
|
273
|
+
math.floor(shp_bounds[1]),
|
|
274
|
+
math.ceil(shp_bounds[2]),
|
|
275
|
+
math.ceil(shp_bounds[3]),
|
|
276
|
+
)
|
|
277
|
+
items = []
|
|
278
|
+
for lon_min in range(cell_bounds[0], cell_bounds[2]):
|
|
279
|
+
for lat_min in range(cell_bounds[1], cell_bounds[3]):
|
|
280
|
+
if (lon_min, lat_min) not in self._tile_to_item:
|
|
281
|
+
continue
|
|
282
|
+
items.append(self._tile_to_item[(lon_min, lat_min)])
|
|
283
|
+
|
|
284
|
+
logger.debug(f"Got {len(items)} items (grid cells) for geometry")
|
|
285
|
+
groups.append([items])
|
|
286
|
+
|
|
287
|
+
return groups
|
|
288
|
+
|
|
289
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
290
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
291
|
+
return Item.deserialize(serialized_item)
|
|
292
|
+
|
|
293
|
+
def ingest(
|
|
294
|
+
self,
|
|
295
|
+
tile_store: TileStoreWithLayer,
|
|
296
|
+
items: list[Item],
|
|
297
|
+
geometries: list[list[STGeometry]],
|
|
298
|
+
) -> None:
|
|
299
|
+
"""Ingest items into the given tile store.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
tile_store: the tile store to ingest into
|
|
303
|
+
items: the items to ingest
|
|
304
|
+
geometries: a list of geometries needed for each item
|
|
305
|
+
"""
|
|
306
|
+
for item in items:
|
|
307
|
+
if tile_store.is_raster_ready(item.name, [self.band_name]):
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
url = self.BASE_URL + item.name
|
|
311
|
+
logger.debug(f"Downloading SRTM data for {item.name} from {url}")
|
|
312
|
+
response = self.session.get(
|
|
313
|
+
url, stream=True, timeout=self.timeout.total_seconds()
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
if response.status_code == 404:
|
|
317
|
+
logger.warning(
|
|
318
|
+
f"Skipping item {item.name} because there is no data at that cell"
|
|
319
|
+
)
|
|
320
|
+
continue
|
|
321
|
+
response.raise_for_status()
|
|
322
|
+
|
|
323
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
324
|
+
local_fname = os.path.join(tmp_dir, "data.tif")
|
|
325
|
+
with open(local_fname, "wb") as f:
|
|
326
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
327
|
+
f.write(chunk)
|
|
328
|
+
|
|
329
|
+
logger.debug(f"Ingesting data for {item.name}")
|
|
330
|
+
tile_store.write_raster_file(
|
|
331
|
+
item.name, [self.band_name], UPath(local_fname)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# The code below is for creating the mirror by downloading SRTM GeoTIFFs from the USGS
|
|
336
|
+
# EarthExplorer M2M API.
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@functools.cache
|
|
340
|
+
def _get_cached_m2m_client(timeout: timedelta) -> M2MAPIClient:
|
|
341
|
+
"""Get a cached M2M API client for this process.
|
|
342
|
+
|
|
343
|
+
The client is cached per process, so each worker reuses its client across
|
|
344
|
+
multiple download tasks.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
timeout: timeout for API requests
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
M2M API client
|
|
351
|
+
"""
|
|
352
|
+
session = create_retry_session()
|
|
353
|
+
return M2MAPIClient(timeout=timeout, session=session)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _worker_download_scene(
|
|
357
|
+
scene: dict[str, Any],
|
|
358
|
+
output_path: str,
|
|
359
|
+
timeout: timedelta,
|
|
360
|
+
) -> str:
|
|
361
|
+
"""Worker function for downloading a single SRTM GeoTIFF file.
|
|
362
|
+
|
|
363
|
+
This function gets the download URL from the M2M API and downloads the file.
|
|
364
|
+
The M2M client is cached per worker process to avoid repeated logins.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
scene: scene metadata from scene_search
|
|
368
|
+
output_path: path to save the downloaded file
|
|
369
|
+
timeout: timeout for API requests and download
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
the output path of the downloaded file
|
|
373
|
+
"""
|
|
374
|
+
entity_id = scene["entityId"]
|
|
375
|
+
display_id = scene["displayId"]
|
|
376
|
+
|
|
377
|
+
logger.debug(f"Starting download for {display_id}")
|
|
378
|
+
|
|
379
|
+
# Ensure output directory exists
|
|
380
|
+
output_dir = os.path.dirname(output_path)
|
|
381
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
382
|
+
|
|
383
|
+
# Get cached M2M client for this worker process
|
|
384
|
+
client = _get_cached_m2m_client(timeout)
|
|
385
|
+
|
|
386
|
+
# Get downloadable products for this scene
|
|
387
|
+
logger.debug(f"Getting products for {display_id}")
|
|
388
|
+
products = client.get_downloadable_products(SRTM_DATASET_NAME, entity_id)
|
|
389
|
+
|
|
390
|
+
# Build a map of available products
|
|
391
|
+
available_products = {}
|
|
392
|
+
for product in products:
|
|
393
|
+
if product.get("available", False) and product.get("id"):
|
|
394
|
+
available_products[product["productName"]] = product["id"]
|
|
395
|
+
|
|
396
|
+
# Try products in order of preference (highest resolution first)
|
|
397
|
+
download_url = None
|
|
398
|
+
for product_name in SRTM_PRODUCT_NAMES:
|
|
399
|
+
if product_name in available_products:
|
|
400
|
+
product_id = available_products[product_name]
|
|
401
|
+
logger.debug(f"Getting download URL for {display_id} ({product_name})")
|
|
402
|
+
download_url = client.get_download_url(entity_id, product_id)
|
|
403
|
+
break
|
|
404
|
+
|
|
405
|
+
if download_url is None:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"No GeoTIFF product found for scene {display_id}. "
|
|
408
|
+
f"Available products: {list(available_products.keys())}"
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Download with atomic write using retry session
|
|
412
|
+
logger.debug(f"Downloading file for {display_id}")
|
|
413
|
+
with client.session.get(
|
|
414
|
+
download_url, stream=True, timeout=timeout.total_seconds()
|
|
415
|
+
) as r:
|
|
416
|
+
r.raise_for_status()
|
|
417
|
+
with open_atomic(UPath(output_path), "wb") as f:
|
|
418
|
+
shutil.copyfileobj(r.raw, f)
|
|
419
|
+
|
|
420
|
+
return output_path
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _fetch_all_scenes(
|
|
424
|
+
client: M2MAPIClient,
|
|
425
|
+
cache_dir: str,
|
|
426
|
+
) -> dict[str, dict[str, Any]]:
|
|
427
|
+
"""Fetch all SRTM scenes, using a cached scene list if available.
|
|
428
|
+
|
|
429
|
+
SRTM data is organized in 1x1 degree tiles covering latitude -60 to 60.
|
|
430
|
+
We iterate over 10x10 degree boxes to avoid timeout issues with global search.
|
|
431
|
+
|
|
432
|
+
This fetches both SRTM1 (1 arc-second, US only) and SRTM3 (3 arc-second, global)
|
|
433
|
+
scenes. The download function will select the highest resolution available.
|
|
434
|
+
|
|
435
|
+
The scene list is cached in scenes.json in the cache directory. If the cache
|
|
436
|
+
exists, it is loaded instead of querying the API.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
client: M2M API client
|
|
440
|
+
cache_dir: directory where cache file is stored
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
dict mapping display_id to scene metadata
|
|
444
|
+
"""
|
|
445
|
+
cache_path = os.path.join(cache_dir, SCENE_CACHE_FILENAME)
|
|
446
|
+
scenes: dict[str, dict[str, Any]] = {}
|
|
447
|
+
|
|
448
|
+
# Try to load from cache
|
|
449
|
+
if os.path.exists(cache_path):
|
|
450
|
+
with open(cache_path) as f:
|
|
451
|
+
scenes = json.load(f)
|
|
452
|
+
logger.info(f"Loaded {len(scenes)} scenes from cache")
|
|
453
|
+
return scenes
|
|
454
|
+
|
|
455
|
+
# Fetch from API by iterating over 10x10 degree boxes
|
|
456
|
+
logger.info("No cached scene list found, fetching from API...")
|
|
457
|
+
box_size = 10
|
|
458
|
+
total_boxes = ((SRTM_LAT_MAX - SRTM_LAT_MIN) // box_size) * (360 // box_size)
|
|
459
|
+
|
|
460
|
+
box_idx = 0
|
|
461
|
+
for lat in range(SRTM_LAT_MIN, SRTM_LAT_MAX, box_size):
|
|
462
|
+
for lon in range(-180, 180, box_size):
|
|
463
|
+
box_idx += 1
|
|
464
|
+
bbox = (lon, lat, lon + box_size, lat + box_size)
|
|
465
|
+
|
|
466
|
+
results = client.scene_search(SRTM_DATASET_NAME, bbox=bbox)
|
|
467
|
+
for scene in results:
|
|
468
|
+
display_id = scene["displayId"]
|
|
469
|
+
if display_id not in scenes:
|
|
470
|
+
scenes[display_id] = scene
|
|
471
|
+
|
|
472
|
+
logger.info(
|
|
473
|
+
f"Searched {box_idx}/{total_boxes} boxes, "
|
|
474
|
+
f"found {len(scenes)} unique scenes so far"
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Save to cache
|
|
478
|
+
with open(cache_path, "w") as f:
|
|
479
|
+
json.dump(scenes, f)
|
|
480
|
+
logger.info(f"Cached {len(scenes)} scenes to {SCENE_CACHE_FILENAME}")
|
|
481
|
+
|
|
482
|
+
return scenes
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def bulk_download_srtm(
|
|
486
|
+
output_dir: str,
|
|
487
|
+
num_workers: int = 4,
|
|
488
|
+
timeout: timedelta = timedelta(minutes=5),
|
|
489
|
+
) -> None:
|
|
490
|
+
"""Bulk download SRTM data from USGS EarthExplorer.
|
|
491
|
+
|
|
492
|
+
Downloads all SRTM tiles to the specified output directory. Uses atomic
|
|
493
|
+
rename to ensure partially downloaded files are not included. Files that
|
|
494
|
+
already exist in the output directory are skipped.
|
|
495
|
+
|
|
496
|
+
The scene list is cached in scenes.json in the output directory to avoid
|
|
497
|
+
re-querying on subsequent runs.
|
|
498
|
+
|
|
499
|
+
Requires M2M_USERNAME and M2M_TOKEN environment variables to be set.
|
|
500
|
+
|
|
501
|
+
Args:
|
|
502
|
+
output_dir: directory to save downloaded files
|
|
503
|
+
num_workers: number of parallel download workers
|
|
504
|
+
timeout: timeout for API requests and downloads
|
|
505
|
+
"""
|
|
506
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
507
|
+
|
|
508
|
+
client = _get_cached_m2m_client(timeout)
|
|
509
|
+
scenes = _fetch_all_scenes(client, output_dir)
|
|
510
|
+
|
|
511
|
+
# Filter out scenes that are already downloaded
|
|
512
|
+
download_tasks = []
|
|
513
|
+
skipped = 0
|
|
514
|
+
for display_id, scene in scenes.items():
|
|
515
|
+
# Use display_id as the filename with .tif extension for GeoTIFF
|
|
516
|
+
output_path = os.path.join(output_dir, f"{display_id}.tif")
|
|
517
|
+
if os.path.exists(output_path):
|
|
518
|
+
logger.debug(f"Skipping {display_id} - already downloaded")
|
|
519
|
+
skipped += 1
|
|
520
|
+
else:
|
|
521
|
+
download_tasks.append(
|
|
522
|
+
{
|
|
523
|
+
"scene": scene,
|
|
524
|
+
"output_path": output_path,
|
|
525
|
+
"timeout": timeout,
|
|
526
|
+
}
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
logger.info(
|
|
530
|
+
f"Need to download {len(download_tasks)} scenes ({skipped} already downloaded)"
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
if not download_tasks:
|
|
534
|
+
logger.info("All scenes already downloaded!")
|
|
535
|
+
return
|
|
536
|
+
|
|
537
|
+
# Download in parallel using multiprocessing.Pool
|
|
538
|
+
# Each worker creates its own M2M API client to get download URLs
|
|
539
|
+
logger.info(f"Starting downloads with {num_workers} workers...")
|
|
540
|
+
with multiprocessing.Pool(num_workers) as pool:
|
|
541
|
+
for output_path in star_imap_unordered(
|
|
542
|
+
pool, _worker_download_scene, download_tasks
|
|
543
|
+
):
|
|
544
|
+
logger.info(f"Downloaded {output_path}")
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def main() -> None:
|
|
548
|
+
"""Command-line entry point for bulk SRTM download.
|
|
549
|
+
|
|
550
|
+
Requires M2M_USERNAME and M2M_TOKEN environment variables to be set.
|
|
551
|
+
"""
|
|
552
|
+
import argparse
|
|
553
|
+
import logging
|
|
554
|
+
|
|
555
|
+
parser = argparse.ArgumentParser(
|
|
556
|
+
description="Bulk download SRTM data from USGS EarthExplorer. "
|
|
557
|
+
"Requires M2M_USERNAME and M2M_TOKEN environment variables."
|
|
558
|
+
)
|
|
559
|
+
parser.add_argument(
|
|
560
|
+
"output_dir",
|
|
561
|
+
help="Directory to save downloaded SRTM files",
|
|
562
|
+
)
|
|
563
|
+
parser.add_argument(
|
|
564
|
+
"--workers",
|
|
565
|
+
type=int,
|
|
566
|
+
default=4,
|
|
567
|
+
help="Number of parallel download workers (default: 4)",
|
|
568
|
+
)
|
|
569
|
+
parser.add_argument(
|
|
570
|
+
"--timeout",
|
|
571
|
+
type=int,
|
|
572
|
+
default=30,
|
|
573
|
+
help="Timeout in seconds for API requests and downloads (default: 30)",
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
args = parser.parse_args()
|
|
577
|
+
|
|
578
|
+
# Configure logging based on RSLEARN_LOGLEVEL
|
|
579
|
+
log_level = os.environ.get("RSLEARN_LOGLEVEL", "INFO")
|
|
580
|
+
logging.basicConfig(
|
|
581
|
+
level=log_level,
|
|
582
|
+
format="%(asctime)s %(levelname)-6s %(name)s %(message)s",
|
|
583
|
+
)
|
|
584
|
+
# Enable urllib3 logging to see retry information at DEBUG level
|
|
585
|
+
logging.getLogger("urllib3").setLevel(log_level)
|
|
586
|
+
|
|
587
|
+
bulk_download_srtm(
|
|
588
|
+
output_dir=args.output_dir,
|
|
589
|
+
num_workers=args.workers,
|
|
590
|
+
timeout=timedelta(seconds=args.timeout),
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
if __name__ == "__main__":
|
|
595
|
+
main()
|
|
@@ -490,7 +490,7 @@ class LocalFiles(DataSource):
|
|
|
490
490
|
groups.append(cur_groups)
|
|
491
491
|
return groups
|
|
492
492
|
|
|
493
|
-
def deserialize_item(self, serialized_item:
|
|
493
|
+
def deserialize_item(self, serialized_item: dict) -> RasterItem | VectorItem:
|
|
494
494
|
"""Deserializes an item from JSON-decoded data."""
|
|
495
495
|
if self.layer_type == LayerType.RASTER:
|
|
496
496
|
return RasterItem.deserialize(serialized_item)
|
|
@@ -461,7 +461,7 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
461
461
|
groups.append(cur_groups)
|
|
462
462
|
return groups
|
|
463
463
|
|
|
464
|
-
def deserialize_item(self, serialized_item:
|
|
464
|
+
def deserialize_item(self, serialized_item: dict) -> OsmItem:
|
|
465
465
|
"""Deserializes an item from JSON-decoded data."""
|
|
466
466
|
return OsmItem.deserialize(serialized_item)
|
|
467
467
|
|
rslearn/data_sources/planet.py
CHANGED
|
@@ -173,9 +173,8 @@ class Planet(DataSource):
|
|
|
173
173
|
planet_item = asyncio.run(self._get_item_by_name(name))
|
|
174
174
|
return self._wrap_planet_item(planet_item)
|
|
175
175
|
|
|
176
|
-
def deserialize_item(self, serialized_item:
|
|
176
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
177
177
|
"""Deserializes an item from JSON-decoded data."""
|
|
178
|
-
assert isinstance(serialized_item, dict)
|
|
179
178
|
return Item.deserialize(serialized_item)
|
|
180
179
|
|
|
181
180
|
async def _download_asset(self, item: Item, tmp_dir: pathlib.Path) -> UPath:
|
|
@@ -234,9 +234,8 @@ class PlanetBasemap(DataSource):
|
|
|
234
234
|
|
|
235
235
|
return groups
|
|
236
236
|
|
|
237
|
-
def deserialize_item(self, serialized_item:
|
|
237
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
238
238
|
"""Deserializes an item from JSON-decoded data."""
|
|
239
|
-
assert isinstance(serialized_item, dict)
|
|
240
239
|
return PlanetItem.deserialize(serialized_item)
|
|
241
240
|
|
|
242
241
|
def ingest(
|