rslearn 0.0.3__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. {rslearn-0.0.3/rslearn.egg-info → rslearn-0.0.5}/PKG-INFO +6 -3
  2. {rslearn-0.0.3 → rslearn-0.0.5}/pyproject.toml +8 -8
  3. rslearn-0.0.5/rslearn/arg_parser.py +59 -0
  4. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/copernicus.py +10 -8
  5. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/earthdaily.py +21 -1
  6. rslearn-0.0.5/rslearn/data_sources/eurocrops.py +246 -0
  7. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/gcp_public_data.py +3 -3
  8. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/local_files.py +11 -0
  9. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/openstreetmap.py +2 -4
  10. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/utils.py +1 -17
  11. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/main.py +10 -1
  12. rslearn-0.0.5/rslearn/models/copernicusfm.py +216 -0
  13. rslearn-0.0.5/rslearn/models/copernicusfm_src/__init__.py +1 -0
  14. rslearn-0.0.5/rslearn/models/copernicusfm_src/aurora/area.py +50 -0
  15. rslearn-0.0.5/rslearn/models/copernicusfm_src/aurora/fourier.py +134 -0
  16. rslearn-0.0.5/rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +523 -0
  17. rslearn-0.0.5/rslearn/models/copernicusfm_src/flexivit/patch_embed.py +260 -0
  18. rslearn-0.0.5/rslearn/models/copernicusfm_src/flexivit/utils.py +69 -0
  19. rslearn-0.0.5/rslearn/models/copernicusfm_src/model_vit.py +348 -0
  20. rslearn-0.0.5/rslearn/models/copernicusfm_src/util/pos_embed.py +216 -0
  21. rslearn-0.0.5/rslearn/models/panopticon.py +167 -0
  22. rslearn-0.0.5/rslearn/models/presto/__init__.py +5 -0
  23. rslearn-0.0.5/rslearn/models/presto/presto.py +247 -0
  24. rslearn-0.0.5/rslearn/models/presto/single_file_presto.py +932 -0
  25. rslearn-0.0.5/rslearn/models/trunk.py +136 -0
  26. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/unet.py +15 -0
  27. rslearn-0.0.5/rslearn/train/callbacks/adapters.py +53 -0
  28. rslearn-0.0.5/rslearn/train/callbacks/freeze_unfreeze.py +410 -0
  29. rslearn-0.0.5/rslearn/train/callbacks/gradients.py +129 -0
  30. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/data_module.py +70 -41
  31. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/dataset.py +232 -54
  32. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/lightning_module.py +4 -0
  33. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/prediction_writer.py +7 -0
  34. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/scheduler.py +15 -0
  35. rslearn-0.0.5/rslearn/train/tasks/per_pixel_regression.py +259 -0
  36. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/tasks/regression.py +6 -4
  37. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/tasks/segmentation.py +44 -14
  38. rslearn-0.0.5/rslearn/train/transforms/mask.py +69 -0
  39. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/geometry.py +8 -8
  40. {rslearn-0.0.3 → rslearn-0.0.5/rslearn.egg-info}/PKG-INFO +6 -3
  41. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn.egg-info/SOURCES.txt +18 -2
  42. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn.egg-info/requires.txt +1 -1
  43. rslearn-0.0.3/rslearn/models/moe/distributed.py +0 -262
  44. rslearn-0.0.3/rslearn/models/moe/soft.py +0 -676
  45. rslearn-0.0.3/rslearn/models/trunk.py +0 -280
  46. rslearn-0.0.3/rslearn/train/callbacks/freeze_unfreeze.py +0 -91
  47. rslearn-0.0.3/rslearn/train/callbacks/gradients.py +0 -109
  48. {rslearn-0.0.3 → rslearn-0.0.5}/LICENSE +0 -0
  49. {rslearn-0.0.3 → rslearn-0.0.5}/README.md +0 -0
  50. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/__init__.py +0 -0
  51. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/config/__init__.py +0 -0
  52. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/config/dataset.py +0 -0
  53. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/const.py +0 -0
  54. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/__init__.py +0 -0
  55. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/aws_landsat.py +0 -0
  56. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/aws_open_data.py +0 -0
  57. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/aws_sentinel1.py +0 -0
  58. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/climate_data_store.py +0 -0
  59. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/data_source.py +0 -0
  60. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/earthdata_srtm.py +0 -0
  61. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/geotiff.py +0 -0
  62. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/google_earth_engine.py +0 -0
  63. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/planet.py +0 -0
  64. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/planet_basemap.py +0 -0
  65. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/planetary_computer.py +0 -0
  66. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/raster_source.py +0 -0
  67. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/usda_cdl.py +0 -0
  68. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/usgs_landsat.py +0 -0
  69. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/vector_source.py +0 -0
  70. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/worldcereal.py +0 -0
  71. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/worldcover.py +0 -0
  72. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/worldpop.py +0 -0
  73. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/data_sources/xyz_tiles.py +0 -0
  74. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/__init__.py +0 -0
  75. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/add_windows.py +0 -0
  76. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/dataset.py +0 -0
  77. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/index.py +0 -0
  78. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/manage.py +0 -0
  79. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/materialize.py +0 -0
  80. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/remap.py +0 -0
  81. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/dataset/window.py +0 -0
  82. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/log_utils.py +0 -0
  83. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/__init__.py +0 -0
  84. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/clip.py +0 -0
  85. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/conv.py +0 -0
  86. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/croma.py +0 -0
  87. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/detr/__init__.py +0 -0
  88. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/detr/box_ops.py +0 -0
  89. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/detr/detr.py +0 -0
  90. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/detr/matcher.py +0 -0
  91. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/detr/position_encoding.py +0 -0
  92. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/detr/transformer.py +0 -0
  93. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/detr/util.py +0 -0
  94. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/faster_rcnn.py +0 -0
  95. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/fpn.py +0 -0
  96. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/module_wrapper.py +0 -0
  97. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/molmo.py +0 -0
  98. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/multitask.py +0 -0
  99. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/pick_features.py +0 -0
  100. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/pooling_decoder.py +0 -0
  101. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/registry.py +0 -0
  102. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/sam2_enc.py +0 -0
  103. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/satlaspretrain.py +0 -0
  104. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/simple_time_series.py +0 -0
  105. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/singletask.py +0 -0
  106. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/ssl4eo_s12.py +0 -0
  107. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/swin.py +0 -0
  108. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/task_embedding.py +0 -0
  109. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/terramind.py +0 -0
  110. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/upsample.py +0 -0
  111. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/models/use_croma.py +0 -0
  112. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/py.typed +0 -0
  113. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/tile_stores/__init__.py +0 -0
  114. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/tile_stores/default.py +0 -0
  115. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/tile_stores/tile_store.py +0 -0
  116. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/__init__.py +0 -0
  117. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/callbacks/__init__.py +0 -0
  118. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/callbacks/peft.py +0 -0
  119. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/optimizer.py +0 -0
  120. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/tasks/__init__.py +0 -0
  121. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/tasks/classification.py +0 -0
  122. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/tasks/detection.py +0 -0
  123. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/tasks/multi_task.py +0 -0
  124. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/tasks/task.py +0 -0
  125. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/transforms/__init__.py +0 -0
  126. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/transforms/concatenate.py +0 -0
  127. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/transforms/crop.py +0 -0
  128. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/transforms/flip.py +0 -0
  129. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/transforms/normalize.py +0 -0
  130. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/transforms/pad.py +0 -0
  131. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/train/transforms/transform.py +0 -0
  132. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/__init__.py +0 -0
  133. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/array.py +0 -0
  134. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/feature.py +0 -0
  135. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/fsspec.py +0 -0
  136. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/get_utm_ups_crs.py +0 -0
  137. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/grid_index.py +0 -0
  138. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/jsonargparse.py +0 -0
  139. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/mp.py +0 -0
  140. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/raster_format.py +0 -0
  141. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/rtree_index.py +0 -0
  142. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/spatial_index.py +0 -0
  143. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/sqlite_index.py +0 -0
  144. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/time.py +0 -0
  145. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn/utils/vector_format.py +0 -0
  146. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn.egg-info/dependency_links.txt +0 -0
  147. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn.egg-info/entry_points.txt +0 -0
  148. {rslearn-0.0.3 → rslearn-0.0.5}/rslearn.egg-info/top_level.txt +0 -0
  149. {rslearn-0.0.3 → rslearn-0.0.5}/setup.cfg +0 -0
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: A library for developing remote sensing datasets and models
5
- Author-email: Favyen Bastani <favyenb@allenai.org>, Yawen Zhang <yawenz@allenai.org>, Patrick Beukema <patrickb@allenai.org>, Henry Herzog <henryh@allenai.org>, Piper Wolters <piperw@allenai.org>
5
+ Author: OlmoEarth Team
6
6
  License: Apache License
