rslearn 0.0.22__tar.gz → 0.0.24__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.22/rslearn.egg-info → rslearn-0.0.24}/PKG-INFO +1 -1
- {rslearn-0.0.22 → rslearn-0.0.24}/pyproject.toml +1 -1
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/planetary_computer.py +149 -1
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/stac.py +24 -3
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/main.py +4 -1
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/simple_time_series.py +1 -1
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/lightning_module.py +21 -8
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/multi_task.py +8 -5
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/per_pixel_regression.py +1 -1
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/segmentation.py +163 -22
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/raster_format.py +17 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/stac.py +4 -0
- {rslearn-0.0.22 → rslearn-0.0.24/rslearn.egg-info}/PKG-INFO +1 -1
- {rslearn-0.0.22 → rslearn-0.0.24}/LICENSE +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/NOTICE +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/README.md +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/const.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/aws_sentinel2_element84.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/copernicus.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/earthdaily.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/eurocrops.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/gcp_public_data.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/soilgrids.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/utils.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/storage/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/storage/file.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/storage/storage.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/lightning_cli.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/attention_pooling.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/clay/clay.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/component.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/concatenate_features.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/olmoearth_pretrain/model.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/py.typed +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/template_params.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/all_patches_dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/dataset.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/model_context.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/prediction_writer.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/resize.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/SOURCES.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.22 → rslearn-0.0.24}/setup.cfg +0 -0
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
5
|
import xml.etree.ElementTree as ET
|
|
6
|
-
from datetime import timedelta
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import affine
|
|
@@ -12,6 +12,7 @@ import planetary_computer
|
|
|
12
12
|
import rasterio
|
|
13
13
|
import requests
|
|
14
14
|
from rasterio.enums import Resampling
|
|
15
|
+
from typing_extensions import override
|
|
15
16
|
from upath import UPath
|
|
16
17
|
|
|
17
18
|
from rslearn.config import LayerConfig
|
|
@@ -24,11 +25,104 @@ from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
|
24
25
|
from rslearn.utils.fsspec import join_upath
|
|
25
26
|
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
|
|
26
27
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
28
|
+
from rslearn.utils.stac import StacClient, StacItem
|
|
27
29
|
|
|
28
30
|
from .copernicus import get_harmonize_callback
|
|
29
31
|
|
|
30
32
|
logger = get_logger(__name__)
|
|
31
33
|
|
|
34
|
+
# Max limit accepted by Planetary Computer API.
|
|
35
|
+
PLANETARY_COMPUTER_LIMIT = 1000
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PlanetaryComputerStacClient(StacClient):
|
|
39
|
+
"""A StacClient subclass that handles Planetary Computer's pagination limits.
|
|
40
|
+
|
|
41
|
+
Planetary Computer STAC API does not support standard pagination and has a max
|
|
42
|
+
limit of 1000. If the initial query returns 1000 items, this client paginates
|
|
43
|
+
by sorting by ID and using gt (greater than) queries to fetch subsequent pages.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def search(
|
|
48
|
+
self,
|
|
49
|
+
collections: list[str] | None = None,
|
|
50
|
+
bbox: tuple[float, float, float, float] | None = None,
|
|
51
|
+
intersects: dict[str, Any] | None = None,
|
|
52
|
+
date_time: datetime | tuple[datetime, datetime] | None = None,
|
|
53
|
+
ids: list[str] | None = None,
|
|
54
|
+
limit: int | None = None,
|
|
55
|
+
query: dict[str, Any] | None = None,
|
|
56
|
+
sortby: list[dict[str, str]] | None = None,
|
|
57
|
+
) -> list[StacItem]:
|
|
58
|
+
# We will use sortby for pagination, so the caller must not set it.
|
|
59
|
+
if sortby is not None:
|
|
60
|
+
raise ValueError("sortby must not be set for PlanetaryComputerStacClient")
|
|
61
|
+
|
|
62
|
+
# First, try a simple query with the PC limit to detect if pagination is needed.
|
|
63
|
+
# We always use PLANETARY_COMPUTER_LIMIT for the request because PC doesn't
|
|
64
|
+
# support standard pagination, and we need to detect when we hit the limit
|
|
65
|
+
# to switch to ID-based pagination.
|
|
66
|
+
# We could just start sorting by ID here and do pagination, but we treate it as
|
|
67
|
+
# a special case to avoid sorting since that seems to speed up the query.
|
|
68
|
+
stac_items = super().search(
|
|
69
|
+
collections=collections,
|
|
70
|
+
bbox=bbox,
|
|
71
|
+
intersects=intersects,
|
|
72
|
+
date_time=date_time,
|
|
73
|
+
ids=ids,
|
|
74
|
+
limit=PLANETARY_COMPUTER_LIMIT,
|
|
75
|
+
query=query,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# If we got fewer than the PC limit, we have all the results.
|
|
79
|
+
if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
|
|
80
|
+
return stac_items
|
|
81
|
+
|
|
82
|
+
# We hit the limit, so we need to paginate by ID.
|
|
83
|
+
# Re-fetch with sorting by ID to ensure consistent ordering for pagination.
|
|
84
|
+
logger.debug(
|
|
85
|
+
"Initial request returned %d items (at limit), switching to ID pagination",
|
|
86
|
+
len(stac_items),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
all_items: list[StacItem] = []
|
|
90
|
+
last_id: str | None = None
|
|
91
|
+
|
|
92
|
+
while True:
|
|
93
|
+
# Build query with id > last_id if we're paginating.
|
|
94
|
+
combined_query: dict[str, Any] = dict(query) if query else {}
|
|
95
|
+
if last_id is not None:
|
|
96
|
+
combined_query["id"] = {"gt": last_id}
|
|
97
|
+
|
|
98
|
+
stac_items = super().search(
|
|
99
|
+
collections=collections,
|
|
100
|
+
bbox=bbox,
|
|
101
|
+
intersects=intersects,
|
|
102
|
+
date_time=date_time,
|
|
103
|
+
ids=ids,
|
|
104
|
+
limit=PLANETARY_COMPUTER_LIMIT,
|
|
105
|
+
query=combined_query if combined_query else None,
|
|
106
|
+
sortby=[{"field": "id", "direction": "asc"}],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
all_items.extend(stac_items)
|
|
110
|
+
|
|
111
|
+
# If we got fewer than the limit, we've fetched everything.
|
|
112
|
+
if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
# Otherwise, paginate using the last item's ID.
|
|
116
|
+
last_id = stac_items[-1].id
|
|
117
|
+
logger.debug(
|
|
118
|
+
"Got %d items, paginating with id > %s",
|
|
119
|
+
len(stac_items),
|
|
120
|
+
last_id,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
logger.debug("Total items fetched: %d", len(all_items))
|
|
124
|
+
return all_items
|
|
125
|
+
|
|
32
126
|
|
|
33
127
|
class PlanetaryComputer(StacDataSource, TileStore):
|
|
34
128
|
"""Modality-agnostic data source for data on Microsoft Planetary Computer.
|
|
@@ -100,6 +194,10 @@ class PlanetaryComputer(StacDataSource, TileStore):
|
|
|
100
194
|
required_assets=required_assets,
|
|
101
195
|
cache_dir=cache_upath,
|
|
102
196
|
)
|
|
197
|
+
|
|
198
|
+
# Replace the client with PlanetaryComputerStacClient to handle PC's pagination limits.
|
|
199
|
+
self.client = PlanetaryComputerStacClient(self.STAC_ENDPOINT)
|
|
200
|
+
|
|
103
201
|
self.asset_bands = asset_bands
|
|
104
202
|
self.timeout = timeout
|
|
105
203
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
@@ -567,3 +665,53 @@ class Naip(PlanetaryComputer):
|
|
|
567
665
|
context=context,
|
|
568
666
|
**kwargs,
|
|
569
667
|
)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
class CopDemGlo30(PlanetaryComputer):
|
|
671
|
+
"""A data source for Copernicus DEM GLO-30 (30m) on Microsoft Planetary Computer.
|
|
672
|
+
|
|
673
|
+
See https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-30.
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
COLLECTION_NAME = "cop-dem-glo-30"
|
|
677
|
+
DATA_ASSET = "data"
|
|
678
|
+
|
|
679
|
+
def __init__(
|
|
680
|
+
self,
|
|
681
|
+
band_name: str = "DEM",
|
|
682
|
+
context: DataSourceContext = DataSourceContext(),
|
|
683
|
+
**kwargs: Any,
|
|
684
|
+
):
|
|
685
|
+
"""Initialize a new CopDemGlo30 instance.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
band_name: band name to use if the layer config is missing from the
|
|
689
|
+
context.
|
|
690
|
+
context: the data source context.
|
|
691
|
+
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
692
|
+
"""
|
|
693
|
+
if context.layer_config is not None:
|
|
694
|
+
if len(context.layer_config.band_sets) != 1:
|
|
695
|
+
raise ValueError("expected a single band set")
|
|
696
|
+
if len(context.layer_config.band_sets[0].bands) != 1:
|
|
697
|
+
raise ValueError("expected band set to have a single band")
|
|
698
|
+
band_name = context.layer_config.band_sets[0].bands[0]
|
|
699
|
+
|
|
700
|
+
super().__init__(
|
|
701
|
+
collection_name=self.COLLECTION_NAME,
|
|
702
|
+
asset_bands={self.DATA_ASSET: [band_name]},
|
|
703
|
+
# Skip since all items should have the same asset(s).
|
|
704
|
+
skip_items_missing_assets=True,
|
|
705
|
+
context=context,
|
|
706
|
+
**kwargs,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
def _stac_item_to_item(self, stac_item: Any) -> SourceItem:
|
|
710
|
+
# Copernicus DEM is static; ignore item timestamps so it matches any window.
|
|
711
|
+
item = super()._stac_item_to_item(stac_item)
|
|
712
|
+
item.geometry = STGeometry(item.geometry.projection, item.geometry.shp, None)
|
|
713
|
+
return item
|
|
714
|
+
|
|
715
|
+
def _get_search_time_range(self, geometry: STGeometry) -> None:
|
|
716
|
+
# Copernicus DEM is static; do not filter STAC searches by time.
|
|
717
|
+
return None
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""A partial data source implementation providing get_items using a STAC API."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
6
7
|
import shapely
|
|
@@ -11,6 +12,7 @@ from rslearn.const import WGS84_PROJECTION
|
|
|
11
12
|
from rslearn.data_sources.data_source import Item, ItemLookupDataSource
|
|
12
13
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
13
14
|
from rslearn.log_utils import get_logger
|
|
15
|
+
from rslearn.utils.fsspec import open_atomic
|
|
14
16
|
from rslearn.utils.geometry import STGeometry
|
|
15
17
|
from rslearn.utils.stac import StacClient, StacItem
|
|
16
18
|
|
|
@@ -132,6 +134,24 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
132
134
|
|
|
133
135
|
return SourceItem(stac_item.id, geom, asset_urls, properties)
|
|
134
136
|
|
|
137
|
+
def _get_search_time_range(
|
|
138
|
+
self, geometry: STGeometry
|
|
139
|
+
) -> datetime | tuple[datetime, datetime] | None:
|
|
140
|
+
"""Get time range to include in STAC API search.
|
|
141
|
+
|
|
142
|
+
By default, we filter STAC searches to the window's time range. Subclasses can
|
|
143
|
+
override this to disable time filtering for "static" datasets.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
geometry: the geometry we are searching for.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
the time range (or timestamp) to pass to the STAC search, or None to avoid
|
|
150
|
+
temporal filtering in the search request.
|
|
151
|
+
"""
|
|
152
|
+
# Note: StacClient.search accepts either a datetime or a (start, end) tuple.
|
|
153
|
+
return geometry.time_range
|
|
154
|
+
|
|
135
155
|
def get_item_by_name(self, name: str) -> SourceItem:
|
|
136
156
|
"""Gets an item by name.
|
|
137
157
|
|
|
@@ -168,7 +188,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
168
188
|
|
|
169
189
|
# Finally we cache it if cache_dir is set.
|
|
170
190
|
if cache_fname is not None:
|
|
171
|
-
with cache_fname
|
|
191
|
+
with open_atomic(cache_fname, "w") as f:
|
|
172
192
|
json.dump(item.serialize(), f)
|
|
173
193
|
|
|
174
194
|
return item
|
|
@@ -191,10 +211,11 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
191
211
|
# for each requested geometry.
|
|
192
212
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
193
213
|
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
|
|
214
|
+
search_time_range = self._get_search_time_range(wgs84_geometry)
|
|
194
215
|
stac_items = self.client.search(
|
|
195
216
|
collections=[self.collection_name],
|
|
196
217
|
intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
|
|
197
|
-
date_time=
|
|
218
|
+
date_time=search_time_range,
|
|
198
219
|
query=self.query,
|
|
199
220
|
limit=self.limit,
|
|
200
221
|
)
|
|
@@ -239,7 +260,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
239
260
|
cache_fname = self.cache_dir / f"{item.name}.json"
|
|
240
261
|
if cache_fname.exists():
|
|
241
262
|
continue
|
|
242
|
-
with cache_fname
|
|
263
|
+
with open_atomic(cache_fname, "w") as f:
|
|
243
264
|
json.dump(item.serialize(), f)
|
|
244
265
|
|
|
245
266
|
cur_groups = match_candidate_items_to_window(
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import multiprocessing
|
|
5
|
+
import os
|
|
5
6
|
import random
|
|
6
7
|
import sys
|
|
7
8
|
import time
|
|
@@ -45,6 +46,7 @@ handler_registry = {}
|
|
|
45
46
|
ItemType = TypeVar("ItemType", bound="Item")
|
|
46
47
|
|
|
47
48
|
MULTIPROCESSING_CONTEXT = "forkserver"
|
|
49
|
+
MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
def register_handler(category: Any, command: str) -> Callable:
|
|
@@ -837,7 +839,8 @@ def model_predict() -> None:
|
|
|
837
839
|
def main() -> None:
|
|
838
840
|
"""CLI entrypoint."""
|
|
839
841
|
try:
|
|
840
|
-
|
|
842
|
+
mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
|
|
843
|
+
multiprocessing.set_start_method(mp_context)
|
|
841
844
|
except RuntimeError as e:
|
|
842
845
|
logger.error(
|
|
843
846
|
f"Multiprocessing context already set to {multiprocessing.get_context()}: "
|
|
@@ -180,7 +180,7 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
180
180
|
# want to pass 2 timesteps to the model.
|
|
181
181
|
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
182
|
# this for now to not break things.
|
|
183
|
-
num_timesteps = images.shape[1]
|
|
183
|
+
num_timesteps = image_channels // images.shape[1]
|
|
184
184
|
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
185
|
images = rearrange(
|
|
186
186
|
images,
|
|
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
210
210
|
# Fail silently for single-dataset case, which is okay
|
|
211
211
|
pass
|
|
212
212
|
|
|
213
|
+
def on_validation_epoch_end(self) -> None:
|
|
214
|
+
"""Compute and log validation metrics at epoch end.
|
|
215
|
+
|
|
216
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
217
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
218
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
219
|
+
"""
|
|
220
|
+
metrics = self.val_metrics.compute()
|
|
221
|
+
self.log_dict(metrics)
|
|
222
|
+
self.val_metrics.reset()
|
|
223
|
+
|
|
213
224
|
def on_test_epoch_end(self) -> None:
|
|
214
|
-
"""
|
|
225
|
+
"""Compute and log test metrics at epoch end, optionally save to file.
|
|
226
|
+
|
|
227
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
228
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
229
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
230
|
+
"""
|
|
231
|
+
metrics = self.test_metrics.compute()
|
|
232
|
+
self.log_dict(metrics)
|
|
233
|
+
self.test_metrics.reset()
|
|
234
|
+
|
|
215
235
|
if self.metrics_file:
|
|
216
236
|
with open(self.metrics_file, "w") as f:
|
|
217
|
-
metrics = self.test_metrics.compute()
|
|
218
237
|
metrics_dict = {k: v.item() for k, v in metrics.items()}
|
|
219
238
|
json.dump(metrics_dict, f, indent=4)
|
|
220
239
|
logger.info(f"Saved metrics to {self.metrics_file}")
|
|
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
300
319
|
sync_dist=True,
|
|
301
320
|
)
|
|
302
321
|
self.val_metrics.update(outputs, targets)
|
|
303
|
-
self.log_dict(
|
|
304
|
-
self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
305
|
-
)
|
|
306
322
|
|
|
307
323
|
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
|
308
324
|
"""Compute the test loss and additional metrics.
|
|
@@ -340,9 +356,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
340
356
|
sync_dist=True,
|
|
341
357
|
)
|
|
342
358
|
self.test_metrics.update(outputs, targets)
|
|
343
|
-
self.log_dict(
|
|
344
|
-
self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
345
|
-
)
|
|
346
359
|
|
|
347
360
|
if self.visualize_dir:
|
|
348
361
|
for inp, target, output, metadata in zip(
|
|
@@ -118,13 +118,16 @@ class MultiTask(Task):
|
|
|
118
118
|
|
|
119
119
|
def get_metrics(self) -> MetricCollection:
|
|
120
120
|
"""Get metrics for this task."""
|
|
121
|
-
metrics
|
|
121
|
+
# Flatten metrics into a single dict with task_name/ prefix to avoid nested
|
|
122
|
+
# MetricCollections. Nested collections cause issues because MetricCollection
|
|
123
|
+
# has postfix=None which breaks MetricCollection.compute().
|
|
124
|
+
all_metrics = {}
|
|
122
125
|
for task_name, task in self.tasks.items():
|
|
123
|
-
cur_metrics = {}
|
|
124
126
|
for metric_name, metric in task.get_metrics().items():
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
|
|
128
|
+
task_name, metric
|
|
129
|
+
)
|
|
130
|
+
return MetricCollection(all_metrics)
|
|
128
131
|
|
|
129
132
|
|
|
130
133
|
class MetricWrapper(Metric):
|
|
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
100
100
|
raise ValueError(
|
|
101
101
|
f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
|
|
102
102
|
)
|
|
103
|
-
return (raw_output / self.scale_factor).cpu().numpy()
|
|
103
|
+
return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
|
|
104
104
|
|
|
105
105
|
def visualize(
|
|
106
106
|
self,
|