rslearn 0.0.14__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. {rslearn-0.0.14/rslearn.egg-info → rslearn-0.0.16}/PKG-INFO +1 -1
  2. {rslearn-0.0.14 → rslearn-0.0.16}/pyproject.toml +1 -1
  3. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/config/__init__.py +2 -10
  4. rslearn-0.0.16/rslearn/config/dataset.py +596 -0
  5. rslearn-0.0.16/rslearn/data_sources/__init__.py +28 -0
  6. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/aws_landsat.py +13 -24
  7. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/aws_open_data.py +21 -46
  8. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/aws_sentinel1.py +3 -14
  9. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/climate_data_store.py +21 -40
  10. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/copernicus.py +30 -91
  11. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/data_source.py +26 -0
  12. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/earthdaily.py +13 -38
  13. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/earthdata_srtm.py +14 -32
  14. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/eurocrops.py +5 -9
  15. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/gcp_public_data.py +46 -43
  16. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/google_earth_engine.py +31 -44
  17. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/local_files.py +91 -100
  18. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/openstreetmap.py +21 -51
  19. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/planet.py +12 -30
  20. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/planet_basemap.py +4 -25
  21. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/planetary_computer.py +58 -141
  22. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/usda_cdl.py +15 -26
  23. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/usgs_landsat.py +4 -29
  24. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/utils.py +9 -0
  25. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/worldcereal.py +47 -54
  26. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/worldcover.py +16 -14
  27. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/worldpop.py +15 -18
  28. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/xyz_tiles.py +11 -30
  29. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/dataset.py +6 -6
  30. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/manage.py +28 -26
  31. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/materialize.py +9 -45
  32. rslearn-0.0.16/rslearn/lightning_cli.py +436 -0
  33. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/main.py +3 -3
  34. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/clay/clay.py +14 -1
  35. rslearn-0.0.16/rslearn/models/concatenate_features.py +93 -0
  36. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/croma.py +26 -3
  37. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/satlaspretrain.py +18 -4
  38. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/terramind.py +19 -0
  39. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/tile_stores/__init__.py +0 -11
  40. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/dataset.py +4 -12
  41. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/prediction_writer.py +16 -32
  42. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/classification.py +2 -1
  43. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/fsspec.py +20 -0
  44. rslearn-0.0.16/rslearn/utils/jsonargparse.py +112 -0
  45. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/raster_format.py +1 -41
  46. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/vector_format.py +1 -38
  47. {rslearn-0.0.14 → rslearn-0.0.16/rslearn.egg-info}/PKG-INFO +1 -1
  48. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/SOURCES.txt +1 -2
  49. rslearn-0.0.14/rslearn/config/dataset.py +0 -602
  50. rslearn-0.0.14/rslearn/data_sources/__init__.py +0 -51
  51. rslearn-0.0.14/rslearn/data_sources/geotiff.py +0 -1
  52. rslearn-0.0.14/rslearn/data_sources/raster_source.py +0 -23
  53. rslearn-0.0.14/rslearn/lightning_cli.py +0 -67
  54. rslearn-0.0.14/rslearn/utils/jsonargparse.py +0 -33
  55. {rslearn-0.0.14 → rslearn-0.0.16}/LICENSE +0 -0
  56. {rslearn-0.0.14 → rslearn-0.0.16}/NOTICE +0 -0
  57. {rslearn-0.0.14 → rslearn-0.0.16}/README.md +0 -0
  58. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/__init__.py +0 -0
  59. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/arg_parser.py +0 -0
  60. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/const.py +0 -0
  61. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/data_sources/vector_source.py +0 -0
  62. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/__init__.py +0 -0
  63. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/add_windows.py +0 -0
  64. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/handler_summaries.py +0 -0
  65. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/index.py +0 -0
  66. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/remap.py +0 -0
  67. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/dataset/window.py +0 -0
  68. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/log_utils.py +0 -0
  69. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/__init__.py +0 -0
  70. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/anysat.py +0 -0
  71. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/clay/configs/metadata.yaml +0 -0
  72. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/clip.py +0 -0
  73. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/conv.py +0 -0
  74. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/__init__.py +0 -0
  75. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/box_ops.py +0 -0
  76. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/detr.py +0 -0
  77. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/matcher.py +0 -0
  78. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/position_encoding.py +0 -0
  79. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/transformer.py +0 -0
  80. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/detr/util.py +0 -0
  81. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/dinov3.py +0 -0
  82. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/faster_rcnn.py +0 -0
  83. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/feature_center_crop.py +0 -0
  84. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/fpn.py +0 -0
  85. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/galileo/__init__.py +0 -0
  86. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/galileo/galileo.py +0 -0
  87. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/galileo/single_file_galileo.py +0 -0
  88. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/module_wrapper.py +0 -0
  89. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/molmo.py +0 -0
  90. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/multitask.py +0 -0
  91. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/olmoearth_pretrain/__init__.py +0 -0
  92. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/olmoearth_pretrain/model.py +0 -0
  93. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/olmoearth_pretrain/norm.py +0 -0
  94. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon.py +0 -0
  95. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/drone.yaml +0 -0
  96. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/enmap.yaml +0 -0
  97. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/goes.yaml +0 -0
  98. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/himawari.yaml +0 -0
  99. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/intuition.yaml +0 -0
  100. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/landsat8.yaml +0 -0
  101. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/modis_terra.yaml +0 -0
  102. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +0 -0
  103. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/sentinel1.yaml +0 -0
  104. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/sentinel2.yaml +0 -0
  105. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/superdove.yaml +0 -0
  106. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/panopticon_data/sensors/wv23.yaml +0 -0
  107. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/pick_features.py +0 -0
  108. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/pooling_decoder.py +0 -0
  109. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/presto/__init__.py +0 -0
  110. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/presto/presto.py +0 -0
  111. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/presto/single_file_presto.py +0 -0
  112. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/prithvi.py +0 -0
  113. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/registry.py +0 -0
  114. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/resize_features.py +0 -0
  115. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/sam2_enc.py +0 -0
  116. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/simple_time_series.py +0 -0
  117. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/singletask.py +0 -0
  118. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/ssl4eo_s12.py +0 -0
  119. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/swin.py +0 -0
  120. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/task_embedding.py +0 -0
  121. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/trunk.py +0 -0
  122. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/unet.py +0 -0
  123. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/upsample.py +0 -0
  124. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/models/use_croma.py +0 -0
  125. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/py.typed +0 -0
  126. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/template_params.py +0 -0
  127. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/tile_stores/default.py +0 -0
  128. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/tile_stores/tile_store.py +0 -0
  129. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/__init__.py +0 -0
  130. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/all_patches_dataset.py +0 -0
  131. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/__init__.py +0 -0
  132. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/adapters.py +0 -0
  133. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/freeze_unfreeze.py +0 -0
  134. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/gradients.py +0 -0
  135. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/callbacks/peft.py +0 -0
  136. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/data_module.py +0 -0
  137. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/lightning_module.py +0 -0
  138. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/optimizer.py +0 -0
  139. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/scheduler.py +0 -0
  140. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/__init__.py +0 -0
  141. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/detection.py +0 -0
  142. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/embedding.py +0 -0
  143. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/multi_task.py +0 -0
  144. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/per_pixel_regression.py +0 -0
  145. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/regression.py +0 -0
  146. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/segmentation.py +0 -0
  147. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/tasks/task.py +0 -0
  148. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/__init__.py +0 -0
  149. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/concatenate.py +0 -0
  150. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/crop.py +0 -0
  151. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/flip.py +0 -0
  152. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/mask.py +0 -0
  153. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/normalize.py +0 -0
  154. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/pad.py +0 -0
  155. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/select_bands.py +0 -0
  156. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/sentinel1.py +0 -0
  157. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/train/transforms/transform.py +0 -0
  158. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/__init__.py +0 -0
  159. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/array.py +0 -0
  160. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/feature.py +0 -0
  161. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/geometry.py +0 -0
  162. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/get_utm_ups_crs.py +0 -0
  163. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/grid_index.py +0 -0
  164. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/mp.py +0 -0
  165. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/rtree_index.py +0 -0
  166. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/spatial_index.py +0 -0
  167. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/sqlite_index.py +0 -0
  168. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn/utils/time.py +0 -0
  169. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/dependency_links.txt +0 -0
  170. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/entry_points.txt +0 -0
  171. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/requires.txt +0 -0
  172. {rslearn-0.0.14 → rslearn-0.0.16}/rslearn.egg-info/top_level.txt +0 -0
  173. {rslearn-0.0.14 → rslearn-0.0.16}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.14"
