rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py CHANGED
@@ -236,11 +236,9 @@ class BandSetConfig(BaseModel):
236
236
 
237
237
  warnings.warn(
238
238
  "`format = {'name': ...}` is deprecated; "
239
- "use `{'class_path': '...', 'init_args': {...}}` instead.",
240
- DeprecationWarning,
241
- )
242
- logger.warning(
243
- "BandSet.format uses legacy format; support will be removed after 2026-03-01."
239
+ "use `{'class_path': '...', 'init_args': {...}}` instead. "
240
+ "Support will be removed after 2026-03-01.",
241
+ FutureWarning,
244
242
  )
245
243
 
246
244
  legacy_name_to_class_path = {
@@ -294,16 +292,6 @@ class SpaceMode(StrEnum):
294
292
  The duration of the sub-periods is controlled by another option in QueryConfig.
295
293
  """
296
294
 
297
- COMPOSITE = "COMPOSITE"
298
- """Creates one composite covering the entire window.
299
-
300
- During querying all items intersecting the window are placed in one group.
301
- The compositing_method in the rasterlayer config specifies how these items are reduced
302
- to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
303
- """
304
-
305
- # TODO add PER_PERIOD_COMPOSITE
306
-
307
295
 
308
296
  class TimeMode(StrEnum):
309
297
  """Temporal matching mode when looking up items corresponding to a window."""
@@ -353,6 +341,20 @@ class QueryConfig(BaseModel):
353
341
  default=timedelta(days=30),
354
342
  description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
355
343
  )
344
+ mosaic_compositing_overlaps: int = Field(
345
+ default=1,
346
+ description="For MOSAIC and PER_PERIOD_MOSAIC modes, the number of overlapping items "
347
+ "wanted within each item group covering the window. Set to 1 for a single coverage "
348
+ "(default mosaic behavior), or higher for compositing multiple overlapping items."
349
+ "with mean or median compositing method.",
350
+ )
351
+ per_period_mosaic_reverse_time_order: bool = Field(
352
+ default=True,
353
+ description="For PER_PERIOD_MOSAIC mode, whether to return item groups in reverse "
354
+ "temporal order (most recent first). Set to False for chronological order (oldest first). "
355
+ "Default True is deprecated and will change to False with error if still unset or set True "
356
+ "after 2026-04-01.",
357
+ )
356
358
 
357
359
 
358
360
  class DataSourceConfig(BaseModel):
@@ -404,11 +406,9 @@ class DataSourceConfig(BaseModel):
404
406
 
405
407
  warnings.warn(
406
408
  "`Data source configuration {'name': ...}` is deprecated; "
407
- "use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
408
- DeprecationWarning,
409
- )
410
- logger.warning(
411
- "Data source configuration uses legacy format; support will be removed after 2026-03-01."
409
+ "use `{'class_path': '...', 'init_args': {...}, ...}` instead. "
410
+ "Support will be removed after 2026-03-01.",
411
+ FutureWarning,
412
412
  )
413
413
 
414
414
  # Split the dict into the base config that is in the pydantic model, and the
@@ -431,8 +431,9 @@ class DataSourceConfig(BaseModel):
431
431
  and "max_cloud_cover" in ds_init_args
432
432
  ):
433
433
  warnings.warn(
434
- "Data source configuration specifies invalid 'max_cloud_cover' option.",
435
- DeprecationWarning,
434
+ "Data source configuration specifies invalid 'max_cloud_cover' option."
435
+ "Support for ignoring this option will be removed after 2026-03-01.",
436
+ FutureWarning,
436
437
  )
437
438
  del ds_init_args["max_cloud_cover"]
438
439
 
@@ -449,7 +450,13 @@ class LayerType(StrEnum):
449
450
 
450
451
 
451
452
  class CompositingMethod(StrEnum):
452
- """Method how to select pixels for the composite from corresponding items of a window."""
453
+ """Method how to select pixels for the composite from corresponding items of a window.
454
+
455
+ For MEAN and MEDIAN modes, mosaic_compositing_overlaps (in the QueryConfig) should
456
+ be set higher than 1 so that rslearn creates item groups during prepare that cover
457
+ the window with multiple overlaps. At each pixel/band, the mean and median can then
458
+ be computed across items in each group that cover that pixel.
459
+ """
453
460
 
454
461
  FIRST_VALID = "FIRST_VALID"
455
462
  """Select first valid pixel in order of corresponding items (might be sorted)"""
@@ -17,6 +17,7 @@ from .data_source import (
17
17
  ItemLookupDataSource,
18
18
  RetrieveItemDataSource,
19
19
  )
20
+ from .direct_materialize_data_source import DirectMaterializeDataSource
20
21
 
21
22
  __all__ = (
22
23
  "DataSource",
@@ -24,5 +25,6 @@ __all__ = (
24
25
  "Item",
25
26
  "ItemLookupDataSource",
26
27
  "RetrieveItemDataSource",
28
+ "DirectMaterializeDataSource",
27
29
  "data_source_from_config",
28
30
  )
@@ -9,32 +9,28 @@ import urllib.request
9
9
  import zipfile
10
10
  from collections.abc import Generator
11
11
  from datetime import datetime
12
- from typing import Any, BinaryIO
12
+ from typing import BinaryIO
13
13
 
14
- import affine
15
14
  import boto3
16
15
  import dateutil.parser
17
16
  import fiona
18
17
  import fiona.transform
19
- import numpy.typing as npt
20
- import rasterio
21
18
  import shapely
22
19
  import shapely.geometry
23
20
  import tqdm
24
- from rasterio.enums import Resampling
25
21
  from upath import UPath
26
22
 
27
23
  import rslearn.data_sources.utils
28
- from rslearn.config import LayerConfig
29
24
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
30
- from rslearn.dataset import Window
31
- from rslearn.dataset.materialize import RasterMaterializer
32
- from rslearn.tile_stores import TileStore, TileStoreWithLayer
25
+ from rslearn.data_sources.direct_materialize_data_source import (
26
+ DirectMaterializeDataSource,
27
+ )
28
+ from rslearn.tile_stores import TileStoreWithLayer
33
29
  from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
34
- from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
30
+ from rslearn.utils.geometry import STGeometry
35
31
  from rslearn.utils.grid_index import GridIndex
36
32
 
37
- from .data_source import DataSource, DataSourceContext, Item, QueryConfig
33
+ from .data_source import DataSourceContext, Item, QueryConfig
38
34
 
39
35
  WRS2_GRID_SIZE = 1.0
40
36
 
@@ -79,7 +75,7 @@ class LandsatOliTirsItem(Item):
79
75
  )
80
76
 
81
77
 
82
- class LandsatOliTirs(DataSource, TileStore):
78
+ class LandsatOliTirs(DirectMaterializeDataSource[LandsatOliTirsItem]):
83
79
  """A data source for Landsat 8/9 OLI-TIRS imagery on AWS.
84
80
 
85
81
  Specifically, uses the usgs-landsat S3 bucket maintained by USGS. The data includes
@@ -91,7 +87,7 @@ class LandsatOliTirs(DataSource, TileStore):
91
87
 
92
88
  bucket_name = "usgs-landsat"
93
89
  bucket_prefix = "collection02/level-1/standard/oli-tirs"
94
- bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10", "B11"]
90
+ BANDS = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10", "B11"]
95
91
 
96
92
  wrs2_url = "https://d9-wret.s3.us-west-2.amazonaws.com/assets/palladium/production/s3fs-public/atoms/files/WRS2_descending_0.zip" # noqa
97
93
  """URL to download shapefile specifying polygon of each (path, row)."""
@@ -110,6 +106,10 @@ class LandsatOliTirs(DataSource, TileStore):
110
106
  SpaceMode.WITHIN.
111
107
  context: the data source context.
112
108
  """
109
+ # Each band is a separate single-band asset.
110
+ asset_bands = {band: [band] for band in self.BANDS}
111
+ super().__init__(asset_bands=asset_bands)
112
+
113
113
  # If context is provided, we join the directory with the dataset path,
114
114
  # otherwise we treat it directly as UPath.
115
115
  if context.ds_path is not None:
@@ -342,16 +342,43 @@ class LandsatOliTirs(DataSource, TileStore):
342
342
  return item
343
343
  raise ValueError(f"item {name} not found")
344
344
 
345
- def deserialize_item(self, serialized_item: Any) -> LandsatOliTirsItem:
345
+ # --- DirectMaterializeDataSource implementation ---
346
+
347
+ def get_asset_url(self, item_name: str, asset_key: str) -> str:
348
+ """Get the presigned URL to read the asset for the given item and asset key.
349
+
350
+ Args:
351
+ item_name: the name of the item.
352
+ asset_key: the key identifying which asset to get (the band name).
353
+
354
+ Returns:
355
+ the presigned URL to read the asset from.
356
+ """
357
+ # Get the item since it has the blob path.
358
+ item = self.get_item_by_name(item_name)
359
+
360
+ # For Landsat, the asset_key is the band name (e.g., "B1", "B2", etc.).
361
+ blob_key = item.blob_path + f"{asset_key}.TIF"
362
+ return self.client.generate_presigned_url(
363
+ "get_object",
364
+ Params={
365
+ "Bucket": self.bucket_name,
366
+ "Key": blob_key,
367
+ "RequestPayer": "requester",
368
+ },
369
+ )
370
+
371
+ # --- DataSource implementation ---
372
+
373
+ def deserialize_item(self, serialized_item: dict) -> LandsatOliTirsItem:
346
374
  """Deserializes an item from JSON-decoded data."""
347
- assert isinstance(serialized_item, dict)
348
375
  return LandsatOliTirsItem.deserialize(serialized_item)
349
376
 
350
377
  def retrieve_item(
351
378
  self, item: LandsatOliTirsItem
352
379
  ) -> Generator[tuple[str, BinaryIO], None, None]:
353
380
  """Retrieves the rasters corresponding to an item as file streams."""
354
- for band in self.bands:
381
+ for band in self.BANDS:
355
382
  buf = io.BytesIO()
356
383
  self.bucket.download_fileobj(
357
384
  item.blob_path + f"{band}.TIF",
@@ -376,7 +403,7 @@ class LandsatOliTirs(DataSource, TileStore):
376
403
  geometries: a list of geometries needed for each item
377
404
  """
378
405
  for item, cur_geometries in zip(items, geometries):
379
- for band in self.bands:
406
+ for band in self.BANDS:
380
407
  band_names = [band]
381
408
  if tile_store.is_raster_ready(item.name, band_names):
382
409
  continue
@@ -389,147 +416,3 @@ class LandsatOliTirs(DataSource, TileStore):
389
416
  ExtraArgs={"RequestPayer": "requester"},
390
417
  )
