rslearn 0.0.22__tar.gz → 0.0.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. {rslearn-0.0.22/rslearn.egg-info → rslearn-0.0.24}/PKG-INFO +1 -1
  2. {rslearn-0.0.22 → rslearn-0.0.24}/pyproject.toml +1 -1
  3. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/planetary_computer.py +149 -1
  4. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/stac.py +24 -3
  5. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/main.py +4 -1
  6. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/simple_time_series.py +1 -1
  7. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/lightning_module.py +21 -8
  8. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/multi_task.py +8 -5
  9. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/per_pixel_regression.py +1 -1
  10. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/segmentation.py +163 -22
  11. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/raster_format.py +17 -0
  12. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/stac.py +4 -0
  13. {rslearn-0.0.22 → rslearn-0.0.24/rslearn.egg-info}/PKG-INFO +1 -1
  14. {rslearn-0.0.22 → rslearn-0.0.24}/LICENSE +0 -0
  15. {rslearn-0.0.22 → rslearn-0.0.24}/NOTICE +0 -0
  16. {rslearn-0.0.22 → rslearn-0.0.24}/README.md +0 -0
  17. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/__init__.py +0 -0
  18. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/arg_parser.py +0 -0
  19. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/config/__init__.py +0 -0
  20. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/config/dataset.py +0 -0
  21. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/const.py +0 -0
  22. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/__init__.py +0 -0
  23. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_landsat.py +0 -0
  24. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_open_data.py +0 -0
  25. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_sentinel1.py +0 -0
  26. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_sentinel2_element84.py +0 -0
  27. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/climate_data_store.py +0 -0
  28. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/copernicus.py +0 -0
  29. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/data_source.py +0 -0
  30. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/earthdaily.py +0 -0
  31. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/earthdata_srtm.py +0 -0
  32. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/eurocrops.py +0 -0
  33. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/gcp_public_data.py +0 -0
  34. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/google_earth_engine.py +0 -0
  35. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/local_files.py +0 -0
  36. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/openstreetmap.py +0 -0
  37. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/planet.py +0 -0
  38. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/planet_basemap.py +0 -0
  39. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/soilgrids.py +0 -0
  40. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/usda_cdl.py +0 -0
  41. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/usgs_landsat.py +0 -0
  42. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/utils.py +0 -0
  43. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/vector_source.py +0 -0
  44. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/worldcereal.py +0 -0
  45. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/worldcover.py +0 -0
  46. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/worldpop.py +0 -0
  47. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/xyz_tiles.py +0 -0
  48. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/__init__.py +0 -0
  49. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/add_windows.py +0 -0
  50. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/dataset.py +0 -0
  51. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/handler_summaries.py +0 -0
  52. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/manage.py +0 -0
  53. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/materialize.py +0 -0
  54. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/remap.py +0 -0
  55. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/storage/__init__.py +0 -0
  56. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/storage/file.py +0 -0
  57. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/storage/storage.py +0 -0
  58. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/window.py +0 -0
  59. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/lightning_cli.py +0 -0
  60. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/log_utils.py +0 -0
  61. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/__init__.py +0 -0
  62. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/anysat.py +0 -0
  63. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/attention_pooling.py +0 -0
  64. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/clay/clay.py +0 -0
  65. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/clay/configs/metadata.yaml +0 -0
  66. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/clip.py +0 -0
  67. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/component.py +0 -0
  68. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/concatenate_features.py +0 -0
  69. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/conv.py +0 -0
  70. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/croma.py +0 -0
  71. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/__init__.py +0 -0
  72. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/box_ops.py +0 -0
  73. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/detr.py +0 -0
  74. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/matcher.py +0 -0
  75. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/position_encoding.py +0 -0
  76. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/transformer.py +0 -0
  77. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/util.py +0 -0
  78. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/dinov3.py +0 -0
  79. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/faster_rcnn.py +0 -0
  80. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/feature_center_crop.py +0 -0
  81. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/fpn.py +0 -0
  82. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/galileo/__init__.py +0 -0
  83. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/galileo/galileo.py +0 -0
  84. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/galileo/single_file_galileo.py +0 -0
  85. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/module_wrapper.py +0 -0
  86. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/molmo.py +0 -0
  87. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/multitask.py +0 -0
  88. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  89. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/olmoearth_pretrain/model.py +0 -0
  90. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  91. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon.py +0 -0
  92. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  93. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  94. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  95. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  96. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  97. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  98. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  99. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  100. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  101. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  102. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  103. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  104. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/pick_features.py +0 -0
  105. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/pooling_decoder.py +0 -0
  106. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/presto/__init__.py +0 -0
  107. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/presto/presto.py +0 -0
  108. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/presto/single_file_presto.py +0 -0
  109. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/prithvi.py +0 -0
  110. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/resize_features.py +0 -0
  111. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/sam2_enc.py +0 -0
  112. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/satlaspretrain.py +0 -0
  113. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/singletask.py +0 -0
  114. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/ssl4eo_s12.py +0 -0
  115. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/swin.py +0 -0
  116. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/task_embedding.py +0 -0
  117. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/terramind.py +0 -0
  118. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/trunk.py +0 -0
  119. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/unet.py +0 -0
  120. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/upsample.py +0 -0
  121. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/use_croma.py +0 -0
  122. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/py.typed +0 -0
  123. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/template_params.py +0 -0
  124. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/tile_stores/__init__.py +0 -0
  125. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/tile_stores/default.py +0 -0
  126. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/tile_stores/tile_store.py +0 -0
  127. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/__init__.py +0 -0
  128. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/all_patches_dataset.py +0 -0
  129. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/__init__.py +0 -0
  130. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/adapters.py +0 -0
  131. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  132. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/gradients.py +0 -0
  133. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/peft.py +0 -0
  134. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/data_module.py +0 -0
  135. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/dataset.py +0 -0
  136. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/model_context.py +0 -0
  137. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/optimizer.py +0 -0
  138. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/prediction_writer.py +0 -0
  139. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/scheduler.py +0 -0
  140. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/__init__.py +0 -0
  141. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/classification.py +0 -0
  142. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/detection.py +0 -0
  143. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/embedding.py +0 -0
  144. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/regression.py +0 -0
  145. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/task.py +0 -0
  146. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/__init__.py +0 -0
  147. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/concatenate.py +0 -0
  148. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/crop.py +0 -0
  149. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/flip.py +0 -0
  150. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/mask.py +0 -0
  151. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/normalize.py +0 -0
  152. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/pad.py +0 -0
  153. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/resize.py +0 -0
  154. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/select_bands.py +0 -0
  155. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/sentinel1.py +0 -0
  156. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/transform.py +0 -0
  157. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/__init__.py +0 -0
  158. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/array.py +0 -0
  159. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/feature.py +0 -0
  160. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/fsspec.py +0 -0
  161. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/geometry.py +0 -0
  162. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/get_utm_ups_crs.py +0 -0
  163. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/grid_index.py +0 -0
  164. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/jsonargparse.py +0 -0
  165. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/mp.py +0 -0
  166. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/rtree_index.py +0 -0
  167. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/spatial_index.py +0 -0
  168. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/sqlite_index.py +0 -0
  169. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/time.py +0 -0
  170. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/vector_format.py +0 -0
  171. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/SOURCES.txt +0 -0
  172. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/dependency_links.txt +0 -0
  173. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/entry_points.txt +0 -0
  174. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/requires.txt +0 -0
  175. {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/top_level.txt +0 -0
  176. {rslearn-0.0.22 → rslearn-0.0.24}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.22"
3
+ version = "0.0.24"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -3,7 +3,7 @@
3
3
  import os
4
4
  import tempfile
5
5
  import xml.etree.ElementTree as ET
6
- from datetime import timedelta
6
+ from datetime import datetime, timedelta
7
7
  from typing import Any
8
8
 
9
9
  import affine
@@ -12,6 +12,7 @@ import planetary_computer
12
12
  import rasterio
13
13
  import requests
14
14
  from rasterio.enums import Resampling
15
+ from typing_extensions import override
15
16
  from upath import UPath
16
17
 
17
18
  from rslearn.config import LayerConfig
@@ -24,11 +25,104 @@ from rslearn.tile_stores import TileStore, TileStoreWithLayer
24
25
  from rslearn.utils.fsspec import join_upath
25
26
  from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
26
27
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
28
+ from rslearn.utils.stac import StacClient, StacItem
27
29
 
28
30
  from .copernicus import get_harmonize_callback
29
31
 
30
32
  logger = get_logger(__name__)
31
33
 
34
+ # Max limit accepted by Planetary Computer API.
35
+ PLANETARY_COMPUTER_LIMIT = 1000
36
+
37
+
38
+ class PlanetaryComputerStacClient(StacClient):
39
+ """A StacClient subclass that handles Planetary Computer's pagination limits.
40
+
41
+ Planetary Computer STAC API does not support standard pagination and has a max
42
+ limit of 1000. If the initial query returns 1000 items, this client paginates
43
+ by sorting by ID and using gt (greater than) queries to fetch subsequent pages.
44
+ """
45
+
46
+ @override
47
+ def search(
48
+ self,
49
+ collections: list[str] | None = None,
50
+ bbox: tuple[float, float, float, float] | None = None,
51
+ intersects: dict[str, Any] | None = None,
52
+ date_time: datetime | tuple[datetime, datetime] | None = None,
53
+ ids: list[str] | None = None,
54
+ limit: int | None = None,
55
+ query: dict[str, Any] | None = None,
56
+ sortby: list[dict[str, str]] | None = None,
57
+ ) -> list[StacItem]:
58
+ # We will use sortby for pagination, so the caller must not set it.
59
+ if sortby is not None:
60
+ raise ValueError("sortby must not be set for PlanetaryComputerStacClient")
61
+
62
+ # First, try a simple query with the PC limit to detect if pagination is needed.
63
+ # We always use PLANETARY_COMPUTER_LIMIT for the request because PC doesn't
64
+ # support standard pagination, and we need to detect when we hit the limit
65
+ # to switch to ID-based pagination.
66
+ # We could just start sorting by ID here and do pagination, but we treate it as
67
+ # a special case to avoid sorting since that seems to speed up the query.
68
+ stac_items = super().search(
69
+ collections=collections,
70
+ bbox=bbox,
71
+ intersects=intersects,
72
+ date_time=date_time,
73
+ ids=ids,
74
+ limit=PLANETARY_COMPUTER_LIMIT,
75
+ query=query,
76
+ )
77
+
78
+ # If we got fewer than the PC limit, we have all the results.
79
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
80
+ return stac_items
81
+
82
+ # We hit the limit, so we need to paginate by ID.
83
+ # Re-fetch with sorting by ID to ensure consistent ordering for pagination.
84
+ logger.debug(
85
+ "Initial request returned %d items (at limit), switching to ID pagination",
86
+ len(stac_items),
87
+ )
88
+
89
+ all_items: list[StacItem] = []
90
+ last_id: str | None = None
91
+
92
+ while True:
93
+ # Build query with id > last_id if we're paginating.
94
+ combined_query: dict[str, Any] = dict(query) if query else {}
95
+ if last_id is not None:
96
+ combined_query["id"] = {"gt": last_id}
97
+
98
+ stac_items = super().search(
99
+ collections=collections,
100
+ bbox=bbox,
101
+ intersects=intersects,
102
+ date_time=date_time,
103
+ ids=ids,
104
+ limit=PLANETARY_COMPUTER_LIMIT,
105
+ query=combined_query if combined_query else None,
106
+ sortby=[{"field": "id", "direction": "asc"}],
107
+ )
108
+
109
+ all_items.extend(stac_items)
110
+
111
+ # If we got fewer than the limit, we've fetched everything.
112
+ if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
113
+ break
114
+
115
+ # Otherwise, paginate using the last item's ID.
116
+ last_id = stac_items[-1].id
117
+ logger.debug(
118
+ "Got %d items, paginating with id > %s",
119
+ len(stac_items),
120
+ last_id,
121
+ )
122
+
123
+ logger.debug("Total items fetched: %d", len(all_items))
124
+ return all_items
125
+
32
126
 
33
127
  class PlanetaryComputer(StacDataSource, TileStore):
34
128
  """Modality-agnostic data source for data on Microsoft Planetary Computer.
@@ -100,6 +194,10 @@ class PlanetaryComputer(StacDataSource, TileStore):
100
194
  required_assets=required_assets,
101
195
  cache_dir=cache_upath,
102
196
  )
197
+
198
+ # Replace the client with PlanetaryComputerStacClient to handle PC's pagination limits.
199
+ self.client = PlanetaryComputerStacClient(self.STAC_ENDPOINT)
200
+
103
201
  self.asset_bands = asset_bands
104
202
  self.timeout = timeout
105
203
  self.skip_items_missing_assets = skip_items_missing_assets
@@ -567,3 +665,53 @@ class Naip(PlanetaryComputer):
567
665
  context=context,
568
666
  **kwargs,
569
667
  )
668
+
669
+
670
+ class CopDemGlo30(PlanetaryComputer):
671
+ """A data source for Copernicus DEM GLO-30 (30m) on Microsoft Planetary Computer.
672
+
673
+ See https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-30.
674
+ """
675
+
676
+ COLLECTION_NAME = "cop-dem-glo-30"
677
+ DATA_ASSET = "data"
678
+
679
+ def __init__(
680
+ self,
681
+ band_name: str = "DEM",
682
+ context: DataSourceContext = DataSourceContext(),
683
+ **kwargs: Any,
684
+ ):
685
+ """Initialize a new CopDemGlo30 instance.
686
+
687
+ Args:
688
+ band_name: band name to use if the layer config is missing from the
689
+ context.
690
+ context: the data source context.
691
+ kwargs: additional arguments to pass to PlanetaryComputer.
692
+ """
693
+ if context.layer_config is not None:
694
+ if len(context.layer_config.band_sets) != 1:
695
+ raise ValueError("expected a single band set")
696
+ if len(context.layer_config.band_sets[0].bands) != 1:
697
+ raise ValueError("expected band set to have a single band")
698
+ band_name = context.layer_config.band_sets[0].bands[0]
699
+
700
+ super().__init__(
701
+ collection_name=self.COLLECTION_NAME,
702
+ asset_bands={self.DATA_ASSET: [band_name]},
703
+ # Skip since all items should have the same asset(s).
704
+ skip_items_missing_assets=True,
705
+ context=context,
706
+ **kwargs,
707
+ )
708
+
709
+ def _stac_item_to_item(self, stac_item: Any) -> SourceItem:
710
+ # Copernicus DEM is static; ignore item timestamps so it matches any window.
711
+ item = super()._stac_item_to_item(stac_item)
712
+ item.geometry = STGeometry(item.geometry.projection, item.geometry.shp, None)
713
+ return item
714
+
715
+ def _get_search_time_range(self, geometry: STGeometry) -> None:
716
+ # Copernicus DEM is static; do not filter STAC searches by time.
717
+ return None
@@ -1,6 +1,7 @@
1
1
  """A partial data source implementation providing get_items using a STAC API."""
2
2
 
3
3
  import json
4
+ from datetime import datetime
4
5
  from typing import Any
5
6
 
6
7
  import shapely
@@ -11,6 +12,7 @@ from rslearn.const import WGS84_PROJECTION
11
12
  from rslearn.data_sources.data_source import Item, ItemLookupDataSource
12
13
  from rslearn.data_sources.utils import match_candidate_items_to_window
13
14
  from rslearn.log_utils import get_logger
15
+ from rslearn.utils.fsspec import open_atomic
14
16
  from rslearn.utils.geometry import STGeometry
15
17
  from rslearn.utils.stac import StacClient, StacItem
16
18
 
@@ -132,6 +134,24 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
132
134
 
133
135
  return SourceItem(stac_item.id, geom, asset_urls, properties)
134
136
 
137
+ def _get_search_time_range(
138
+ self, geometry: STGeometry
139
+ ) -> datetime | tuple[datetime, datetime] | None:
140
+ """Get time range to include in STAC API search.
141
+
142
+ By default, we filter STAC searches to the window's time range. Subclasses can
143
+ override this to disable time filtering for "static" datasets.
144
+
145
+ Args:
146
+ geometry: the geometry we are searching for.
147
+
148
+ Returns:
149
+ the time range (or timestamp) to pass to the STAC search, or None to avoid
150
+ temporal filtering in the search request.
151
+ """
152
+ # Note: StacClient.search accepts either a datetime or a (start, end) tuple.
153
+ return geometry.time_range
154
+
135
155
  def get_item_by_name(self, name: str) -> SourceItem:
136
156
  """Gets an item by name.
137
157
 
@@ -168,7 +188,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
168
188
 
169
189
  # Finally we cache it if cache_dir is set.
170
190
  if cache_fname is not None:
171
- with cache_fname.open("w") as f:
191
+ with open_atomic(cache_fname, "w") as f:
172
192
  json.dump(item.serialize(), f)
173
193
 
174
194
  return item
@@ -191,10 +211,11 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
191
211
  # for each requested geometry.
192
212
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
193
213
  logger.debug("performing STAC search for geometry %s", wgs84_geometry)
214
+ search_time_range = self._get_search_time_range(wgs84_geometry)
194
215
  stac_items = self.client.search(
195
216
  collections=[self.collection_name],
196
217
  intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
197
- date_time=wgs84_geometry.time_range,
218
+ date_time=search_time_range,
198
219
  query=self.query,
199
220
  limit=self.limit,
200
221
  )
@@ -239,7 +260,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
239
260
  cache_fname = self.cache_dir / f"{item.name}.json"
240
261
  if cache_fname.exists():
241
262
  continue
242
- with cache_fname.open("w") as f:
263
+ with open_atomic(cache_fname, "w") as f:
243
264
  json.dump(item.serialize(), f)
244
265
 
245
266
  cur_groups = match_candidate_items_to_window(
@@ -2,6 +2,7 @@
2
2
 
3
3
  import argparse
4
4
  import multiprocessing
5
+ import os
5
6
  import random
6
7
  import sys
7
8
  import time
@@ -45,6 +46,7 @@ handler_registry = {}
45
46
  ItemType = TypeVar("ItemType", bound="Item")
46
47
 
47
48
  MULTIPROCESSING_CONTEXT = "forkserver"
49
+ MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
48
50
 
49
51
 
50
52
  def register_handler(category: Any, command: str) -> Callable:
@@ -837,7 +839,8 @@ def model_predict() -> None:
837
839
  def main() -> None:
838
840
  """CLI entrypoint."""
839
841
  try:
840
- multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)
842
+ mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
843
+ multiprocessing.set_start_method(mp_context)
841
844
  except RuntimeError as e:
842
845
  logger.error(
843
846
  f"Multiprocessing context already set to {multiprocessing.get_context()}: "
@@ -180,7 +180,7 @@ class SimpleTimeSeries(FeatureExtractor):
180
180
  # want to pass 2 timesteps to the model.
181
181
  # TODO is probably to make this behaviour clearer but lets leave it like
182
182
  # this for now to not break things.
183
- num_timesteps = images.shape[1] // image_channels
183
+ num_timesteps = image_channels // images.shape[1]
184
184
  batched_timesteps = images.shape[2] // num_timesteps
185
185
  images = rearrange(
186
186
  images,
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
210
210
  # Fail silently for single-dataset case, which is okay
211
211
  pass
212
212
 
213
+ def on_validation_epoch_end(self) -> None:
214
+ """Compute and log validation metrics at epoch end.
215
+
216
+ We manually compute and log metrics here (instead of passing the MetricCollection
217
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
218
+ metrics, while log_dict expects each metric to return a scalar tensor.
219
+ """
220
+ metrics = self.val_metrics.compute()
221
+ self.log_dict(metrics)
222
+ self.val_metrics.reset()
223
+
213
224
  def on_test_epoch_end(self) -> None:
214
- """Optionally save the test metrics to a file."""
225
+ """Compute and log test metrics at epoch end, optionally save to file.
226
+
227
+ We manually compute and log metrics here (instead of passing the MetricCollection
228
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
229
+ metrics, while log_dict expects each metric to return a scalar tensor.
230
+ """
231
+ metrics = self.test_metrics.compute()
232
+ self.log_dict(metrics)
233
+ self.test_metrics.reset()
234
+
215
235
  if self.metrics_file:
216
236
  with open(self.metrics_file, "w") as f:
217
- metrics = self.test_metrics.compute()
218
237
  metrics_dict = {k: v.item() for k, v in metrics.items()}
219
238
  json.dump(metrics_dict, f, indent=4)
220
239
  logger.info(f"Saved metrics to {self.metrics_file}")
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
300
319
  sync_dist=True,
301
320
  )
302
321
  self.val_metrics.update(outputs, targets)
303
- self.log_dict(
304
- self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
305
- )
306
322
 
307
323
  def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
308
324
  """Compute the test loss and additional metrics.
@@ -340,9 +356,6 @@ class RslearnLightningModule(L.LightningModule):
340
356
  sync_dist=True,
341
357
  )
342
358
  self.test_metrics.update(outputs, targets)
343
- self.log_dict(
344
- self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
345
- )
346
359
 
347
360
  if self.visualize_dir:
348
361
  for inp, target, output, metadata in zip(
@@ -118,13 +118,16 @@ class MultiTask(Task):
118
118
 
119
119
  def get_metrics(self) -> MetricCollection:
120
120
  """Get metrics for this task."""
121
- metrics = []
121
+ # Flatten metrics into a single dict with task_name/ prefix to avoid nested
122
+ # MetricCollections. Nested collections cause issues because MetricCollection
123
+ # has postfix=None which breaks MetricCollection.compute().
124
+ all_metrics = {}
122
125
  for task_name, task in self.tasks.items():
123
- cur_metrics = {}
124
126
  for metric_name, metric in task.get_metrics().items():
125
- cur_metrics[metric_name] = MetricWrapper(task_name, metric)
126
- metrics.append(MetricCollection(cur_metrics, prefix=f"{task_name}/"))
127
- return MetricCollection(metrics)
127
+ all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
128
+ task_name, metric
129
+ )
130
+ return MetricCollection(all_metrics)
128
131
 
129
132
 
130
133
  class MetricWrapper(Metric):
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
100
100
  raise ValueError(
101
101
  f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
102
102
  )
103
- return (raw_output / self.scale_factor).cpu().numpy()
103
+ return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
104
104
 
105
105
  def visualize(
106
106
  self,