3
+ version = "0.0.16"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
6
  { name = "OlmoEarth Team" },
@@ -3,33 +3,25 @@
3
3
  from .dataset import (
4
4
  BandSetConfig,
5
5
  CompositingMethod,
6
+ DatasetConfig,
6
7
  DataSourceConfig,
7
8
  DType,
8
9
  LayerConfig,
9
10
  LayerType,
10
11
  QueryConfig,
11
- RasterFormatConfig,
12
- RasterLayerConfig,
13
12
  SpaceMode,
14
13
  TimeMode,
15
- VectorFormatConfig,
16
- VectorLayerConfig,
17
- load_layer_config,
18
14
  )
19
15
 
20
16
  __all__ = [
21
17
  "BandSetConfig",
22
18
  "CompositingMethod",
19
+ "DatasetConfig",
23
20
  "DataSourceConfig",
24
21
  "DType",
25
22
  "LayerConfig",
26
23
  "LayerType",
27
24
  "QueryConfig",
28
- "RasterFormatConfig",
29
- "RasterLayerConfig",
30
25
  "SpaceMode",
31
26
  "TimeMode",
32
- "VectorFormatConfig",
33
- "VectorLayerConfig",
34
- "load_layer_config",
35
27
  ]
@@ -0,0 +1,596 @@
1
+ """Classes for storing configuration of a dataset."""
2
+
3
+ import copy
4
+ import functools
5
+ import json
6
+ import warnings
7
+ from datetime import timedelta
8
+ from enum import StrEnum
9
+ from typing import TYPE_CHECKING, Annotated, Any
10
+
11
+ import jsonargparse
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+ import pytimeparse
15
+ from pydantic import (
16
+ BaseModel,
17
+ BeforeValidator,
18
+ ConfigDict,
19
+ Field,
20
+ PlainSerializer,
21
+ field_validator,
22
+ model_validator,
23
+ )
24
+ from rasterio.enums import Resampling
25
+ from upath import UPath
26
+
27
+ from rslearn.log_utils import get_logger
28
+ from rslearn.utils import PixelBounds, Projection
29
+ from rslearn.utils.raster_format import RasterFormat
30
+ from rslearn.utils.vector_format import VectorFormat
31
+
32
+ if TYPE_CHECKING:
33
+ from rslearn.data_sources.data_source import DataSource
34
+
35
+ logger = get_logger("__name__")
36
+
37
+
38
+ def ensure_timedelta(v: Any) -> Any:
39
+ """Ensure the value is a timedelta.
40
+
41
+ If the value is a string, we try to parse it with pytimeparse.
42
+
43
+ This function is meant to be used like Annotated[timedelta, BeforeValidator(ensure_timedelta)].
44
+ """
45
+ if isinstance(v, timedelta):
46
+ return v
47
+ if isinstance(v, str):
48
+ return pytimeparse.parse(v)
49
+ raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
50
+
51
+
52
+ def ensure_optional_timedelta(v: Any) -> Any:
53
+ """Like ensure_timedelta, but allows None as a value."""
54
+ if v is None:
55
+ return None
56
+ if isinstance(v, timedelta):
57
+ return v
58
+ if isinstance(v, str):
59
+ return pytimeparse.parse(v)
60
+ raise TypeError(f"Invalid type for timedelta: {type(v).__name__}")
61
+
62
+
63
+ def serialize_optional_timedelta(v: timedelta | None) -> str | None:
64
+ """Serialize an optional timedelta for compatibility with pytimeparse."""
65
+ if v is None:
66
+ return None
67
+ return str(v.total_seconds()) + "s"
68
+
69
+
70
+ class DType(StrEnum):
71
+ """Data type of a raster."""
72
+
73
+ UINT8 = "uint8"
74
+ UINT16 = "uint16"
75
+ UINT32 = "uint32"
76
+ UINT64 = "uint64"
77
+ INT8 = "int8"
78
+ INT16 = "int16"
79
+ INT32 = "int32"
80
+ INT64 = "int64"
81
+ FLOAT32 = "float32"
82
+
83
+ def get_numpy_dtype(self) -> npt.DTypeLike:
84
+ """Returns numpy dtype object corresponding to this DType."""
85
+ if self == DType.UINT8:
86
+ return np.uint8
87
+ elif self == DType.UINT16:
88
+ return np.uint16
89
+ elif self == DType.UINT32:
90
+ return np.uint32
91
+ elif self == DType.UINT64:
92
+ return np.uint64
93
+ elif self == DType.INT8:
94
+ return np.int8
95
+ elif self == DType.INT16:
96
+ return np.int16
97
+ elif self == DType.INT32:
98
+ return np.int32
99
+ elif self == DType.INT64:
100
+ return np.int64
101
+ elif self == DType.FLOAT32:
102
+ return np.float32
103
+ raise ValueError(f"unable to handle numpy dtype {self}")
104
+
105
+
106
+ class ResamplingMethod(StrEnum):
107
+ """An enum representing the rasterio Resampling."""
108
+
109
+ NEAREST = "nearest"
110
+ BILINEAR = "bilinear"
111
+ CUBIC = "cubic"
112
+ CUBIC_SPLINE = "cubic_spline"
113
+
114
+ def get_rasterio_resampling(self) -> Resampling:
115
+ """Get the rasterio Resampling corresponding to this ResamplingMethod."""
116
+ return RESAMPLING_METHODS[self]
117
+
118
+
119
+ RESAMPLING_METHODS = {
120
+ ResamplingMethod.NEAREST: Resampling.nearest,
121
+ ResamplingMethod.BILINEAR: Resampling.bilinear,
122
+ ResamplingMethod.CUBIC: Resampling.cubic,
123
+ ResamplingMethod.CUBIC_SPLINE: Resampling.cubic_spline,
124
+ }
125
+
126
+
127
+ class BandSetConfig(BaseModel):
128
+ """A configuration for a band set in a raster layer.
129
+
130
+ Each band set specifies one or more bands that should be stored together.
131
+ It also specifies the storage format and dtype, the zoom offset, etc. for these
132
+ bands.
133
+ """
134
+
135
+ dtype: DType = Field(description="Pixel value type to store the data under")
136
+ bands: list[str] = Field(
137
+ default_factory=lambda: [],
138
+ description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
139
+ )
140
+ num_bands: int | None = Field(
141
+ default=None,
142
+ description="The number of bands in this band set. The bands will be named B0, B1, B2, etc.",
143
+ )
144
+ format: dict[str, Any] = Field(
145
+ default_factory=lambda: {
146
+ "class_path": "rslearn.utils.raster_format.GeotiffRasterFormat"
147
+ },
148
+ description="jsonargparse configuration for the RasterFormat to store the tiles in.",
149
+ )
150
+
151
+ # Store images at a resolution higher or lower than the window resolution. This
152
+ # enables keeping source data at its native resolution, either to save storage
153
+ # space (for lower resolution data) or to retain details (for higher resolution
154
+ # data). If positive, store data at the window resolution divided by
155
+ # 2^(zoom_offset) (higher resolution). If negative, store data at the window
156
+ # resolution multiplied by 2^(-zoom_offset) (lower resolution).
157
+ zoom_offset: int = Field(
158
+ default=0,
159
+ description="Store data at the window resolution multiplied by 2^(-zoom_offset).",
160
+ )
161
+
162
+ remap: dict[str, Any] | None = Field(
163
+ default=None,
164
+ description="Optional jsonargparse configuration for a Remapper to remap pixel values.",
165
+ )
166
+
167
+ # Optional list of names for the different possible values of each band. The length
168
+ # of this list must equal the number of bands. For example, [["forest", "desert"]]
169
+ # means that it is a single-band raster where values can be 0 (forest) or 1
170
+ # (desert).
171
+ class_names: list[list[str]] | None = Field(
172
+ default=None,
173
+ description="Optional list of names for the different possible values of each band.",
174
+ )
175
+
176
+ # Optional list of nodata values for this band set. This is used during
177
+ # materialization when creating mosaics, to determine which parts of the source
178
+ # images should be copied.
179
+ nodata_vals: list[float] | None = Field(
180
+ default=None, description="Optional nodata value for each band."
181
+ )
182
+
183
+ @model_validator(mode="after")
184
+ def after_validator(self) -> "BandSetConfig":
185
+ """Ensure the BandSetConfig is valid, and handle the num_bands field."""
186
+ if (len(self.bands) == 0 and self.num_bands is None) or (
187
+ len(self.bands) != 0 and self.num_bands is not None
188
+ ):
189
+ raise ValueError("exactly one of bands and num_bands must be specified")
190
+
191
+ if self.num_bands is not None:
192
+ self.bands = [f"B{band_idx}" for band_idx in range(self.num_bands)]
193
+ self.num_bands = None
194
+
195
+ return self
196
+
197
+ def get_final_projection_and_bounds(
198
+ self, projection: Projection, bounds: PixelBounds
199
+ ) -> tuple[Projection, PixelBounds]:
200
+ """Gets the final projection/bounds based on band set config.
201
+
202
+ The band set config may apply a non-zero zoom offset that modifies the window's
203
+ projection.
204
+
205
+ Args:
206
+ projection: the window's projection
207
+ bounds: the window's bounds (optional)
208
+ band_set: band set configuration object
209
+
210
+ Returns:
211
+ tuple of updated projection and bounds with zoom offset applied
212
+ """
213
+ if self.zoom_offset == 0:
214
+ return projection, bounds
215
+ projection = Projection(
216
+ projection.crs,
217
+ projection.x_resolution / (2**self.zoom_offset),
218
+ projection.y_resolution / (2**self.zoom_offset),
219
+ )
220
+ if self.zoom_offset > 0:
221
+ zoom_factor = 2**self.zoom_offset
222
+ bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
223
+ else:
224
+ bounds = tuple(
225
+ x // (2 ** (-self.zoom_offset))
226
+ for x in bounds # type: ignore
227
+ )
228
+ return projection, bounds
229
+
230
+ @field_validator("format", mode="before")
231
+ @classmethod
232
+ def convert_format_from_legacy(cls, v: dict[str, Any]) -> dict[str, Any]:
233
+ """Support legacy format of the RasterFormat.
234
+
235
+ The legacy format sets 'name' instead of 'class_path', and uses custom parsing
236
+ for the init_args.
237
+ """
238
+ if "name" not in v:
239
+ # New version, it is all good.
240
+ return v
241
+
242
+ warnings.warn(
243
+ "`format = {'name': ...}` is deprecated; "
244
+ "use `{'class_path': '...', 'init_args': {...}}` instead.",
245
+ DeprecationWarning,
246
+ )
247
+
248
+ legacy_name_to_class_path = {
249
+ "image_tile": "rslearn.utils.raster_format.ImageTileRasterFormat",
250
+ "geotiff": "rslearn.utils.raster_format.GeotiffRasterFormat",
251
+ "single_image": "rslearn.utils.raster_format.SingleImageRasterFormat",
252
+ }
253
+ if v["name"] not in legacy_name_to_class_path:
254
+ raise ValueError(
255
+ f"could not parse legacy format with unknown raster format {v['name']}"
256
+ )
257
+ init_args = dict(v)
258
+ class_path = legacy_name_to_class_path[init_args.pop("name")]
259
+
260
+ return dict(
261
+ class_path=class_path,
262
+ init_args=init_args,
263
+ )
264
+
265
+ def instantiate_raster_format(self) -> RasterFormat:
266
+ """Instantiate the RasterFormat specified by this BandSetConfig."""
267
+ from rslearn.utils.jsonargparse import init_jsonargparse
268
+
269
+ init_jsonargparse()
270
+ parser = jsonargparse.ArgumentParser()
271
+ parser.add_argument("--raster_format", type=RasterFormat)
272
+ cfg = parser.parse_object({"raster_format": self.format})
273
+ raster_format = parser.instantiate_classes(cfg).raster_format
274
+ return raster_format
275
+
276
+
277
+ class SpaceMode(StrEnum):
278
+ """Spatial matching mode when looking up items corresponding to a window."""
279
+
280
+ CONTAINS = "CONTAINS"
281
+ """Items must contain the entire window."""
282
+
283
+ INTERSECTS = "INTERSECTS"
284
+ """Items must overlap any portion of the window."""
285
+
286
+ MOSAIC = "MOSAIC"
287
+ """Groups of items should be computed that cover the entire window.
288
+
289
+ During materialization, items in each group are merged to form a mosaic in the
290
+ dataset.
291
+ """
292
+
293
+ PER_PERIOD_MOSAIC = "PER_PERIOD_MOSAIC"
294
+ """Create one mosaic per sub-period of the time range.
295
+
296
+ The duration of the sub-periods is controlled by another option in QueryConfig.
297
+ """
298
+
299
+ COMPOSITE = "COMPOSITE"
300
+ """Creates one composite covering the entire window.
301
+
302
+ During querying all items intersecting the window are placed in one group.
303
+ The compositing_method in the rasterlayer config specifies how these items are reduced
304
+ to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
305
+ """
306
+
307
+ # TODO add PER_PERIOD_COMPOSITE
308
+
309
+
310
+ class TimeMode(StrEnum):
311
+ """Temporal matching mode when looking up items corresponding to a window."""
312
+
313
+ WITHIN = "WITHIN"
314
+ """Items must be within the window time range."""
315
+
316
+ NEAREST = "NEAREST"
317
+ """Select items closest to the window time range, up to max_matches."""
318
+
319
+ BEFORE = "BEFORE"
320
+ """Select items before the end of the window time range, up to max_matches."""
321
+
322
+ AFTER = "AFTER"
323
+ """Select items after the start of the window time range, up to max_matches."""
324
+
325
+
326
+ class QueryConfig(BaseModel):
327
+ """A configuration for querying items in a data source."""
328
+
329
+ model_config = ConfigDict(frozen=True)
330
+
331
+ space_mode: SpaceMode = Field(
332
+ default=SpaceMode.MOSAIC,
333
+ description="Specifies how items should be matched with windows spatially.",
334
+ )
335
+ time_mode: TimeMode = Field(
336
+ default=TimeMode.WITHIN,
337
+ description="Specifies how items should be matched with windows temporally.",
338
+ )
339
+
340
+ # Minimum number of item groups. If there are fewer than this many matches, then no
341
+ # matches will be returned. This can be used to prevent unnecessary data ingestion
342
+ # if the user plans to discard windows that do not have a sufficient amount of data.
343
+ min_matches: int = Field(
344
+ default=0, description="The minimum number of item groups."
345
+ )
346
+
347
+ max_matches: int = Field(
348
+ default=1, description="The maximum number of item groups."
349
+ )
350
+ period_duration: Annotated[
351
+ timedelta,
352
+ BeforeValidator(ensure_timedelta),
353
+ PlainSerializer(serialize_optional_timedelta),
354
+ ] = Field(
355
+ default=timedelta(days=30),
356
+ description="The duration of the periods, if the space mode is PER_PERIOD_MOSAIC.",
357
+ )
358
+
359
+
360
+ class DataSourceConfig(BaseModel):
361
+ """Configuration for a DataSource in a dataset layer."""
362
+
363
+ model_config = ConfigDict(frozen=True)
364
+
365
+ class_path: str = Field(description="Class path for the data source.")
366
+ init_args: dict[str, Any] = Field(
367
+ default_factory=lambda: {},
368
+ description="jsonargparse init args for the data source.",
369
+ )
370
+ query_config: QueryConfig = Field(
371
+ default_factory=lambda: QueryConfig(),
372
+ description="QueryConfig specifying how to match items with windows.",
373
+ )
374
+ time_offset: Annotated[
375
+ timedelta | None,
376
+ BeforeValidator(ensure_optional_timedelta),
377
+ PlainSerializer(serialize_optional_timedelta),
378
+ ] = Field(
379
+ default=None,
380
+ description="Optional timedelta to add to the window's time range before matching.",
381
+ )
382
+ duration: Annotated[
383
+ timedelta | None,
384
+ BeforeValidator(ensure_optional_timedelta),
385
+ PlainSerializer(serialize_optional_timedelta),
386
+ ] = Field(
387
+ default=None,
388
+ description="Optional, if the window's time range is (t0, t1), then update to (t0, t0 + duration).",
389
+ )
390
+ ingest: bool = Field(
391
+ default=True,
392
+ description="Whether to ingest this layer (default True). If False, it will be directly materialized without ingestion.",
393
+ )
394
+
395
+ @model_validator(mode="before")
396
+ @classmethod
397
+ def convert_from_legacy(cls, d: dict[str, Any]) -> dict[str, Any]:
398
+ """Support legacy format of the DataSourceConfig.
399
+
400
+ The legacy format sets 'name' instead of 'class_path', and mixes the arguments
401
+ for the DataSource in with the DataSourceConfig keys.
402
+ """
403
+ if "name" not in d:
404
+ # New version, it is all good.
405
+ return d
406
+
407
+ warnings.warn(
408
+ "`Data source configuration {'name': ...}` is deprecated; "
409
+ "use `{'class_path': '...', 'init_args': {...}, ...}` instead.",
410
+ DeprecationWarning,
411
+ )
412
+
413
+ # Split the dict into the base config that is in the pydantic model, and the
414
+ # source-specific options that should be moved to init_args dict.
415
+ class_path = d["name"]
416
+ base_config: dict[str, Any] = {}
417
+ ds_init_args: dict[str, Any] = {}
418
+ for k, v in d.items():
419
+ if k == "name":
420
+ continue
421
+ if k in cls.model_fields:
422
+ base_config[k] = v
423
+ else:
424
+ ds_init_args[k] = v
425
+
426
+ # Some legacy configs erroneously specify these keys, which are now caught by
427
+ # validation. But we still want those specific legacy configs to work.
428
+ if (
429
+ class_path == "rslearn.data_sources.planetary_computer.Sentinel2"
430
+ and "max_cloud_cover" in ds_init_args
431
+ ):
432
+ warnings.warn(
433
+ "Data source configuration specifies invalid 'max_cloud_cover' option.",
434
+ DeprecationWarning,
435
+ )
436
+ del ds_init_args["max_cloud_cover"]
437
+
438
+ base_config["class_path"] = class_path
439
+ base_config["init_args"] = ds_init_args
440
+ return base_config
441
+
442
+
443
+ class LayerType(StrEnum):
444
+ """The layer type (raster or vector)."""
445
+
446
+ RASTER = "raster"
447
+ VECTOR = "vector"
448
+
449
+
450
+ class CompositingMethod(StrEnum):
451
+ """Method how to select pixels for the composite from corresponding items of a window."""
452
+
453
+ FIRST_VALID = "FIRST_VALID"
454
+ """Select first valid pixel in order of corresponding items (might be sorted)"""
455
+
456
+ MEAN = "MEAN"
457
+ """Select per-pixel mean value of corresponding items of a window"""
458
+
459
+ MEDIAN = "MEDIAN"
460
+ """Select per-pixel median value of corresponding items of a window"""
461
+
462
+
463
+ class LayerConfig(BaseModel):
464
+ """Configuration of a layer in a dataset."""
465
+
466
+ model_config = ConfigDict(frozen=True)
467
+
468
+ type: LayerType = Field(description="The LayerType (raster or vector).")
469
+ data_source: DataSourceConfig | None = Field(
470
+ default=None,
471
+ description="Optional DataSourceConfig if this layer is retrievable.",
472
+ )
473
+ alias: str | None = Field(
474
+ default=None, description="Alias for this layer to use in the tile store."
475
+ )
476
+
477
+ # Raster layer options.
478
+ band_sets: list[BandSetConfig] = Field(
479
+ default_factory=lambda: [],
480
+ description="For raster layers, the bands to store in this layer.",
481
+ )
482
+ resampling_method: ResamplingMethod = Field(
483
+ default=ResamplingMethod.BILINEAR,
484
+ description="For raster layers, how to resample rasters (if neeed), default bilinear resampling.",
485
+ )
486
+ compositing_method: CompositingMethod = Field(
487
+ default=CompositingMethod.FIRST_VALID,
488
+ description="For raster layers, how to compute pixel values in the composite of each window's items.",
489
+ )
490
+
491
+ # Vector layer options.
492
+ vector_format: dict[str, Any] = Field(
493
+ default_factory=lambda: {
494
+ "class_path": "rslearn.utils.vector_format.GeojsonVectorFormat"
495
+ },
496
+ description="For vector layers, the jsonargparse configuration for the VectorFormat.",
497
+ )
498
+ class_property_name: str | None = Field(
499
+ default=None,
500
+ description="Optional metadata field indicating that the GeoJSON features contain a property that corresponds to a class label, and this is the name of that property.",
501
+ )
502
+ class_names: list[str] | None = Field(
503
+ default=None,
504
+ description="The list of classes that the class_property_name property could be set to.",
505
+ )
506
+
507
+ @model_validator(mode="after")
508
+ def after_validator(self) -> "LayerConfig":
509
+ """Ensure the LayerConfig is valid."""
510
+ if self.type == LayerType.RASTER and len(self.band_sets) == 0:
511
+ raise ValueError(
512
+ "band sets must be specified and non-empty for raster layers"
513
+ )
514
+
515
+ return self
516
+
517
+ def __hash__(self) -> int:
518
+ """Return a hash of this LayerConfig."""
519
+ return hash(json.dumps(self.model_dump(mode="json"), sort_keys=True))
520
+
521
+ def __eq__(self, other: Any) -> bool:
522
+ """Returns whether other is the same as this LayerConfig.
523
+
524
+ Args:
525
+ other: the other object to compare.
526
+ """
527
+ if not isinstance(other, LayerConfig):
528
+ return False
529
+ return self.model_dump() == other.model_dump()
530
+
531
+ @functools.cache
532
+ def instantiate_data_source(self, ds_path: UPath | None = None) -> "DataSource":
533
+ """Instantiate the data source specified by this config.
534
+
535
+ Args:
536
+ ds_path: optional dataset path to include in the DataSourceContext.
537
+
538
+ Returns:
539
+ the DataSource object.
540
+ """
541
+ from rslearn.data_sources.data_source import DataSource, DataSourceContext
542
+ from rslearn.utils.jsonargparse import data_source_context_serializer
543
+
544
+ logger.debug("getting a data source for dataset at %s", ds_path)
545
+ if self.data_source is None:
546
+ raise ValueError("This layer does not specify a data source")
547
+
548
+ # Inject the DataSourceContext into the args.
549
+ context = DataSourceContext(
550
+ ds_path=ds_path,
551
+ layer_config=self,
552
+ )
553
+ ds_config: dict[str, Any] = {
554
+ "class_path": self.data_source.class_path,
555
+ "init_args": copy.deepcopy(self.data_source.init_args),
556
+ }
557
+ ds_config["init_args"]["context"] = data_source_context_serializer(context)
558
+
559
+ # Now we can parse with jsonargparse.
560
+ from rslearn.utils.jsonargparse import (
561
+ data_source_context_serializer,
562
+ init_jsonargparse,
563
+ )
564
+
565
+ init_jsonargparse()
566
+ parser = jsonargparse.ArgumentParser()
567
+ parser.add_argument("--data_source", type=DataSource)
568
+ cfg = parser.parse_object({"data_source": ds_config})
569
+ data_source = parser.instantiate_classes(cfg).data_source
570
+ return data_source
571
+
572
+ def instantiate_vector_format(self) -> VectorFormat:
573
+ """Instantiate the vector format specified by this config."""
574
+ if self.type != LayerType.VECTOR:
575
+ raise ValueError(
576
+ f"cannot instantiate vector format for layer with type {self.type}"
577
+ )
578
+
579
+ from rslearn.utils.jsonargparse import init_jsonargparse
580
+
581
+ init_jsonargparse()
582
+ parser = jsonargparse.ArgumentParser()
583
+ parser.add_argument("--vector_format", type=VectorFormat)
584
+ cfg = parser.parse_object({"vector_format": self.vector_format})
585
+ vector_format = parser.instantiate_classes(cfg).vector_format
586
+ return vector_format
587
+
588
+
589
+ class DatasetConfig(BaseModel):
590
+ """Overall dataset configuration."""
591
+
592
+ layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
593
+ tile_store: dict[str, Any] = Field(
594
+ default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
595
+ description="jsonargparse configuration for the TileStore.",
596
+ )
@@ -0,0 +1,28 @@
1
+ """Data sources.
2
+
3
+ A DataSource represents a source from which raster and vector data corresponding to
4
+ spatiotemporal windows can be retrieved.
5
+
6
+ A DataSource consists of items that can be ingested, like Sentinel-2 scenes or
7
+ OpenStreetMap PBF files.
8
+
9
+ Each source supports operations to lookup items that match with spatiotemporal
10
+ geometries, and ingest those items.
11
+ """
12
+
13
+ from .data_source import (
14
+ DataSource,
15
+ DataSourceContext,
16
+ Item,
17
+ ItemLookupDataSource,
18
+ RetrieveItemDataSource,
19
+ )
20
+
21
+ __all__ = (
22
+ "DataSource",
23
+ "DataSourceContext",
24
+ "Item",
25
+ "ItemLookupDataSource",
26
+ "RetrieveItemDataSource",
27
+ "data_source_from_config",
28
+ )