391
418
  tile_store.write_raster_file(item.name, band_names, UPath(fname))
392
-
393
- # The functions below are to emulate TileStore functionality so we can easily
394
- # support materialization directly from the COGs.
395
- def is_raster_ready(
396
- self, layer_name: str, item_name: str, bands: list[str]
397
- ) -> bool:
398
- """Checks if this raster has been written to the store.
399
-
400
- Args:
401
- layer_name: the layer name or alias.
402
- item_name: the item.
403
- bands: the list of bands identifying which specific raster to read.
404
-
405
- Returns:
406
- whether there is a raster in the store matching the source, item, and
407
- bands.
408
- """
409
- # Always ready since we access it on AWS bucket.
410
- return True
411
-
412
- def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
413
- """Get the sets of bands that have been stored for the specified item.
414
-
415
- Args:
416
- layer_name: the layer name or alias.
417
- item_name: the item.
418
-
419
- Returns:
420
- a list of lists of bands that are in the tile store (with one raster
421
- stored corresponding to each inner list). If no rasters are ready for
422
- this item, returns empty list.
423
- """
424
- return [[band] for band in self.bands]
425
-
426
- def get_raster_bounds(
427
- self, layer_name: str, item_name: str, bands: list[str], projection: Projection
428
- ) -> PixelBounds:
429
- """Get the bounds of the raster in the specified projection.
430
-
431
- Args:
432
- layer_name: the layer name or alias.
433
- item_name: the item to check.
434
- bands: the list of bands identifying which specific raster to read. These
435
- bands must match the bands of a stored raster.
436
- projection: the projection to get the raster's bounds in.
437
-
438
- Returns:
439
- the bounds of the raster in the projection.
440
- """
441
- item = self.get_item_by_name(item_name)
442
- geom = item.geometry.to_projection(projection)
443
- return (
444
- int(geom.shp.bounds[0]),
445
- int(geom.shp.bounds[1]),
446
- int(geom.shp.bounds[2]),
447
- int(geom.shp.bounds[3]),
448
- )
449
-
450
- def read_raster(
451
- self,
452
- layer_name: str,
453
- item_name: str,
454
- bands: list[str],
455
- projection: Projection,
456
- bounds: PixelBounds,
457
- resampling: Resampling = Resampling.bilinear,
458
- ) -> npt.NDArray[Any]:
459
- """Read raster data from the store.
460
-
461
- Args:
462
- layer_name: the layer name or alias.
463
- item_name: the item to read.
464
- bands: the list of bands identifying which specific raster to read. These
465
- bands must match the bands of a stored raster.
466
- projection: the projection to read in.
467
- bounds: the bounds to read.
468
- resampling: the resampling method to use in case reprojection is needed.
469
-
470
- Returns:
471
- the raster data
472
- """
473
- # Landsat assets have single band per asset.
474
- assert len(bands) == 1
475
- band = bands[0]
476
-
477
- # Get the item since it has the blob path.
478
- item = self.get_item_by_name(item_name)
479
-
480
- # Create pre-signed URL for rasterio access.
481
- # We do this because accessing via URL is much faster since rasterio can use
482
- # the URL directly.
483
- blob_key = item.blob_path + f"{band}.TIF"
484
- url = self.client.generate_presigned_url(
485
- "get_object",
486
- Params={
487
- "Bucket": self.bucket_name,
488
- "Key": blob_key,
489
- "RequestPayer": "requester",
490
- },
491
- )
492
-
493
- # Construct the transform to use for the warped dataset.
494
- wanted_transform = affine.Affine(
495
- projection.x_resolution,
496
- 0,
497
- bounds[0] * projection.x_resolution,
498
- 0,
499
- projection.y_resolution,
500
- bounds[1] * projection.y_resolution,
501
- )
502
-
503
- with rasterio.open(url) as src:
504
- with rasterio.vrt.WarpedVRT(
505
- src,
506
- crs=projection.crs,
507
- transform=wanted_transform,
508
- width=bounds[2] - bounds[0],
509
- height=bounds[3] - bounds[1],
510
- resampling=resampling,
511
- ) as vrt:
512
- return vrt.read()
513
-
514
- def materialize(
515
- self,
516
- window: Window,
517
- item_groups: list[list[LandsatOliTirsItem]],
518
- layer_name: str,
519
- layer_cfg: LayerConfig,
520
- ) -> None:
521
- """Materialize data for the window.
522
-
523
- Args:
524
- window: the window to materialize
525
- item_groups: the items from get_items
526
- layer_name: the name of this layer
527
- layer_cfg: the config of this layer
528
- """
529
- RasterMaterializer().materialize(
530
- TileStoreWithLayer(self, layer_name),
531
- window,
532
- layer_name,
533
- layer_cfg,
534
- item_groups,
535
- )
@@ -318,9 +318,8 @@ class Naip(DataSource):
318
318
  groups.append(cur_groups)
