rslearn 0.0.14__tar.gz → 0.0.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.14/rslearn.egg-info → rslearn-0.0.16}/PKG-INFO +1 -1
- {rslearn-0.0.14 → rslearn-0.0.16}/pyproject.toml +1 -1
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/config/__init__.py +2 -10
- rslearn-0.0.16/rslearn/config/dataset.py +596 -0
- rslearn-0.0.16/rslearn/data_sources/__init__.py +28 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/aws_landsat.py +13 -24
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/aws_open_data.py +21 -46
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/aws_sentinel1.py +3 -14
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/climate_data_store.py +21 -40
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/copernicus.py +30 -91
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/data_source.py +26 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/earthdaily.py +13 -38
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/earthdata_srtm.py +14 -32
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/eurocrops.py +5 -9
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/gcp_public_data.py +46 -43
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/google_earth_engine.py +31 -44
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/local_files.py +91 -100
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/openstreetmap.py +21 -51
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/planet.py +12 -30
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/planet_basemap.py +4 -25
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/planetary_computer.py +58 -141
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/usda_cdl.py +15 -26
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/usgs_landsat.py +4 -29
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/utils.py +9 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/worldcereal.py +47 -54
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/worldcover.py +16 -14
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/worldpop.py +15 -18
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/xyz_tiles.py +11 -30
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/dataset.py +6 -6
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/manage.py +28 -26
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/materialize.py +9 -45
- rslearn-0.0.16/rslearn/lightning_cli.py +436 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/main.py +3 -3
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/clay/clay.py +14 -1
- rslearn-0.0.16/rslearn/models/concatenate_features.py +93 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/croma.py +26 -3
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/satlaspretrain.py +18 -4
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/terramind.py +19 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/tile_stores/__init__.py +0 -11
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/dataset.py +4 -12
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/prediction_writer.py +16 -32
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/classification.py +2 -1
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/fsspec.py +20 -0
- rslearn-0.0.16/rslearn/utils/jsonargparse.py +112 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/raster_format.py +1 -41
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.14 → rslearn-0.0.16/rslearn.egg-info}/PKG-INFO +1 -1
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/SOURCES.txt +1 -2
- rslearn-0.0.14/rslearn/config/dataset.py +0 -602
- rslearn-0.0.14/rslearn/data_sources/__init__.py +0 -51
- rslearn-0.0.14/rslearn/data_sources/geotiff.py +0 -1
- rslearn-0.0.14/rslearn/data_sources/raster_source.py +0 -23
- rslearn-0.0.14/rslearn/lightning_cli.py +0 -67
- rslearn-0.0.14/rslearn/utils/jsonargparse.py +0 -33
- {rslearn-0.0.14 → rslearn-0.0.16}/LICENSE +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/NOTICE +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/README.md +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/arg_parser.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/const.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/handler_summaries.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/anysat.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/clay/configs/metadata.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/dinov3.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/feature_center_crop.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/galileo/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/galileo/galileo.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/galileo/single_file_galileo.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/olmoearth_pretrain/model.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/presto/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/presto/presto.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/presto/single_file_presto.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/prithvi.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/resize_features.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/trunk.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/py.typed +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/template_params.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/all_patches_dataset.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/adapters.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/gradients.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/data_module.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/lightning_module.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/scheduler.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/embedding.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/per_pixel_regression.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/regression.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/segmentation.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/mask.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/select_bands.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/sentinel1.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/geometry.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/requires.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.14 → rslearn-0.0.16}/setup.cfg +0 -0
|
@@ -3,33 +3,25 @@
|
|
|
3
3
|
from .dataset import (
|
|
4
4
|
BandSetConfig,
|
|
5
5
|
CompositingMethod,
|
|
6
|
+
DatasetConfig,
|
|
6
7
|
DataSourceConfig,
|
|
7
8
|
DType,
|
|
8
9
|
LayerConfig,
|
|
9
10
|
LayerType,
|
|
10
11
|
QueryConfig,
|
|
11
|
-
RasterFormatConfig,
|
|
12
|
-
RasterLayerConfig,
|
|
13
12
|
SpaceMode,
|
|
14
13
|
TimeMode,
|
|
15
|
-
VectorFormatConfig,
|
|
16
|
-
VectorLayerConfig,
|
|
17
|
-
load_layer_config,
|
|
18
14
|
)
|
|
19
15
|
|
|
20
16
|
__all__ = [
|
|
21
17
|
"BandSetConfig",
|
|
22
18
|
"CompositingMethod",
|
|
19
|
+
"DatasetConfig",
|
|
23
20
|
"DataSourceConfig",
|
|
24
21
|
"DType",
|
|
25
22
|
"LayerConfig",
|
|
26
23
|
"LayerType",
|
|
27
24
|
"QueryConfig",
|
|
28
|
-
"RasterFormatConfig",
|
|
29
|
-
"RasterLayerConfig",
|
|
30
25
|
"SpaceMode",
|
|
31
26
|
"TimeMode",
|
|
32
|
-
"VectorFormatConfig",
|
|
33
|
-
"VectorLayerConfig",
|
|
34
|
-
"load_layer_config",
|
|
35
27
|
]
|
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
"""Classes for storing configuration of a dataset."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import functools
|
|
5
|
+
import json
|
|
6
|
+
import warnings
|
|
7
|
+
from datetime import timedelta
|
|
8
|
+
from enum import StrEnum
|
|
9
|
+
from typing import TYPE_CHECKING, Annotated, Any
|
|
10
|
+
|
|
11
|
+
import jsonargparse
|
|
12
|
+
import numpy as np
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
import pytimeparse
|
|
15
|
+
from pydantic import (
|
|
16
|
+
BaseModel,
|
|
17
|
+
BeforeValidator,
|
|
18
|
+
ConfigDict,
|
|
19
|
+
Field,
|
|
20
|
+
PlainSerializer,
|
|
21
|
+
field_validator,
|
|
22
|
+
model_validator,
|
|
23
|
+
)
|
|
24
|
+
from rasterio.enums import Resampling
|
|
25
|
+
from upath import UPath
|
|
26
|
+
|
|
27
|
+
from rslearn.log_utils import get_logger
|
|
28
|
+
from rslearn.utils import PixelBounds, Projection
|
|
29
|
+
from rslearn.utils.raster_format import RasterFormat
|
|
30
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from rslearn.data_sources.data_source import DataSource
|
|
34
|
+
|
|
35
|
+
logger = get_logger("__name__")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def ensure_timedelta(v: Any) -> Any:
|
|
39
|
+
"""Ensure the value is a timedelta.
|
|
40
|
+
|
|
41
|
+
If the value is a string, we try to parse it with pytimeparse.
|
|
42
|
+
|
|
43
|
+
This function is meant to be used like Annotated[timedelta, BeforeValidator(ensure_timedelta)].
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(v, timedelta):
|
|
46
|
+
return v
|
|
47
|
+
if isinstance(v, str):
|
|
48
|
+
return pytimeparse.parse(v)
|
|
49
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def ensure_optional_timedelta(v: Any) -> Any:
|
|
53
|
+
"""Like ensure_timedelta, but allows None as a value."""
|
|
54
|
+
if v is None:
|
|
55
|
+
return None
|
|
56
|
+
if isinstance(v, timedelta):
|
|
57
|
+
return v
|
|
58
|
+
if isinstance(v, str):
|
|
59
|
+
return pytimeparse.parse(v)
|
|
60
|
+
raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def serialize_optional_timedelta(v: timedelta | None) -> str | None:
|
|
64
|
+
"""Serialize an optional timedelta for compatibility with pytimeparse."""
|
|
65
|
+
if v is None:
|
|
66
|
+
return None
|
|
67
|
+
return str(v.total_seconds()) + "s"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DType(StrEnum):
|
|
71
|
+
"""Data type of a raster."""
|
|
72
|
+
|
|
73
|
+
UINT8 = "uint8"
|
|
74
|
+
UINT16 = "uint16"
|
|
75
|
+
UINT32 = "uint32"
|
|
76
|
+
UINT64 = "uint64"
|
|
77
|
+
INT8 = "int8"
|
|
78
|
+
INT16 = "int16"
|
|
79
|
+
INT32 = "int32"
|
|
80
|
+
INT64 = "int64"
|
|
81
|
+
FLOAT32 = "float32"
|
|
82
|
+
|
|
83
|
+
def get_numpy_dtype(self) -> npt.DTypeLike:
|
|
84
|
+
"""Returns numpy dtype object corresponding to this DType."""
|
|
85
|
+
if self == DType.UINT8:
|
|
86
|
+
return np.uint8
|
|
87
|
+
elif self == DType.UINT16:
|
|
88
|
+
return np.uint16
|
|
89
|
+
elif self == DType.UINT32:
|
|
90
|
+
return np.uint32
|
|
91
|
+
elif self == DType.UINT64:
|
|
92
|
+
return np.uint64
|
|
93
|
+
elif self == DType.INT8:
|
|
94
|
+
return np.int8
|
|
95
|
+
elif self == DType.INT16:
|
|
96
|
+
return np.int16
|
|
97
|
+
elif self == DType.INT32:
|
|
98
|
+
return np.int32
|
|
99
|
+
elif self == DType.INT64:
|
|
100
|
+
return np.int64
|
|
101
|
+
elif self == DType.FLOAT32:
|
|
102
|
+
return np.float32
|
|
103
|
+
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ResamplingMethod(StrEnum):
|
|
107
|
+
"""An enum representing the rasterio Resampling."""
|
|
108
|
+
|
|
109
|
+
NEAREST = "nearest"
|
|
110
|
+
BILINEAR = "bilinear"
|
|
111
|
+
CUBIC = "cubic"
|
|
112
|
+
CUBIC_SPLINE = "cubic_spline"
|
|
113
|
+
|
|
114
|
+
def get_rasterio_resampling(self) -> Resampling:
|
|
115
|
+
"""Get the rasterio Resampling corresponding to this ResamplingMethod."""
|
|
116
|
+
return RESAMPLING_METHODS[self]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
RESAMPLING_METHODS = {
|
|
120
|
+
ResamplingMethod.NEAREST: Resampling.nearest,
|
|
121
|
+
ResamplingMethod.BILINEAR: Resampling.bilinear,
|
|
122
|
+
ResamplingMethod.CUBIC: Resampling.cubic,
|
|
123
|
+
ResamplingMethod.CUBIC_SPLINE: Resampling.cubic_spline,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class BandSetConfig(BaseModel):
|
|
128
|
+
"""A configuration for a band set in a raster layer.
|
|
129
|
+
|
|
130
|
+
Each band set specifies one or more bands that should be stored together.
|
|
131
|
+
It also specifies the storage format and dtype, the zoom offset, etc. for these
|
|
132
|
+
bands.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
dtype: DType = Field(description="Pixel value type to store the data under")
|
|
136
|
+
bands: list[str] = Field(
|
|
137
|
+
default_factory=lambda: [],
|
|
138
|
+
description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
|
|
139
|
+
)
|
|
140
|
+
num_bands: int | None = Field(
|
|
141
|
+
default=None,
|
|
142
|
+
description="The number of bands in this band set. The bands will be named B0, B1, B2, etc.",
|
|
143
|
+
)
|
|
144
|
+
format: dict[str, Any] = Field(
|
|
145
|
+
default_factory=lambda: {
|
|
146
|
+
"class_path": "rslearn.utils.raster_format.GeotiffRasterFormat"
|
|
147
|
+
},
|
|
148
|
+
description="jsonargparse configuration for the RasterFormat to store the tiles in.",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Store images at a resolution higher or lower than the window resolution. This
|
|
152
|
+
# enables keeping source data at its native resolution, either to save storage
|
|
153
|
+
# space (for lower resolution data) or to retain details (for higher resolution
|
|
154
|
+
# data). If positive, store data at the window resolution divided by
|
|
155
|
+
# 2^(zoom_offset) (higher resolution). If negative, store data at the window
|
|
156
|
+
# resolution multiplied by 2^(-zoom_offset) (lower resolution).
|
|
157
|
+
zoom_offset: int = Field(
|
|
158
|
+
default=0,
|
|
159
|
+
description="Store data at the window resolution multiplied by 2^(-zoom_offset).",
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
remap: dict[str, Any] | None = Field(
|
|
163
|
+
default=None,
|
|
164
|
+
description="Optional jsonargparse configuration for a Remapper to remap pixel values.",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Optional list of names for the different possible values of each band. The length
|
|
168
|
+
# of this list must equal the number of bands. For example, [["forest", "desert"]]
|
|
169
|
+
# means that it is a single-band raster where values can be 0 (forest) or 1
|
|
170
|
+
# (desert).
|
|
171
|
+
class_names: list[list[str]] | None = Field(
|
|
172
|
+
default=None,
|
|
173
|
+
description="Optional list of names for the different possible values of each band.",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Optional list of nodata values for this band set. This is used during
|
|
177
|
+
# materialization when creating mosaics, to determine which parts of the source
|
|
178
|
+
# images should be copied.
|
|
179
|
+
nodata_vals: list[float] | None = Field(
|
|
180
|
+
default=None, description="Optional nodata value for each band."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
@model_validator(mode="after")
|
|
184
|
+
def after_validator(self) -> "BandSetConfig":
|
|
185
|
+
"""Ensure the BandSetConfig is valid, and handle the num_bands field."""
|
|
186
|
+
if (len(self.bands) == 0 and self.num_bands is None) or (
|
|
187
|
+
len(self.bands) != 0 and self.num_bands is not None
|
|
188
|
+
):
|
|
189
|
+
raise ValueError("exactly one of bands and num_bands must be specified")
|
|
190
|
+
|
|
191
|
+
if self.num_bands is not None:
|
|
192
|
+
self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
|
|
193
|
+
self.num_bands = None
|
|
194
|
+
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def get_final_projection_and_bounds(
|
|
198
|
+
self, projection: Projection, bounds: PixelBounds
|
|
199
|
+
) -> tuple[Projection, PixelBounds]:
|
|
200
|
+
"""Gets the final projection/bounds based on band set config.
|
|
201
|
+
|
|
202
|
+
The band set config may apply a non-zero zoom offset that modifies the window's
|
|
203
|
+
projection.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
projection: the window's projection
|
|
207
|
+
bounds: the window's bounds (optional)
|
|
208
|
+
band_set: band set configuration object
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
tuple of updated projection and bounds with zoom offset applied
|
|
212
|
+
"""
|
|
213
|
+
if self.zoom_offset == 0:
|
|
214
|
+
return projection, bounds
|
|
215
|
+
projection = Projection(
|
|
216
|
+
projection.crs,
|
|
217
|
+
projection.x_resolution / (2**self.zoom_offset),
|
|
218
|
+
projection.y_resolution / (2**self.zoom_offset),
|
|
219
|
+
)
|
|
220
|
+
if self.zoom_offset > 0:
|
|
221
|
+
zoom_factor = 2**self.zoom_offset
|
|
222
|
+
bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
|
|
223
|
+
else:
|
|
224
|
+
bounds = tuple(
|
|
225
|
+
x // (2 ** (-self.zoom_offset))
|
|
226
|
+
for x in bounds # type: ignore
|
|
227
|
+
)
|
|
228
|
+
return projection, bounds
|
|
229
|
+
|
|
230
|
+
@field_validator("format", mode="before")
|
|
231
|
+
@classmethod
|
|
232
|
+
def convert_format_from_legacy(cls, v: dict[str, Any]) -> dict[str, Any]:
|
|
233
|
+
"""Support legacy format of the RasterFormat.
|
|
234
|
+
|
|
235
|
+
The legacy format sets 'name' instead of 'class_path', and uses custom parsing
|
|
236
|
+
for the init_args.
|
|
237
|
+
"""
|
|
238
|
+
if "name" not in v:
|
|
239
|
+
# New version, it is all good.
|
|
240
|
+
return v
|
|
241
|
+
|
|
242
|
+
warnings.warn(
|
|
243
|
+
"`format = {'name': ...}` is deprecated; "
|
|
244
|
+
"use `{'class_path': '...', 'init_args': {...}}` instead.",
|
|
245
|
+
DeprecationWarning,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
legacy_name_to_class_path = {
|
|
249
|
+
"image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
|
|
250
|
+
"geotiff": "rslearn.utils.raster_format.GeotiffRasterFormat",
|
|
251
|
+
"single_image": "rslearn.utils.raster_format.SingleImageRasterFormat",
|
|
252
|
+
}
|
|
253
|
+
if v["name"] not in legacy_name_to_class_path:
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"could not parse legacy format with unknown raster format {v['name']}"
|
|
256
|
+
)
|
|
257
|
+
init_args = dict(v)
|
|
258
|
+
class_path = legacy_name_to_class_path[init_args.pop("name")]
|
|
259
|
+
|
|
260
|
+
return dict(
|
|
261
|
+
class_path=class_path,
|
|
262
|
+
init_args=init_args,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def instantiate_raster_format(self) -> RasterFormat:
|
|
266
|
+
"""Instantiate the RasterFormat specified by this BandSetConfig."""
|
|
267
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
268
|
+
|
|
269
|
+
init_jsonargparse()
|
|
270
|
+
parser = jsonargparse.ArgumentParser()
|
|
271
|
+
parser.add_argument("--raster_format", type=RasterFormat)
|
|
272
|
+
cfg = parser.parse_object({"raster_format": self.format})
|
|
273
|
+
raster_format = parser.instantiate_classes(cfg).raster_format
|
|
274
|
+
return raster_format
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class SpaceMode(StrEnum):
|
|
278
|
+
"""Spatial matching mode when looking up items corresponding to a window."""
|
|
279
|
+
|
|
280
|
+
CONTAINS = "CONTAINS"
|
|
281
|
+
"""Items must contain the entire window."""
|
|
282
|
+
|
|
283
|
+
INTERSECTS = "INTERSECTS"
|
|
284
|
+
"""Items must overlap any portion of the window."""
|
|
285
|
+
|
|
286
|
+
MOSAIC = "MOSAIC"
|
|
287
|
+
"""Groups of items should be computed that cover the entire window.
|
|
288
|
+
|
|
289
|
+
During materialization, items in each group are merged to form a mosaic in the
|
|
290
|
+
dataset.
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
PER_PERIOD_MOSAIC = "PER_PERIOD_MOSAIC"
|
|
294
|
+
"""Create one mosaic per sub-period of the time range.
|
|
295
|
+
|
|
296
|
+
The duration of the sub-periods is controlled by another option in QueryConfig.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
COMPOSITE = "COMPOSITE"
|
|
300
|
+
"""Creates one composite covering the entire window.
|
|
301
|
+
|
|
302
|
+
During querying all items intersecting the window are placed in one group.
|
|
303
|
+
The compositing_method in the rasterlayer config specifies how these items are reduced
|
|
304
|
+
to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
# TODO add PER_PERIOD_COMPOSITE
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class TimeMode(StrEnum):
|
|
311
|
+
"""Temporal matching mode when looking up items corresponding to a window."""
|
|
312
|
+
|
|
313
|
+
WITHIN = "WITHIN"
|
|
314
|
+
"""Items must be within the window time range."""
|
|
315
|
+
|
|
316
|
+
NEAREST = "NEAREST"
|
|
317
|
+
"""Select items closest to the window time range, up to max_matches."""
|
|
318
|
+
|
|
319
|
+
BEFORE = "BEFORE"
|
|
320
|
+
"""Select items before the end of the window time range, up to max_matches."""
|
|
321
|
+
|
|
322
|
+
AFTER = "AFTER"
|
|
323
|
+
"""Select items after the start of the window time range, up to max_matches."""
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class QueryConfig(BaseModel):
|
|
327
|
+
"""A configuration for querying items in a data source."""
|
|
328
|
+
|
|
329
|
+
model_config = ConfigDict(frozen=True)
|
|
330
|
+
|
|
331
|
+
space_mode: SpaceMode = Field(
|
|
332
|
+
default=SpaceMode.MOSAIC,
|
|
333
|
+
description="Specifies how items should be matched with windows spatially.",
|
|
334
|
+
)
|
|
335
|
+
time_mode: TimeMode = Field(
|
|
336
|
+
default=TimeMode.WITHIN,
|
|
337
|
+
description="Specifies how items should be matched with windows temporally.",
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Minimum number of item groups. If there are fewer than this many matches, then no
|
|
341
|
+
# matches will be returned. This can be used to prevent unnecessary data ingestion
|
|
342
|
+
# if the user plans to discard windows that do not have a sufficient amount of data.
|
|
343
|
+
min_matches: int = Field(
|
|
344
|
+
default=0, description="The minimum number of item groups."
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
max_matches: int = Field(
|
|
348
|
+
default=1, description="The maximum number of item groups."
|
|
349
|
+
)
|
|
350
|
+
period_duration: Annotated[
|
|
351
|
+
timedelta,
|
|
352
|
+
BeforeValidator(ensure_timedelta),
|
|
353
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
354
|
+
] = Field(
|
|
355
|
+
default=timedelta(days=30),
|
|
356
|
+
description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class DataSourceConfig(BaseModel):
|
|
361
|
+
"""Configuration for a DataSource in a dataset layer."""
|
|
362
|
+
|
|
363
|
+
model_config = ConfigDict(frozen=True)
|
|
364
|
+
|
|
365
|
+
class_path: str = Field(description="Class path for the data source.")
|
|
366
|
+
init_args: dict[str, Any] = Field(
|
|
367
|
+
default_factory=lambda: {},
|
|
368
|
+
description="jsonargparse init args for the data source.",
|
|
369
|
+
)
|
|
370
|
+
query_config: QueryConfig = Field(
|
|
371
|
+
default_factory=lambda: QueryConfig(),
|
|
372
|
+
description="QueryConfig specifying how to match items with windows.",
|
|
373
|
+
)
|
|
374
|
+
time_offset: Annotated[
|
|
375
|
+
timedelta | None,
|
|
376
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
377
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
378
|
+
] = Field(
|
|
379
|
+
default=None,
|
|
380
|
+
description="Optional timedelta to add to the window's time range before matching.",
|
|
381
|
+
)
|
|
382
|
+
duration: Annotated[
|
|
383
|
+
timedelta | None,
|
|
384
|
+
BeforeValidator(ensure_optional_timedelta),
|
|
385
|
+
PlainSerializer(serialize_optional_timedelta),
|
|
386
|
+
] = Field(
|
|
387
|
+
default=None,
|
|
388
|
+
description="Optional, if the window's time range is (t0, t1), then update to (t0, t0 + duration).",
|
|
389
|
+
)
|
|
390
|
+
ingest: bool = Field(
|
|
391
|
+
default=True,
|
|
392
|
+
description="Whether to ingest this layer (default True). If False, it will be directly materialized without ingestion.",
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
@model_validator(mode="before")
|
|
396
|
+
@classmethod
|
|
397
|
+
def convert_from_legacy(cls, d: dict[str, Any]) -> dict[str, Any]:
|
|
398
|
+
"""Support legacy format of the DataSourceConfig.
|
|
399
|
+
|
|
400
|
+
The legacy format sets 'name' instead of 'class_path', and mixes the arguments
|
|
401
|
+
for the DataSource in with the DataSourceConfig keys.
|
|
402
|
+
"""
|
|
403
|
+
if "name" not in d:
|
|
404
|
+
# New version, it is all good.
|
|
405
|
+
return d
|
|
406
|
+
|
|
407
|
+
warnings.warn(
|
|
408
|
+
"`Data source configuration {'name': ...}` is deprecated; "
|
|
409
|
+
"use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
|
|
410
|
+
DeprecationWarning,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Split the dict into the base config that is in the pydantic model, and the
|
|
414
|
+
# source-specific options that should be moved to init_args dict.
|
|
415
|
+
class_path = d["name"]
|
|
416
|
+
base_config: dict[str, Any] = {}
|
|
417
|
+
ds_init_args: dict[str, Any] = {}
|
|
418
|
+
for k, v in d.items():
|
|
419
|
+
if k == "name":
|
|
420
|
+
continue
|
|
421
|
+
if k in cls.model_fields:
|
|
422
|
+
base_config[k] = v
|
|
423
|
+
else:
|
|
424
|
+
ds_init_args[k] = v
|
|
425
|
+
|
|
426
|
+
# Some legacy configs erroneously specify these keys, which are now caught by
|
|
427
|
+
# validation. But we still want those specific legacy configs to work.
|
|
428
|
+
if (
|
|
429
|
+
class_path == "rslearn.data_sources.planetary_computer.Sentinel2"
|
|
430
|
+
and "max_cloud_cover" in ds_init_args
|
|
431
|
+
):
|
|
432
|
+
warnings.warn(
|
|
433
|
+
"Data source configuration specifies invalid 'max_cloud_cover' option.",
|
|
434
|
+
DeprecationWarning,
|
|
435
|
+
)
|
|
436
|
+
del ds_init_args["max_cloud_cover"]
|
|
437
|
+
|
|
438
|
+
base_config["class_path"] = class_path
|
|
439
|
+
base_config["init_args"] = ds_init_args
|
|
440
|
+
return base_config
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class LayerType(StrEnum):
|
|
444
|
+
"""The layer type (raster or vector)."""
|
|
445
|
+
|
|
446
|
+
RASTER = "raster"
|
|
447
|
+
VECTOR = "vector"
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class CompositingMethod(StrEnum):
|
|
451
|
+
"""Method how to select pixels for the composite from corresponding items of a window."""
|
|
452
|
+
|
|
453
|
+
FIRST_VALID = "FIRST_VALID"
|
|
454
|
+
"""Select first valid pixel in order of corresponding items (might be sorted)"""
|
|
455
|
+
|
|
456
|
+
MEAN = "MEAN"
|
|
457
|
+
"""Select per-pixel mean value of corresponding items of a window"""
|
|
458
|
+
|
|
459
|
+
MEDIAN = "MEDIAN"
|
|
460
|
+
"""Select per-pixel median value of corresponding items of a window"""
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
class LayerConfig(BaseModel):
|
|
464
|
+
"""Configuration of a layer in a dataset."""
|
|
465
|
+
|
|
466
|
+
model_config = ConfigDict(frozen=True)
|
|
467
|
+
|
|
468
|
+
type: LayerType = Field(description="The LayerType (raster or vector).")
|
|
469
|
+
data_source: DataSourceConfig | None = Field(
|
|
470
|
+
default=None,
|
|
471
|
+
description="Optional DataSourceConfig if this layer is retrievable.",
|
|
472
|
+
)
|
|
473
|
+
alias: str | None = Field(
|
|
474
|
+
default=None, description="Alias for this layer to use in the tile store."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Raster layer options.
|
|
478
|
+
band_sets: list[BandSetConfig] = Field(
|
|
479
|
+
default_factory=lambda: [],
|
|
480
|
+
description="For raster layers, the bands to store in this layer.",
|
|
481
|
+
)
|
|
482
|
+
resampling_method: ResamplingMethod = Field(
|
|
483
|
+
default=ResamplingMethod.BILINEAR,
|
|
484
|
+
description="For raster layers, how to resample rasters (if neeed), default bilinear resampling.",
|
|
485
|
+
)
|
|
486
|
+
compositing_method: CompositingMethod = Field(
|
|
487
|
+
default=CompositingMethod.FIRST_VALID,
|
|
488
|
+
description="For raster layers, how to compute pixel values in the composite of each window's items.",
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Vector layer options.
|
|
492
|
+
vector_format: dict[str, Any] = Field(
|
|
493
|
+
default_factory=lambda: {
|
|
494
|
+
"class_path": "rslearn.utils.vector_format.GeojsonVectorFormat"
|
|
495
|
+
},
|
|
496
|
+
description="For vector layers, the jsonargparse configuration for the VectorFormat.",
|
|
497
|
+
)
|
|
498
|
+
class_property_name: str | None = Field(
|
|
499
|
+
default=None,
|
|
500
|
+
description="Optional metadata field indicating that the GeoJSON features contain a property that corresponds to a class label, and this is the name of that property.",
|
|
501
|
+
)
|
|
502
|
+
class_names: list[str] | None = Field(
|
|
503
|
+
default=None,
|
|
504
|
+
description="The list of classes that the class_property_name property could be set to.",
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
@model_validator(mode="after")
|
|
508
|
+
def after_validator(self) -> "LayerConfig":
|
|
509
|
+
"""Ensure the LayerConfig is valid."""
|
|
510
|
+
if self.type == LayerType.RASTER and len(self.band_sets) == 0:
|
|
511
|
+
raise ValueError(
|
|
512
|
+
"band sets must be specified and non-empty for raster layers"
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return self
|
|
516
|
+
|
|
517
|
+
def __hash__(self) -> int:
|
|
518
|
+
"""Return a hash of this LayerConfig."""
|
|
519
|
+
return hash(json.dumps(self.model_dump(mode="json"), sort_keys=True))
|
|
520
|
+
|
|
521
|
+
def __eq__(self, other: Any) -> bool:
|
|
522
|
+
"""Returns whether other is the same as this LayerConfig.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
other: the other object to compare.
|
|
526
|
+
"""
|
|
527
|
+
if not isinstance(other, LayerConfig):
|
|
528
|
+
return False
|
|
529
|
+
return self.model_dump() == other.model_dump()
|
|
530
|
+
|
|
531
|
+
@functools.cache
|
|
532
|
+
def instantiate_data_source(self, ds_path: UPath | None = None) -> "DataSource":
|
|
533
|
+
"""Instantiate the data source specified by this config.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
ds_path: optional dataset path to include in the DataSourceContext.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
the DataSource object.
|
|
540
|
+
"""
|
|
541
|
+
from rslearn.data_sources.data_source import DataSource, DataSourceContext
|
|
542
|
+
from rslearn.utils.jsonargparse import data_source_context_serializer
|
|
543
|
+
|
|
544
|
+
logger.debug("getting a data source for dataset at %s", ds_path)
|
|
545
|
+
if self.data_source is None:
|
|
546
|
+
raise ValueError("This layer does not specify a data source")
|
|
547
|
+
|
|
548
|
+
# Inject the DataSourceContext into the args.
|
|
549
|
+
context = DataSourceContext(
|
|
550
|
+
ds_path=ds_path,
|
|
551
|
+
layer_config=self,
|
|
552
|
+
)
|
|
553
|
+
ds_config: dict[str, Any] = {
|
|
554
|
+
"class_path": self.data_source.class_path,
|
|
555
|
+
"init_args": copy.deepcopy(self.data_source.init_args),
|
|
556
|
+
}
|
|
557
|
+
ds_config["init_args"]["context"] = data_source_context_serializer(context)
|
|
558
|
+
|
|
559
|
+
# Now we can parse with jsonargparse.
|
|
560
|
+
from rslearn.utils.jsonargparse import (
|
|
561
|
+
data_source_context_serializer,
|
|
562
|
+
init_jsonargparse,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
init_jsonargparse()
|
|
566
|
+
parser = jsonargparse.ArgumentParser()
|
|
567
|
+
parser.add_argument("--data_source", type=DataSource)
|
|
568
|
+
cfg = parser.parse_object({"data_source": ds_config})
|
|
569
|
+
data_source = parser.instantiate_classes(cfg).data_source
|
|
570
|
+
return data_source
|
|
571
|
+
|
|
572
|
+
def instantiate_vector_format(self) -> VectorFormat:
|
|
573
|
+
"""Instantiate the vector format specified by this config."""
|
|
574
|
+
if self.type != LayerType.VECTOR:
|
|
575
|
+
raise ValueError(
|
|
576
|
+
f"cannot instantiate vector format for layer with type {self.type}"
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
580
|
+
|
|
581
|
+
init_jsonargparse()
|
|
582
|
+
parser = jsonargparse.ArgumentParser()
|
|
583
|
+
parser.add_argument("--vector_format", type=VectorFormat)
|
|
584
|
+
cfg = parser.parse_object({"vector_format": self.vector_format})
|
|
585
|
+
vector_format = parser.instantiate_classes(cfg).vector_format
|
|
586
|
+
return vector_format
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
class DatasetConfig(BaseModel):
|
|
590
|
+
"""Overall dataset configuration."""
|
|
591
|
+
|
|
592
|
+
layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
|
|
593
|
+
tile_store: dict[str, Any] = Field(
|
|
594
|
+
default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
|
|
595
|
+
description="jsonargparse configuration for the TileStore.",
|
|
596
|
+
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Data sources.
|
|
2
|
+
|
|
3
|
+
A DataSource represents a source from which raster and vector data corresponding to
|
|
4
|
+
spatiotemporal windows can be retrieved.
|
|
5
|
+
|
|
6
|
+
A DataSource consists of items that can be ingested, like Sentinel-2 scenes or
|
|
7
|
+
OpenStreetMap PBF files.
|
|
8
|
+
|
|
9
|
+
Each source supports operations to lookup items that match with spatiotemporal
|
|
10
|
+
geometries, and ingest those items.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from .data_source import (
|
|
14
|
+
DataSource,
|
|
15
|
+
DataSourceContext,
|
|
16
|
+
Item,
|
|
17
|
+
ItemLookupDataSource,
|
|
18
|
+
RetrieveItemDataSource,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = (
|
|
22
|
+
"DataSource",
|
|
23
|
+
"DataSourceContext",
|
|
24
|
+
"Item",
|
|
25
|
+
"ItemLookupDataSource",
|
|
26
|
+
"RetrieveItemDataSource",
|
|
27
|
+
"data_source_from_config",
|
|
28
|
+
)
|