7
7
  Version 2.0, January 2004
8
8
  http://www.apache.org/licenses/
@@ -205,6 +205,9 @@ License: Apache License
205
205
  See the License for the specific language governing permissions and
206
206
  limitations under the License.
207
207
 
208
+ Project-URL: homepage, https://github.com/allenai/rslearn
209
+ Project-URL: issues, https://github.com/allenai/rslearn/issues
210
+ Project-URL: repository, https://github.com/allenai/rslearn
208
211
  Requires-Python: >=3.11
209
212
  Description-Content-Type: text/markdown
210
213
  License-File: LICENSE
@@ -227,7 +230,7 @@ Requires-Dist: universal_pathlib>=0.2.6
227
230
  Provides-Extra: extra
228
231
  Requires-Dist: accelerate>=1.10; extra == "extra"
229
232
  Requires-Dist: cdsapi>=0.7.6; extra == "extra"
230
- Requires-Dist: earthdaily[platform]>=1.0.0; extra == "extra"
233
+ Requires-Dist: earthdaily[platform]>=1.0.7; extra == "extra"
231
234
  Requires-Dist: earthengine-api>=1.6.3; extra == "extra"
232
235
  Requires-Dist: einops>=0.8; extra == "extra"
233
236
  Requires-Dist: gcsfs==2025.3.0; extra == "extra"