319
319
  return groups
320
320
 
321
- def deserialize_item(self, serialized_item: Any) -> NaipItem:
321
+ def deserialize_item(self, serialized_item: dict) -> NaipItem:
322
322
  """Deserializes an item from JSON-decoded data."""
323
- assert isinstance(serialized_item, dict)
324
323
  return NaipItem.deserialize(serialized_item)
325
324
 
326
325
  def ingest(
@@ -639,9 +638,8 @@ class Sentinel2(
639
638
  return item
640
639
  raise ValueError(f"item {name} not found")
641
640
 
642
- def deserialize_item(self, serialized_item: Any) -> Sentinel2Item:
641
+ def deserialize_item(self, serialized_item: dict) -> Sentinel2Item:
643
642
  """Deserializes an item from JSON-decoded data."""
644
- assert isinstance(serialized_item, dict)
645
643
  return Sentinel2Item.deserialize(serialized_item)
646
644
 
647
645
  def retrieve_item(
@@ -2,7 +2,6 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
- from typing import Any
6
5
 
7
6
  import boto3
8
7
  from upath import UPath
@@ -78,9 +77,8 @@ class Sentinel1(DataSource, TileStore):
78
77
  """Gets an item by name."""
79
78
  return self.sentinel1.get_item_by_name(name)
80
79
 
81
- def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
80
+ def deserialize_item(self, serialized_item: dict) -> CopernicusItem:
82
81
  """Deserializes an item from JSON-decoded data."""
83
- assert isinstance(serialized_item, dict)
84
82
  return CopernicusItem.deserialize(serialized_item)
85
83
 
86
84
  def ingest(