rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -820,9 +820,8 @@ class Sentinel2(DataSource):
820
820
  groups.append(cur_groups)
821
821
  return groups
822
822
 
823
- def deserialize_item(self, serialized_item: Any) -> Sentinel2Item:
823
+ def deserialize_item(self, serialized_item: dict) -> Sentinel2Item:
824
824
  """Deserializes an item from JSON-decoded data."""
825
- assert isinstance(serialized_item, dict)
826
825
  return Sentinel2Item.deserialize(serialized_item)
827
826
 
828
827
  def retrieve_item(
@@ -235,9 +235,8 @@ class GEE(DataSource, TileStore):
235
235
 
236
236
  return groups
237
237
 
238
- def deserialize_item(self, serialized_item: Any) -> Item:
238
+ def deserialize_item(self, serialized_item: dict) -> Item:
239
239
  """Deserializes an item from JSON-decoded data."""
240
- assert isinstance(serialized_item, dict)
241
240
  return Item.deserialize(serialized_item)
242
241
 
243
242
  def item_to_image(self, item: Item) -> ee.image.Image:
@@ -0,0 +1,595 @@
1
+ """Global SRTM void-filled elevation data from USGS, mirrored on Hugging Face by AI2.
2
+
3
+ This module provides:
4
+ 1. A bulk download utility to fetch SRTM data from USGS EarthExplorer via the M2M API.
5
+ This can be used to initialize a mirror of the data.
6
+ 2. A data source that will pull from the AI2 Hugging Face mirror.
7
+
8
+ The SRTM dataset in USGS EarthExplorer is "srtm_v2" which contains void-filled elevation
9
+ data from the Shuttle Radar Topography Mission. The bulk download fetches the highest
10
+ resolution available: 1 arc-second (~30m) in the US and 3 arc-second (~90m) globally.
11
+
12
+ See https://www.usgs.gov/centers/eros/science/usgs-eros-archive-digital-elevation-shuttle-radar-topography-mission-srtm
13
+ for details.
14
+ """
15
+
16
+ import functools
17
+ import json
18
+ import math
19
+ import multiprocessing
20
+ import os
21
+ import re
22
+ import shutil
23
+ import tempfile
24
+ from datetime import timedelta
25
+ from typing import Any
26
+
27
+ import requests
28
+ import shapely
29
+ from upath import UPath
30
+
31
+ from rslearn.config import QueryConfig, SpaceMode
32
+ from rslearn.const import WGS84_PROJECTION
33
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
34
+ from rslearn.log_utils import get_logger
35
+ from rslearn.tile_stores import TileStoreWithLayer
36
+ from rslearn.utils.fsspec import join_upath, open_atomic
37
+ from rslearn.utils.geometry import STGeometry
38
+ from rslearn.utils.m2m_api import M2MAPIClient
39
+ from rslearn.utils.mp import star_imap_unordered
40
+ from rslearn.utils.retry_session import create_retry_session
41
+
42
+ logger = get_logger(__name__)
43
+
44
+ # SRTM dataset name in USGS EarthExplorer M2M API
45
+ SRTM_DATASET_NAME = "srtm_v2"
46
+
47
+ # Product names in order of preference (highest resolution first)
48
+ # 1 Arc-second is only available in the US, 3 Arc-second is global
49
+ SRTM_PRODUCT_NAMES = ["GeoTIFF 1 Arc-second", "GeoTIFF 3 Arc-second"]
50
+
51
+ # SRTM covers latitude -60 to 60
52
+ SRTM_LAT_MIN = -60
53
+ SRTM_LAT_MAX = 60
54
+
55
+ # Cache filename for scene list
56
+ SCENE_CACHE_FILENAME = "scenes.json"
57
+
58
+
59
+ class SRTM(DataSource):
60
+ """Data source for SRTM elevation data from the AI2 Hugging Face mirror.
61
+
62
+ The data is split into 1x1-degree tiles, with filenames like:
63
+ N05/SRTM1N05W163V2.tif (1 arc-second, ~30m resolution)
64
+ N05/SRTM3N05W163V2.tif (3 arc-second, ~90m resolution)
65
+
66
+ SRTM1 (1 arc-second) is available for some regions (primarily US territories),
67
+ while SRTM3 (3 arc-second) is available globally. By default, SRTM1 is preferred
68
+ when available for higher resolution. Set always_use_3arcsecond=True to always
69
+ use the lower resolution SRTM3 data for consistency.
70
+
71
+ Items from this data source do not come with a time range. The band name will match
72
+ that specified in the band set, which should have a single band (e.g. "dem").
73
+ """
74
+
75
+ BASE_URL = (
76
+ "https://huggingface.co/datasets/allenai/srtm-global-void-filled/resolve/main/"
77
+ )
78
+ FILE_LIST_FILENAME = "file_list.json"
79
+ FILENAME_SUFFIX = "V2.tif"
80
+
81
+ def __init__(
82
+ self,
83
+ timeout: timedelta = timedelta(seconds=10),
84
+ cache_dir: str | None = None,
85
+ always_use_3arcsecond: bool = False,
86
+ context: DataSourceContext = DataSourceContext(),
87
+ ):
88
+ """Initialize a new SRTM instance.
89
+
90
+ Args:
91
+ timeout: timeout for requests.
92
+ cache_dir: optional directory to cache the file list.
93
+ always_use_3arcsecond: if True, always use 3 arc-second (SRTM3) data even
94
+ when 1 arc-second (SRTM1) is available. Defaults to False, which
95
+ prefers SRTM1 for higher resolution when available.
96
+ context: the data source context.
97
+ """
98
+ # Get band name from context if possible, falling back to "dem".
99
+ if context.layer_config is not None:
100
+ if len(context.layer_config.band_sets) != 1:
101
+ raise ValueError("expected a single band set")
102
+ if len(context.layer_config.band_sets[0].bands) != 1:
103
+ raise ValueError("expected band set to have a single band")
104
+ self.band_name = context.layer_config.band_sets[0].bands[0]
105
+ else:
106
+ self.band_name = "dem"
107
+
108
+ self.timeout = timeout
109
+ self.session = requests.session()
110
+ self.always_use_3arcsecond = always_use_3arcsecond
111
+
112
+ # Set the cache path if a cache_dir is provided.
113
+ self.file_list_cache_path: UPath | None = None
114
+ if cache_dir is not None:
115
+ if context.ds_path is not None:
116
+ cache_root = join_upath(context.ds_path, cache_dir)
117
+ else:
118
+ cache_root = UPath(cache_dir)
119
+ cache_root.mkdir(parents=True, exist_ok=True)
120
+ self.file_list_cache_path = join_upath(cache_root, self.FILE_LIST_FILENAME)
121
+
122
+ self._basename_to_item, self._tile_to_item = self._load_file_index()
123
+
124
+ def _load_file_index(
125
+ self,
126
+ ) -> tuple[dict[str, Item], dict[tuple[int, int], Item]]:
127
+ """Load the file list and build indices for lookups."""
128
+ file_list = self._load_file_list_json()
129
+ if not isinstance(file_list, list):
130
+ raise ValueError("expected file_list.json to be a list of filenames")
131
+
132
+ basename_to_item: dict[str, Item] = {}
133
+ tile_to_item: dict[tuple[int, int], Item] = {}
134
+
135
+ for entry in file_list:
136
+ if not isinstance(entry, str):
137
+ raise ValueError(
138
+ "expected file_list.json to contain only string filenames"
139
+ )
140
+ basename = os.path.basename(entry)
141
+
142
+ # Check if this is SRTM1 or SRTM3 based on filename prefix
143
+ is_srtm1 = basename.startswith("SRTM1")
144
+
145
+ # Skip SRTM1 files if always_use_3arcsecond is enabled
146
+ if self.always_use_3arcsecond and is_srtm1:
147
+ continue
148
+
149
+ lat_min, lon_min = self._parse_tile_basename(basename)
150
+ geometry = STGeometry(
151
+ WGS84_PROJECTION,
152
+ shapely.box(lon_min, lat_min, lon_min + 1, lat_min + 1),
153
+ None,
154
+ )
155
+ item = Item(entry, geometry)
156
+
157
+ key = (lon_min, lat_min)
158
+
159
+ # For tile_to_item, prefer SRTM1 over SRTM3 when not using always_use_3arcsecond
160
+ if key in tile_to_item:
161
+ existing_is_srtm1 = os.path.basename(tile_to_item[key].name).startswith(
162
+ "SRTM1"
163
+ )
164
+ # Only replace if current is SRTM1 and existing is not
165
+ if is_srtm1 and not existing_is_srtm1:
166
+ tile_to_item[key] = item
167
+ # Keep existing if it's SRTM1 and current is not
168
+ elif existing_is_srtm1 and not is_srtm1:
169
+ pass
170
+ else:
171
+ # Same type, keep the existing one
172
+ pass
173
+ else:
174
+ tile_to_item[key] = item
175
+
176
+ if basename not in basename_to_item:
177
+ basename_to_item[basename] = item
178
+
179
+ logger.info(
180
+ f"Loaded {len(tile_to_item)} SRTM tiles from Hugging Face file list"
181
+ )
182
+ return basename_to_item, tile_to_item
183
+
184
+ def _load_file_list_json(self) -> list[Any]:
185
+ """Load file list JSON, optionally from cache."""
186
+ if self.file_list_cache_path is not None and self.file_list_cache_path.exists():
187
+ with self.file_list_cache_path.open() as f:
188
+ file_list = json.load(f)
189
+ logger.info(f"Loaded SRTM file list cache from {self.file_list_cache_path}")
190
+ return file_list
191
+
192
+ response = self.session.get(
193
+ self.BASE_URL + self.FILE_LIST_FILENAME,
194
+ timeout=self.timeout.total_seconds(),
195
+ )
196
+ response.raise_for_status()
197
+ file_list = response.json()
198
+ if self.file_list_cache_path is not None:
199
+ with open_atomic(self.file_list_cache_path, "w") as f:
200
+ json.dump(file_list, f)
201
+ return file_list
202
+
203
+ def _parse_tile_basename(self, basename: str) -> tuple[int, int]:
204
+ """Parse a tile basename into (lat_min, lon_min)."""
205
+ match = re.match(
206
+ r"^SRTM\d([NS])(\d{2})([EW])(\d{3})V2\.tif$",
207
+ basename,
208
+ )
209
+ if match is None:
210
+ raise ValueError(f"invalid SRTM tile filename: {basename}")
211
+
212
+ lat_sign, lat_str, lon_sign, lon_str = match.groups()
213
+ lat_degrees = int(lat_str)
214
+ lon_degrees = int(lon_str)
215
+
216
+ if lat_sign == "N":
217
+ lat_min = lat_degrees
218
+ elif lat_sign == "S":
219
+ lat_min = -lat_degrees
220
+ else:
221
+ raise ValueError(f"invalid SRTM tile latitude for filename: {basename}")
222
+
223
+ if lon_sign == "E":
224
+ lon_min = lon_degrees
225
+ elif lon_sign == "W":
226
+ lon_min = -lon_degrees
227
+ else:
228
+ raise ValueError(f"invalid SRTM tile longitude for filename: {basename}")
229
+
230
+ return lat_min, lon_min
231
+
232
+ def get_item_by_name(self, name: str) -> Item:
233
+ """Gets an item by name.
234
+
235
+ Args:
236
+ name: the name of the item to get. For SRTM, the item name is the filename
237
+ of the GeoTIFF tile.
238
+
239
+ Returns:
240
+ the Item object
241
+ """
242
+ basename = os.path.basename(name)
243
+ item = self._basename_to_item.get(basename)
244
+ if item is not None:
245
+ return item
246
+ raise ValueError(f"unknown SRTM tile name: {name}")
247
+
248
+ def get_items(
249
+ self, geometries: list[STGeometry], query_config: QueryConfig
250
+ ) -> list[list[list[Item]]]:
251
+ """Get a list of items in the data source intersecting the given geometries.
252
+
253
+ Args:
254
+ geometries: the spatiotemporal geometries
255
+ query_config: the query configuration
256
+
257
+ Returns:
258
+ List of groups of items that should be retrieved for each geometry.
259
+ """
260
+ # It only makes sense to create mosaic from the SRTM data since it is one
261
+ # global layer (but spatially split up into items).
262
+ if query_config.space_mode != SpaceMode.MOSAIC or query_config.max_matches != 1:
263
+ raise ValueError(
264
+ "expected mosaic with max_matches=1 for the query configuration"
265
+ )
266
+
267
+ groups = []
268
+ for geometry in geometries:
269
+ wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
270
+ shp_bounds = wgs84_geometry.shp.bounds
271
+ cell_bounds = (
272
+ math.floor(shp_bounds[0]),
273
+ math.floor(shp_bounds[1]),
274
+ math.ceil(shp_bounds[2]),
275
+ math.ceil(shp_bounds[3]),
276
+ )
277
+ items = []
278
+ for lon_min in range(cell_bounds[0], cell_bounds[2]):
279
+ for lat_min in range(cell_bounds[1], cell_bounds[3]):
280
+ if (lon_min, lat_min) not in self._tile_to_item:
281
+ continue
282
+ items.append(self._tile_to_item[(lon_min, lat_min)])
283
+
284
+ logger.debug(f"Got {len(items)} items (grid cells) for geometry")
285
+ groups.append([items])
286
+
287
+ return groups
288
+
289
+ def deserialize_item(self, serialized_item: dict) -> Item:
290
+ """Deserializes an item from JSON-decoded data."""
291
+ return Item.deserialize(serialized_item)
292
+
293
+ def ingest(
294
+ self,
295
+ tile_store: TileStoreWithLayer,
296
+ items: list[Item],
297
+ geometries: list[list[STGeometry]],
298
+ ) -> None:
299
+ """Ingest items into the given tile store.
300
+
301
+ Args:
302
+ tile_store: the tile store to ingest into
303
+ items: the items to ingest
304
+ geometries: a list of geometries needed for each item
305
+ """
306
+ for item in items:
307
+ if tile_store.is_raster_ready(item.name, [self.band_name]):
308
+ continue
309
+
310
+ url = self.BASE_URL + item.name
311
+ logger.debug(f"Downloading SRTM data for {item.name} from {url}")
312
+ response = self.session.get(
313
+ url, stream=True, timeout=self.timeout.total_seconds()
314
+ )
315
+
316
+ if response.status_code == 404:
317
+ logger.warning(
318
+ f"Skipping item {item.name} because there is no data at that cell"
319
+ )
320
+ continue
321
+ response.raise_for_status()
322
+
323
+ with tempfile.TemporaryDirectory() as tmp_dir:
324
+ local_fname = os.path.join(tmp_dir, "data.tif")
325
+ with open(local_fname, "wb") as f:
326
+ for chunk in response.iter_content(chunk_size=8192):
327
+ f.write(chunk)
328
+
329
+ logger.debug(f"Ingesting data for {item.name}")
330
+ tile_store.write_raster_file(
331
+ item.name, [self.band_name], UPath(local_fname)
332
+ )
333
+
334
+
335
+ # The code below is for creating the mirror by downloading SRTM GeoTIFFs from the USGS
336
+ # EarthExplorer M2M API.
337
+
338
+
339
+ @functools.cache
340
+ def _get_cached_m2m_client(timeout: timedelta) -> M2MAPIClient:
341
+ """Get a cached M2M API client for this process.
342
+
343
+ The client is cached per process, so each worker reuses its client across
344
+ multiple download tasks.
345
+
346
+ Args:
347
+ timeout: timeout for API requests
348
+
349
+ Returns:
350
+ M2M API client
351
+ """
352
+ session = create_retry_session()
353
+ return M2MAPIClient(timeout=timeout, session=session)
354
+
355
+
356
+ def _worker_download_scene(
357
+ scene: dict[str, Any],
358
+ output_path: str,
359
+ timeout: timedelta,
360
+ ) -> str:
361
+ """Worker function for downloading a single SRTM GeoTIFF file.
362
+
363
+ This function gets the download URL from the M2M API and downloads the file.
364
+ The M2M client is cached per worker process to avoid repeated logins.
365
+
366
+ Args:
367
+ scene: scene metadata from scene_search
368
+ output_path: path to save the downloaded file
369
+ timeout: timeout for API requests and download
370
+
371
+ Returns:
372
+ the output path of the downloaded file
373
+ """
374
+ entity_id = scene["entityId"]
375
+ display_id = scene["displayId"]
376
+
377
+ logger.debug(f"Starting download for {display_id}")
378
+
379
+ # Ensure output directory exists
380
+ output_dir = os.path.dirname(output_path)
381
+ os.makedirs(output_dir, exist_ok=True)
382
+
383
+ # Get cached M2M client for this worker process
384
+ client = _get_cached_m2m_client(timeout)
385
+
386
+ # Get downloadable products for this scene
387
+ logger.debug(f"Getting products for {display_id}")
388
+ products = client.get_downloadable_products(SRTM_DATASET_NAME, entity_id)
389
+
390
+ # Build a map of available products
391
+ available_products = {}
392
+ for product in products:
393
+ if product.get("available", False) and product.get("id"):
394
+ available_products[product["productName"]] = product["id"]
395
+
396
+ # Try products in order of preference (highest resolution first)
397
+ download_url = None
398
+ for product_name in SRTM_PRODUCT_NAMES:
399
+ if product_name in available_products:
400
+ product_id = available_products[product_name]
401
+ logger.debug(f"Getting download URL for {display_id} ({product_name})")
402
+ download_url = client.get_download_url(entity_id, product_id)
403
+ break
404
+
405
+ if download_url is None:
406
+ raise ValueError(
407
+ f"No GeoTIFF product found for scene {display_id}. "
408
+ f"Available products: {list(available_products.keys())}"
409
+ )
410
+
411
+ # Download with atomic write using retry session
412
+ logger.debug(f"Downloading file for {display_id}")
413
+ with client.session.get(
414
+ download_url, stream=True, timeout=timeout.total_seconds()
415
+ ) as r:
416
+ r.raise_for_status()
417
+ with open_atomic(UPath(output_path), "wb") as f:
418
+ shutil.copyfileobj(r.raw, f)
419
+
420
+ return output_path
421
+
422
+
423
+ def _fetch_all_scenes(
424
+ client: M2MAPIClient,
425
+ cache_dir: str,
426
+ ) -> dict[str, dict[str, Any]]:
427
+ """Fetch all SRTM scenes, using a cached scene list if available.
428
+
429
+ SRTM data is organized in 1x1 degree tiles covering latitude -60 to 60.
430
+ We iterate over 10x10 degree boxes to avoid timeout issues with global search.
431
+
432
+ This fetches both SRTM1 (1 arc-second, US only) and SRTM3 (3 arc-second, global)
433
+ scenes. The download function will select the highest resolution available.
434
+
435
+ The scene list is cached in scenes.json in the cache directory. If the cache
436
+ exists, it is loaded instead of querying the API.
437
+
438
+ Args:
439
+ client: M2M API client
440
+ cache_dir: directory where cache file is stored
441
+
442
+ Returns:
443
+ dict mapping display_id to scene metadata
444
+ """
445
+ cache_path = os.path.join(cache_dir, SCENE_CACHE_FILENAME)
446
+ scenes: dict[str, dict[str, Any]] = {}
447
+
448
+ # Try to load from cache
449
+ if os.path.exists(cache_path):
450
+ with open(cache_path) as f:
451
+ scenes = json.load(f)
452
+ logger.info(f"Loaded {len(scenes)} scenes from cache")
453
+ return scenes
454
+
455
+ # Fetch from API by iterating over 10x10 degree boxes
456
+ logger.info("No cached scene list found, fetching from API...")
457
+ box_size = 10
458
+ total_boxes = ((SRTM_LAT_MAX - SRTM_LAT_MIN) // box_size) * (360 // box_size)
459
+
460
+ box_idx = 0
461
+ for lat in range(SRTM_LAT_MIN, SRTM_LAT_MAX, box_size):
462
+ for lon in range(-180, 180, box_size):
463
+ box_idx += 1
464
+ bbox = (lon, lat, lon + box_size, lat + box_size)
465
+
466
+ results = client.scene_search(SRTM_DATASET_NAME, bbox=bbox)
467
+ for scene in results:
468
+ display_id = scene["displayId"]
469
+ if display_id not in scenes:
470
+ scenes[display_id] = scene
471
+
472
+ logger.info(
473
+ f"Searched {box_idx}/{total_boxes} boxes, "
474
+ f"found {len(scenes)} unique scenes so far"
475
+ )
476
+
477
+ # Save to cache
478
+ with open(cache_path, "w") as f:
479
+ json.dump(scenes, f)
480
+ logger.info(f"Cached {len(scenes)} scenes to {SCENE_CACHE_FILENAME}")
481
+
482
+ return scenes
483
+
484
+
485
+ def bulk_download_srtm(
486
+ output_dir: str,
487
+ num_workers: int = 4,
488
+ timeout: timedelta = timedelta(minutes=5),
489
+ ) -> None:
490
+ """Bulk download SRTM data from USGS EarthExplorer.
491
+
492
+ Downloads all SRTM tiles to the specified output directory. Uses atomic
493
+ rename to ensure partially downloaded files are not included. Files that
494
+ already exist in the output directory are skipped.
495
+
496
+ The scene list is cached in scenes.json in the output directory to avoid
497
+ re-querying on subsequent runs.
498
+
499
+ Requires M2M_USERNAME and M2M_TOKEN environment variables to be set.
500
+
501
+ Args:
502
+ output_dir: directory to save downloaded files
503
+ num_workers: number of parallel download workers
504
+ timeout: timeout for API requests and downloads
505
+ """
506
+ os.makedirs(output_dir, exist_ok=True)
507
+
508
+ client = _get_cached_m2m_client(timeout)
509
+ scenes = _fetch_all_scenes(client, output_dir)
510
+
511
+ # Filter out scenes that are already downloaded
512
+ download_tasks = []
513
+ skipped = 0
514
+ for display_id, scene in scenes.items():
515
+ # Use display_id as the filename with .tif extension for GeoTIFF
516
+ output_path = os.path.join(output_dir, f"{display_id}.tif")
517
+ if os.path.exists(output_path):
518
+ logger.debug(f"Skipping {display_id} - already downloaded")
519
+ skipped += 1
520
+ else:
521
+ download_tasks.append(
522
+ {
523
+ "scene": scene,
524
+ "output_path": output_path,
525
+ "timeout": timeout,
526
+ }
527
+ )
528
+
529
+ logger.info(
530
+ f"Need to download {len(download_tasks)} scenes ({skipped} already downloaded)"
531
+ )
532
+
533
+ if not download_tasks:
534
+ logger.info("All scenes already downloaded!")
535
+ return
536
+
537
+ # Download in parallel using multiprocessing.Pool
538
+ # Each worker creates its own M2M API client to get download URLs
539
+ logger.info(f"Starting downloads with {num_workers} workers...")
540
+ with multiprocessing.Pool(num_workers) as pool:
541
+ for output_path in star_imap_unordered(
542
+ pool, _worker_download_scene, download_tasks
543
+ ):
544
+ logger.info(f"Downloaded {output_path}")
545
+
546
+
547
+ def main() -> None:
548
+ """Command-line entry point for bulk SRTM download.
549
+
550
+ Requires M2M_USERNAME and M2M_TOKEN environment variables to be set.
551
+ """
552
+ import argparse
553
+ import logging
554
+
555
+ parser = argparse.ArgumentParser(
556
+ description="Bulk download SRTM data from USGS EarthExplorer. "
557
+ "Requires M2M_USERNAME and M2M_TOKEN environment variables."
558
+ )
559
+ parser.add_argument(
560
+ "output_dir",
561
+ help="Directory to save downloaded SRTM files",
562
+ )
563
+ parser.add_argument(
564
+ "--workers",
565
+ type=int,
566
+ default=4,
567
+ help="Number of parallel download workers (default: 4)",
568
+ )
569
+ parser.add_argument(
570
+ "--timeout",
571
+ type=int,
572
+ default=30,
573
+ help="Timeout in seconds for API requests and downloads (default: 30)",
574
+ )
575
+
576
+ args = parser.parse_args()
577
+
578
+ # Configure logging based on RSLEARN_LOGLEVEL
579
+ log_level = os.environ.get("RSLEARN_LOGLEVEL", "INFO")
580
+ logging.basicConfig(
581
+ level=log_level,
582
+ format="%(asctime)s %(levelname)-6s %(name)s %(message)s",
583
+ )
584
+ # Enable urllib3 logging to see retry information at DEBUG level
585
+ logging.getLogger("urllib3").setLevel(log_level)
586
+
587
+ bulk_download_srtm(
588
+ output_dir=args.output_dir,
589
+ num_workers=args.workers,
590
+ timeout=timedelta(seconds=args.timeout),
591
+ )
592
+
593
+
594
+ if __name__ == "__main__":
595
+ main()
@@ -236,8 +236,8 @@ class RasterImporter(Importer):
236
236
  "windows in the rslearn dataset. When using settings like "
237
237
  "max_matches=1 and space_mode=MOSAIC, this may cause windows outside "
238
238
  "the geometry’s valid bounds to be materialized from the global raster "
239
- "instead of a more appropriate source. Consider using COMPOSITE mode, "
240
- "or increasing max_matches if this behavior is unintended."
239
+ "instead of a more appropriate source. Consider increasing max_matches"
240
+ "if this behavior is unintended."
241
241
  )
242
242
 
243
243
  if spec.name:
@@ -490,7 +490,7 @@ class LocalFiles(DataSource):
490
490
  groups.append(cur_groups)
491
491
  return groups
492
492
 
493
- def deserialize_item(self, serialized_item: Any) -> RasterItem | VectorItem:
493
+ def deserialize_item(self, serialized_item: dict) -> RasterItem | VectorItem:
494
494
  """Deserializes an item from JSON-decoded data."""
495
495
  if self.layer_type == LayerType.RASTER:
496
496
  return RasterItem.deserialize(serialized_item)
@@ -461,7 +461,7 @@ class OpenStreetMap(DataSource[OsmItem]):
461
461
  groups.append(cur_groups)
462
462
  return groups
463
463
 
464
- def deserialize_item(self, serialized_item: Any) -> OsmItem:
464
+ def deserialize_item(self, serialized_item: dict) -> OsmItem:
465
465
  """Deserializes an item from JSON-decoded data."""
466
466
  return OsmItem.deserialize(serialized_item)
467
467
 
@@ -173,9 +173,8 @@ class Planet(DataSource):
173
173
  planet_item = asyncio.run(self._get_item_by_name(name))
174
174
  return self._wrap_planet_item(planet_item)
175
175
 
176
- def deserialize_item(self, serialized_item: Any) -> Item:
176
+ def deserialize_item(self, serialized_item: dict) -> Item:
177
177
  """Deserializes an item from JSON-decoded data."""
178
- assert isinstance(serialized_item, dict)
179
178
  return Item.deserialize(serialized_item)
180
179
 
181
180
  async def _download_asset(self, item: Item, tmp_dir: pathlib.Path) -> UPath:
@@ -234,9 +234,8 @@ class PlanetBasemap(DataSource):
234
234
 
235
235
  return groups
236
236
 
237
- def deserialize_item(self, serialized_item: Any) -> Item:
237
+ def deserialize_item(self, serialized_item: dict) -> Item:
238
238
  """Deserializes an item from JSON-decoded data."""
239
- assert isinstance(serialized_item, dict)
240
239
  return PlanetItem.deserialize(serialized_item)
241
240
 
242
241
  def ingest(