@@ -1,18 +1,13 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.3"
3
+ version = "0.0.5"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
- {name = "Favyen Bastani", email = "favyenb@allenai.org"},
7
- {name = "Yawen Zhang", email = "yawenz@allenai.org"},
8
- {name = "Patrick Beukema", email = "patrickb@allenai.org"},
9
- {name = "Henry Herzog", email = "henryh@allenai.org"},
10
- {name = "Piper Wolters", email = "piperw@allenai.org"},
6
+ { name = "OlmoEarth Team" },
11
7
  ]
12
8
  readme = "README.md"
13
9
  license = {file = "LICENSE"}
14
10
  requires-python = ">=3.11"
15
-
16
11
  dependencies = [
17
12
  "boto3>=1.39",
18
13
  "class_registry>=2.1",
@@ -39,7 +34,7 @@ dependencies = [
39
34
  extra = [
40
35
  "accelerate>=1.10",
41
36
  "cdsapi>=0.7.6",
42
- "earthdaily[platform]>=1.0.0",
37
+ "earthdaily[platform]>=1.0.7",
43
38
  "earthengine-api>=1.6.3",
44
39
  "einops>=0.8",
45
40
  "gcsfs==2025.3.0",
@@ -71,6 +66,11 @@ dev = [
71
66
  "pytest-xdist",
72
67
  ]
73
68
 
69
+ [project.urls]
70
+ homepage = "https://github.com/allenai/rslearn"
71
+ issues = "https://github.com/allenai/rslearn/issues"
72
+ repository = "https://github.com/allenai/rslearn"
73
+
74
74
  [build-system]
75
75
  requires = ["setuptools>=61"]
76
76
  build-backend = "setuptools.build_meta"
@@ -0,0 +1,59 @@
1
+ """Custom Lightning ArgumentParser with environment variable substitution support."""
2
+
3
+ import os
4
+ import re
5
+ from typing import Any
6
+
7
+ from jsonargparse import Namespace
8
+ from lightning.pytorch.cli import LightningArgumentParser
9
+
10
+
11
+ def substitute_env_vars_in_string(content: str) -> str:
12
+ """Substitute environment variables in a string.
13
+
14
+ Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
15
+ This works on raw string content before YAML parsing.
16
+
17
+ Args:
18
+ content: The string content containing template variables
19
+
20
+ Returns:
21
+ The string with environment variables substituted
22
+ """
23
+ pattern = r"\$\{([^}]+)\}"
24
+
25
+ def replace_variable(match_obj: re.Match[str]) -> str:
26
+ var_name = match_obj.group(1)
27
+ env_value = os.getenv(var_name, "")
28
+ return env_value if env_value is not None else ""
29
+
30
+ return re.sub(pattern, replace_variable, content)
31
+
32
+
33
+ class RslearnArgumentParser(LightningArgumentParser):
34
+ """Custom ArgumentParser that substitutes environment variables in config files.
35
+
36
+ This parser extends LightningArgumentParser to automatically substitute
37
+ ${VAR_NAME} patterns with environment variable values before parsing
38
+ configuration content. This allows config files to use environment
39
+ variables while still passing Lightning's validation.
40
+ """
41
+
42
+ def parse_string(
43
+ self,
44
+ cfg_str: str,
45
+ cfg_path: str | os.PathLike = "",
46
+ ext_vars: dict | None = None,
47
+ env: bool | None = None,
48
+ defaults: bool = True,
49
+ with_meta: bool | None = None,
50
+ **kwargs: Any,
51
+ ) -> Namespace:
52
+ """Pre-processes string for environment variable substitution before parsing."""
53
+ # Substitute environment variables in the config string before parsing
54
+ substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
55
+
56
+ # Call the parent method with the substituted config
57
+ return super().parse_string(
58
+ substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
59
+ )
@@ -34,7 +34,7 @@ from rslearn.utils.geometry import (
34
34
  FloatBounds,
35
35
  STGeometry,
36
36
  flatten_shape,
37
- split_shape_at_prime_meridian,
37
+ split_shape_at_antimeridian,
38
38
  )
39
39
  from rslearn.utils.grid_index import GridIndex
40
40
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
@@ -160,7 +160,7 @@ def get_sentinel2_tile_index() -> dict[str, list[FloatBounds]]:
160
160
  # issues where the tile bounds go from -180 to 180 longitude and thus match
161
161
  # with anything at the same latitude.
162
162
  union_shp = shapely.unary_union(shapes)
163
- split_shapes = flatten_shape(split_shape_at_prime_meridian(union_shp))
163
+ split_shapes = flatten_shape(split_shape_at_antimeridian(union_shp))
164
164
  bounds_list: list[FloatBounds] = []
165
165
  for shp in split_shapes:
166
166
  bounds_list.append(shp.bounds)
@@ -222,10 +222,10 @@ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
222
222
  """
223
223
  tile_index = load_sentinel2_tile_index(cache_dir)
224
224
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
225
- # If the shape is a collection, it could be cutting across prime meridian.
225
+ # If the shape is a collection, it could be cutting across antimeridian.
226
226
  # So we query each component shape separately and collect the results to avoid
227
227
  # issues.
228
- # We assume the caller has already applied split_at_prime_meridian.
228
+ # We assume the caller has already applied split_at_antimeridian.
229
229
  results = set()
230
230
  for shp in flatten_shape(wgs84_geometry.shp):
231
231
  for result in tile_index.query(shp.bounds):
@@ -319,7 +319,6 @@ class Copernicus(DataSource):
319
319
  then we attempt to read the username/password from COPERNICUS_USERNAME
320
320
  and COPERNICUS_PASSWORD (this is useful since access tokens are only
321
321
  valid for an hour).
322
- password: set API username/password instead of access token.
323
322
  query_filter: filter string to include when searching for items. This will
324
323
  be appended to other name, geographic, and sensing time filters where
325
324
  applicable. For example, "Collection/Name eq 'SENTINEL-2'". See the API
@@ -368,6 +367,7 @@ class Copernicus(DataSource):
368
367
  "order_by",
369
368
  "sort_by",
370
369
  "sort_desc",
370
+ "timeout",
371
371
  ]
372
372
  for k in simple_optionals:
373
373
  if k in d:
@@ -709,6 +709,8 @@ class Sentinel2(Copernicus):
709
709
  "B12": ["B12"],
710
710
  "B8A": ["B8A"],
711
711
  "TCI": ["R", "G", "B"],
712
+ # L1C-only products.
713
+ "B10": ["B10"],
712
714
  # L2A-only products.
713
715
  "AOT": ["AOT"],
714
716
  "WVP": ["WVP"],
@@ -809,17 +811,16 @@ class Sentinel2(Copernicus):
809
811
 
810
812
  kwargs: dict[str, Any] = dict(
811
813
  assets=list(needed_assets),
814
+ product_type=Sentinel2ProductType[d["product_type"]],
812
815
  )
813
816
 
814
- if "product_type" in d:
815
- kwargs["product_type"] = Sentinel2ProductType(d["product_type"])
816
-
817
817
  simple_optionals = [
818
818
  "harmonize",
819
819
  "access_token",
820
820
  "order_by",
821
821
  "sort_by",
822
822
  "sort_desc",
823
+ "timeout",
823
824
  ]
824
825
  for k in simple_optionals:
825
826
  if k in d:
@@ -965,6 +966,7 @@ class Sentinel1(Copernicus):
965
966
  "order_by",
966
967
  "sort_by",
967
968
  "sort_desc",
969
+ "timeout",
968
970
  ]
969
971
  for k in simple_optionals:
970
972
  if k in d:
@@ -82,6 +82,8 @@ class EarthDaily(DataSource, TileStore):
82
82
  timeout: timedelta = timedelta(seconds=10),
83
83
  skip_items_missing_assets: bool = False,
84
84
  cache_dir: UPath | None = None,
85
+ max_retries: int = 3,
86
+ retry_backoff_factor: float = 5.0,
85
87
  service_name: Literal["platform"] = "platform",
86
88
  ):
87
89
  """Initialize a new EarthDaily instance.
@@ -99,6 +101,11 @@ class EarthDaily(DataSource, TileStore):
99
101
  cache_dir: optional directory to cache items by name, including asset URLs.
100
102
  If not set, there will be no cache and instead STAC requests will be
101
103
  needed each time.
104
+ max_retries: the maximum number of retry attempts for HTTP requests that fail
105
+ due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
106
+ retry_backoff_factor: backoff factor for exponential retry delays between HTTP
107
+ request attempts. The delay between retries is calculated using the formula:
108
+ `(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
102
109
  service_name: the service name, only "platform" is supported, the other
103
110
  services "legacy" and "internal" are not supported.
104
111
  """
@@ -110,6 +117,8 @@ class EarthDaily(DataSource, TileStore):
110
117
  self.timeout = timeout
111
118
  self.skip_items_missing_assets = skip_items_missing_assets
112
119
  self.cache_dir = cache_dir
120
+ self.max_retries = max_retries
121
+ self.retry_backoff_factor = retry_backoff_factor
113
122
  self.service_name = service_name
114
123
 
115
124
  if cache_dir is not None:
@@ -139,6 +148,12 @@ class EarthDaily(DataSource, TileStore):
139
148
  if "cache_dir" in d:
140
149
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
141
150
 
151
+ if "max_retries" in d:
152
+ kwargs["max_retries"] = d["max_retries"]
153
+
154
+ if "retry_backoff_factor" in d:
155
+ kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
156
+
142
157
  simple_optionals = ["query", "sort_by", "sort_ascending"]
143
158
  for k in simple_optionals:
144
159
  if k in d:
@@ -159,7 +174,12 @@ class EarthDaily(DataSource, TileStore):
159
174
  if self.eds_client is not None:
160
175
  return self.eds_client, self.client, self.collection
161
176
 
162
- self.eds_client = EDSClient(EDSConfig())
177
+ self.eds_client = EDSClient(
178
+ EDSConfig(
179
+ max_retries=self.max_retries,
180
+ retry_backoff_factor=self.retry_backoff_factor,
181
+ )
182
+ )
163
183
 
164
184
  if self.service_name == "platform":
165
185
  self.client = self.eds_client.platform.pystac_client
@@ -0,0 +1,246 @@
1
+ """Data source for vector EuroCrops crop type data."""
2
+
3
+ import glob
4
+ import os
5
+ import tempfile
6
+ import zipfile
7
+ from datetime import UTC, datetime, timedelta
8
+ from typing import Any
9
+
10
+ import fiona
11
+ import requests
12
+ from rasterio.crs import CRS
13
+ from upath import UPath
14
+
15
+ from rslearn.config import QueryConfig, VectorLayerConfig
16
+ from rslearn.const import WGS84_PROJECTION
17
+ from rslearn.data_sources import DataSource, Item
18
+ from rslearn.data_sources.utils import match_candidate_items_to_window
19
+ from rslearn.log_utils import get_logger
20
+ from rslearn.tile_stores import TileStoreWithLayer
21
+ from rslearn.utils.feature import Feature
22
+ from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ class EuroCropsItem(Item):
28
+ """An item in the EuroCrops data source.
29
+
30
+ For simplicity, we have just one item per year, so each item combines all of the
31
+ country-level files for that year.
32
+ """
33
+
34
+ def __init__(self, name: str, geometry: STGeometry, zip_fnames: list[str]):
35
+ """Creates a new EuroCropsItem.
36
+
37
+ Args:
38
+ name: unique name of the item. It is just the year that this item
39
+ corresponds to.
40
+ geometry: the spatial and temporal extent of the item
41
+ zip_fnames: the filenames of the zip files that contain country-level crop
42
+ type data for this year.
43
+ """
44
+ super().__init__(name, geometry)
45
+ self.zip_fnames = zip_fnames
46
+
47
+ def serialize(self) -> dict:
48
+ """Serializes the item to a JSON-encodable dictionary."""
49
+ d = super().serialize()
50
+ d["zip_fnames"] = self.zip_fnames
51
+ return d
52
+
53
+ @staticmethod
54
+ def deserialize(d: dict) -> "EuroCropsItem":
55
+ """Deserializes an item from a JSON-decoded dictionary."""
56
+ item = super(EuroCropsItem, EuroCropsItem).deserialize(d)
57
+ return EuroCropsItem(
58
+ name=item.name, geometry=item.geometry, zip_fnames=d["zip_fnames"]
59
+ )
60
+
61
+
62
+ class EuroCrops(DataSource[EuroCropsItem]):
63
+ """A data source for EuroCrops vector data (v11).
64
+
65
+ See https://zenodo.org/records/14094196 for details.
66
+
67
+ While the source data is split into country-level files, this data source uses one
68
+ item per year for simplicity. So each item corresponds to all of the country-level
69
+ files for that year.
70
+
71
+ Note that the RO_ny.zip file is not used.
72
+ """
73
+
74
+ BASE_URL = "https://zenodo.org/records/14094196/files/"
75
+ FILENAMES_BY_YEAR = {
76
+ 2018: [
77
+ "FR_2018.zip",
78
+ ],
79
+ 2019: [
80
+ "DK_2019.zip",
81
+ ],
82
+ 2020: [
83
+ "ES_NA_2020.zip",
84
+ "FI_2020.zip",
85
+ "HR_2020.zip",
86
+ "NL_2020.zip",
87
+ ],
88
+ 2021: [
89
+ "AT_2021.zip",
90
+ "BE_VLG_2021.zip",
91
+ "BE_WAL_2021.zip",
92
+ "EE_2021.zip",
93
+ "LT_2021.zip",
94
+ "LV_2021.zip",
95
+ "PT_2021.zip",
96
+ "SE_2021.zip",
97
+ "SI_2021.zip",
98
+ "SK_2021.zip",
99
+ ],
100
+ 2023: [
101
+ "CZ_2023.zip",
102
+ "DE_BB_2023.zip",
103
+ "DE_LS_2021.zip",
104
+ "DE_NRW_2021.zip",
105
+ "ES_2023.zip",
106
+ "IE_2023.zip",
107
+ ],
108
+ }
109
+ TIMEOUT = timedelta(seconds=10)
110
+
111
+ @staticmethod
112
+ def from_config(config: VectorLayerConfig, ds_path: UPath) -> "EuroCrops":
113
+ """Creates a new EuroCrops instance from a configuration dictionary."""
114
+ if config.data_source is None:
115
+ raise ValueError("data_source is required")
116
+ return EuroCrops()
117
+
118
+ def _get_all_items(self) -> list[EuroCropsItem]:
119
+ """Get a list of all available items in the data source."""
120
+ items: list[EuroCropsItem] = []
121
+ for year, fnames in self.FILENAMES_BY_YEAR.items():
122
+ items.append(
123
+ EuroCropsItem(
124
+ str(year),
125
+ get_global_geometry(
126
+ time_range=(
127
+ datetime(year, 1, 1, tzinfo=UTC),
128
+ datetime(year + 1, 1, 1, tzinfo=UTC),
129
+ ),
130
+ ),
131
+ fnames,
132
+ )
133
+ )
134
+ return items
135
+
136
+ def get_items(
137
+ self, geometries: list[STGeometry], query_config: QueryConfig
138
+ ) -> list[list[list[EuroCropsItem]]]:
139
+ """Get a list of items in the data source intersecting the given geometries.
140
+
141
+ Args:
142
+ geometries: the spatiotemporal geometries
143
+ query_config: the query configuration
144
+
145
+ Returns:
146
+ List of groups of items that should be retrieved for each geometry.
147
+ """
148
+ wgs84_geometries = [
149
+ geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
150
+ ]
151
+ all_items = self._get_all_items()
152
+ groups = []
153
+ for geometry in wgs84_geometries:
154
+ cur_groups = match_candidate_items_to_window(
155
+ geometry, all_items, query_config
156
+ )
157
+ groups.append(cur_groups)
158
+ return groups
159
+
160
+ def deserialize_item(self, serialized_item: Any) -> EuroCropsItem:
161
+ """Deserializes an item from JSON-decoded data."""
162
+ return EuroCropsItem.deserialize(serialized_item)
163
+
164
+ def _extract_features(self, fname: str) -> list[Feature]:
165
+ """Download the given zip file, extract shapefile, and return list of features."""
166
+ with tempfile.TemporaryDirectory() as tmp_dir:
167
+ # Download the zip file.
168
+ url = self.BASE_URL + fname
169
+ logger.debug(f"Downloading zip file from {url}")
170
+ response = requests.get(
171
+ url,
172
+ stream=True,
173
+ timeout=self.TIMEOUT.total_seconds(),
174
+ allow_redirects=False,
175
+ )
176
+ response.raise_for_status()
177
+ zip_fname = os.path.join(tmp_dir, "data.zip")
178
+ with open(zip_fname, "wb") as f:
179
+ for chunk in response.iter_content(chunk_size=8192):
180
+ f.write(chunk)
181
+
182
+ # Extract all of the files and look for shapefile filename.
183
+ logger.debug(f"Extracting zip file {fname}")
184
+ with zipfile.ZipFile(zip_fname) as zip_f:
185
+ zip_f.extractall(path=tmp_dir)
186
+
187
+ # The shapefiles or geopackage files can appear at any level in the hierarchy.
188
+ # Most zip files contain one but some contain multiple (one per region).
189
+ shp_fnames = glob.glob(
190
+ "**/*.shp", root_dir=tmp_dir, recursive=True
191
+ ) + glob.glob("**/*.gpkg", root_dir=tmp_dir, recursive=True)
192
+ if len(shp_fnames) == 0:
193
+ tmp_dir_fnames = os.listdir(tmp_dir)
194
+ raise ValueError(
195
+ f"expected {fname} to contain .shp file but none found (matches={shp_fnames}, ls={tmp_dir_fnames})"
196
+ )
197
+
198
+ # Load the features from the shapefile(s).
199
+ features = []
200
+ for shp_fname in shp_fnames:
201
+ logger.debug(f"Loading feature list from {shp_fname}")
202
+ with fiona.open(os.path.join(tmp_dir, shp_fname)) as src:
203
+ crs = CRS.from_wkt(src.crs.to_wkt())
204
+ # Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
205
+ # should be 1 projection unit/pixel.
206
+ projection = Projection(crs, 1, 1)
207
+
208
+ for feat in src:
209
+ features.append(
210
+ Feature.from_geojson(
211
+ projection,
212
+ {
213
+ "type": "Feature",
214
+ "geometry": dict(feat.geometry),
215
+ "properties": dict(feat.properties),
216
+ },
217
+ )
218
+ )
219
+
220
+ return features
221
+
222
+ def ingest(
223
+ self,
224
+ tile_store: TileStoreWithLayer,
225
+ items: list[EuroCropsItem],
226
+ geometries: list[list[STGeometry]],
227
+ ) -> None:
228
+ """Ingest items into the given tile store.
229
+
230
+ Args:
231
+ tile_store: the tile store to ingest into
232
+ items: the items to ingest
233
+ geometries: a list of geometries needed for each item
234
+ """
235
+ for item in items:
236
+ if tile_store.is_vector_ready(item.name):
237
+ continue
238
+
239
+ # Get features across all shapefiles.
240
+ features: list[Feature] = []
241
+ for fname in item.zip_fnames:
242
+ logger.debug(f"Getting features from {fname} for item {item.name}")
243
+ features.extend(self._extract_features(fname))
244
+
245
+ logger.debug(f"Writing features for {item.name} to the tile store")
246
+ tile_store.write_vector(item.name, features)
@@ -26,7 +26,7 @@ from rslearn.data_sources.utils import match_candidate_items_to_window
26
26
  from rslearn.log_utils import get_logger
27
27
  from rslearn.tile_stores import TileStoreWithLayer
28
28
  from rslearn.utils.fsspec import join_upath, open_atomic
29
- from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_prime_meridian
29
+ from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_antimeridian
30
30
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
31
31
 
32
32
  from .copernicus import get_harmonize_callback, get_sentinel2_tiles
@@ -358,7 +358,7 @@ class Sentinel2(DataSource):
358
358
  shp = shapely.box(*bounds)
359
359
  sensing_time = row["sensing_time"]
360
360
  geometry = STGeometry(WGS84_PROJECTION, shp, (sensing_time, sensing_time))
361
- geometry = split_at_prime_meridian(geometry)
361
+ geometry = split_at_antimeridian(geometry)
362
362
 
363
363
  cloud_cover = float(row["cloud_cover"])
364
364
 
@@ -511,7 +511,7 @@ class Sentinel2(DataSource):
511
511
 
512
512
  time_range = (product_xml.start_time, product_xml.start_time)
513
513
  geometry = STGeometry(WGS84_PROJECTION, product_xml.shp, time_range)
514
- geometry = split_at_prime_meridian(geometry)
514
+ geometry = split_at_antimeridian(geometry)
515
515
 
516
516
  # Sometimes the geometry is not valid.
517
517
  # We just apply make_valid on it to correct issues.
@@ -232,6 +232,17 @@ class RasterImporter(Importer):
232
232
  projection = Projection(crs, x_resolution, y_resolution)
233
233
  geometry = STGeometry(projection, shp, None)
234
234
 
235
+ if geometry.is_too_large():
236
+ geometry = get_global_geometry(time_range=None)
237
+ logger.warning(
238
+ "Global geometry detected: this geometry will be matched against all "
239
+ "windows in the rslearn dataset. When using settings like "
240
+ "max_matches=1 and space_mode=MOSAIC, this may cause windows outside "
241
+ "the geometry’s valid bounds to be materialized from the global raster "
242
+ "instead of a more appropriate source. Consider using COMPOSITE mode, "
243
+ "or increasing max_matches if this behavior is unintended."
244
+ )
245
+
235
246
  if spec.name:
236
247
  item_name = spec.name
237
248
  else:
@@ -1,4 +1,4 @@
1
- """Data source for raster data on public Cloud Storage buckets."""
1
+ """Data source for OpenStreetMap vector features."""
2
2
 
3
3
  import json
4
4
  import shutil
@@ -392,7 +392,7 @@ class OpenStreetMap(DataSource[OsmItem]):
392
392
  bounds_fname: UPath,
393
393
  categories: dict[str, Filter],
394
394
  ):
395
- """Initialize a new Sentinel2 instance.
395
+ """Initialize a new OpenStreetMap instance.
396
396
 
397
397
  Args:
398
398
  config: the configuration of this layer.
@@ -508,8 +508,6 @@ class OpenStreetMap(DataSource[OsmItem]):
508
508
  items: the items to ingest
509
509
  geometries: a list of geometries needed for each item
510
510
  """
511
- item_names = [item.name for item in items]
512
- item_names.sort()
513
511
  for cur_item, cur_geometries in zip(items, geometries):
514
512
  if tile_store.is_vector_ready(cur_item.name):
515
513
  continue
@@ -256,23 +256,7 @@ def match_candidate_items_to_window(
256
256
  if item_geom.is_global():
257
257
  item_geom = geometry
258
258
  else:
259
- # Windows are usually smaller than items.
260
- # So we first clip the item to the window bounds in the item's
261
- # projection, then re-project the item to the window's projection.
262
- buffered_window_geom = STGeometry(
263
- geometry.projection,
264
- geometry.shp.buffer(1),
265
- geometry.time_range,
266
- )
267
- window_shp_in_item_proj = buffered_window_geom.to_projection(
268
- item_geom.projection
269
- ).shp
270
- clipped_item_geom = STGeometry(
271
- item_geom.projection,
272
- item_geom.shp.intersection(window_shp_in_item_proj),
273
- item_geom.time_range,
274
- )
275
- item_geom = clipped_item_geom.to_projection(geometry.projection)
259
+ item_geom = item_geom.to_projection(geometry.projection)
276
260
  item_shps.append(item_geom.shp)
277
261
 
278
262
  if query_config.space_mode == SpaceMode.CONTAINS:
@@ -13,6 +13,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
13
13
  from rasterio.crs import CRS
14
14
  from upath import UPath
15
15
 
16
+ from rslearn.arg_parser import RslearnArgumentParser
16
17
  from rslearn.config import LayerConfig
17
18
  from rslearn.const import WGS84_EPSG
18
19
  from rslearn.data_sources import Item, data_source_from_config
@@ -779,7 +780,7 @@ def dataset_build_index() -> None:
779
780
 
780
781
 
781
782
  class RslearnLightningCLI(LightningCLI):
782
- """LightningCLI that links data.tasks to model.tasks."""
783
+ """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
783
784
 
784
785
  def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
785
786
  """Link data.tasks to model.tasks.
@@ -787,6 +788,7 @@ class RslearnLightningCLI(LightningCLI):
787
788
  Args:
788
789
  parser: the argument parser
789
790
  """
791
+ # Link data.tasks to model.tasks
790
792
  parser.link_arguments(
791
793
  "data.init_args.task", "model.init_args.task", apply_on="instantiate"
792
794
  )
@@ -815,6 +817,12 @@ class RslearnLightningCLI(LightningCLI):
815
817
  # sampler as needed.
816
818
  c.trainer.use_distributed_sampler = False
817
819
 
820
+ # For predict, make sure that return_predictions is False.
821
+ # Otherwise all the predictions would be stored in memory which can lead to
822
+ # high memory consumption.
823
+ if subcommand == "predict":
824
+ c.return_predictions = False
825
+
818
826
 
819
827
  def model_handler() -> None:
820
828
  """Handler for any rslearn model X commands."""
@@ -825,6 +833,7 @@ def model_handler() -> None:
825
833
  subclass_mode_model=True,
826
834
  subclass_mode_data=True,
827
835
  save_config_kwargs={"overwrite": True},
836
+ parser_class=RslearnArgumentParser,
828
837
  )
829
838
 
830
839