rslearn 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. {rslearn-0.0.3/rslearn.egg-info → rslearn-0.0.4}/PKG-INFO +3 -3
  2. {rslearn-0.0.3 → rslearn-0.0.4}/pyproject.toml +3 -7
  3. rslearn-0.0.4/rslearn/arg_parser.py +59 -0
  4. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/copernicus.py +4 -4
  5. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/earthdaily.py +21 -1
  6. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/gcp_public_data.py +3 -3
  7. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/utils.py +1 -17
  8. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/main.py +10 -1
  9. rslearn-0.0.4/rslearn/models/trunk.py +136 -0
  10. rslearn-0.0.4/rslearn/train/callbacks/adapters.py +53 -0
  11. rslearn-0.0.4/rslearn/train/callbacks/freeze_unfreeze.py +410 -0
  12. rslearn-0.0.4/rslearn/train/callbacks/gradients.py +129 -0
  13. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/data_module.py +70 -41
  14. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/dataset.py +232 -54
  15. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/lightning_module.py +4 -0
  16. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/prediction_writer.py +7 -0
  17. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/scheduler.py +15 -0
  18. rslearn-0.0.4/rslearn/train/tasks/per_pixel_regression.py +259 -0
  19. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/regression.py +6 -4
  20. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/segmentation.py +44 -14
  21. rslearn-0.0.4/rslearn/train/transforms/mask.py +69 -0
  22. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/geometry.py +8 -8
  23. {rslearn-0.0.3 → rslearn-0.0.4/rslearn.egg-info}/PKG-INFO +3 -3
  24. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/SOURCES.txt +4 -2
  25. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/requires.txt +1 -1
  26. rslearn-0.0.3/rslearn/models/moe/distributed.py +0 -262
  27. rslearn-0.0.3/rslearn/models/moe/soft.py +0 -676
  28. rslearn-0.0.3/rslearn/models/trunk.py +0 -280
  29. rslearn-0.0.3/rslearn/train/callbacks/freeze_unfreeze.py +0 -91
  30. rslearn-0.0.3/rslearn/train/callbacks/gradients.py +0 -109
  31. {rslearn-0.0.3 → rslearn-0.0.4}/LICENSE +0 -0
  32. {rslearn-0.0.3 → rslearn-0.0.4}/README.md +0 -0
  33. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/__init__.py +0 -0
  34. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/config/__init__.py +0 -0
  35. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/config/dataset.py +0 -0
  36. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/const.py +0 -0
  37. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/__init__.py +0 -0
  38. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/aws_landsat.py +0 -0
  39. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/aws_open_data.py +0 -0
  40. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/aws_sentinel1.py +0 -0
  41. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/climate_data_store.py +0 -0
  42. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/data_source.py +0 -0
  43. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/earthdata_srtm.py +0 -0
  44. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/geotiff.py +0 -0
  45. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/google_earth_engine.py +0 -0
  46. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/local_files.py +0 -0
  47. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/openstreetmap.py +0 -0
  48. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/planet.py +0 -0
  49. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/planet_basemap.py +0 -0
  50. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/planetary_computer.py +0 -0
  51. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/raster_source.py +0 -0
  52. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/usda_cdl.py +0 -0
  53. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/usgs_landsat.py +0 -0
  54. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/vector_source.py +0 -0
  55. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/worldcereal.py +0 -0
  56. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/worldcover.py +0 -0
  57. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/worldpop.py +0 -0
  58. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/xyz_tiles.py +0 -0
  59. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/__init__.py +0 -0
  60. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/add_windows.py +0 -0
  61. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/dataset.py +0 -0
  62. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/index.py +0 -0
  63. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/manage.py +0 -0
  64. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/materialize.py +0 -0
  65. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/remap.py +0 -0
  66. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/window.py +0 -0
  67. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/log_utils.py +0 -0
  68. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/__init__.py +0 -0
  69. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/clip.py +0 -0
  70. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/conv.py +0 -0
  71. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/croma.py +0 -0
  72. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/__init__.py +0 -0
  73. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/box_ops.py +0 -0
  74. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/detr.py +0 -0
  75. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/matcher.py +0 -0
  76. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/position_encoding.py +0 -0
  77. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/transformer.py +0 -0
  78. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/util.py +0 -0
  79. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/faster_rcnn.py +0 -0
  80. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/fpn.py +0 -0
  81. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/module_wrapper.py +0 -0
  82. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/molmo.py +0 -0
  83. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/multitask.py +0 -0
  84. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/pick_features.py +0 -0
  85. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/pooling_decoder.py +0 -0
  86. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/registry.py +0 -0
  87. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/sam2_enc.py +0 -0
  88. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/satlaspretrain.py +0 -0
  89. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/simple_time_series.py +0 -0
  90. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/singletask.py +0 -0
  91. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/ssl4eo_s12.py +0 -0
  92. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/swin.py +0 -0
  93. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/task_embedding.py +0 -0
  94. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/terramind.py +0 -0
  95. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/unet.py +0 -0
  96. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/upsample.py +0 -0
  97. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/use_croma.py +0 -0
  98. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/py.typed +0 -0
  99. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/tile_stores/__init__.py +0 -0
  100. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/tile_stores/default.py +0 -0
  101. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/tile_stores/tile_store.py +0 -0
  102. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/__init__.py +0 -0
  103. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/callbacks/__init__.py +0 -0
  104. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/callbacks/peft.py +0 -0
  105. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/optimizer.py +0 -0
  106. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/__init__.py +0 -0
  107. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/classification.py +0 -0
  108. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/detection.py +0 -0
  109. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/multi_task.py +0 -0
  110. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/task.py +0 -0
  111. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/__init__.py +0 -0
  112. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/concatenate.py +0 -0
  113. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/crop.py +0 -0
  114. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/flip.py +0 -0
  115. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/normalize.py +0 -0
  116. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/pad.py +0 -0
  117. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/transform.py +0 -0
  118. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/__init__.py +0 -0
  119. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/array.py +0 -0
  120. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/feature.py +0 -0
  121. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/fsspec.py +0 -0
  122. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/get_utm_ups_crs.py +0 -0
  123. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/grid_index.py +0 -0
  124. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/jsonargparse.py +0 -0
  125. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/mp.py +0 -0
  126. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/raster_format.py +0 -0
  127. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/rtree_index.py +0 -0
  128. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/spatial_index.py +0 -0
  129. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/sqlite_index.py +0 -0
  130. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/time.py +0 -0
  131. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/vector_format.py +0 -0
  132. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/dependency_links.txt +0 -0
  133. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/entry_points.txt +0 -0
  134. {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/top_level.txt +0 -0
  135. {rslearn-0.0.3 → rslearn-0.0.4}/setup.cfg +0 -0
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: A library for developing remote sensing datasets and models
5
- Author-email: Favyen Bastani <favyenb@allenai.org>, Yawen Zhang <yawenz@allenai.org>, Patrick Beukema <patrickb@allenai.org>, Henry Herzog <henryh@allenai.org>, Piper Wolters <piperw@allenai.org>
5
+ Author: OlmoEarth Team
6
6
  License: Apache License
7
7
  Version 2.0, January 2004
8
8
  http://www.apache.org/licenses/
@@ -227,7 +227,7 @@ Requires-Dist: universal_pathlib>=0.2.6
227
227
  Provides-Extra: extra
228
228
  Requires-Dist: accelerate>=1.10; extra == "extra"
229
229
  Requires-Dist: cdsapi>=0.7.6; extra == "extra"
230
- Requires-Dist: earthdaily[platform]>=1.0.0; extra == "extra"
230
+ Requires-Dist: earthdaily[platform]>=1.0.7; extra == "extra"
231
231
  Requires-Dist: earthengine-api>=1.6.3; extra == "extra"
232
232
  Requires-Dist: einops>=0.8; extra == "extra"
233
233
  Requires-Dist: gcsfs==2025.3.0; extra == "extra"
@@ -1,13 +1,9 @@
1
1
  [project]
2
2
  name = "rslearn"
3
- version = "0.0.3"
3
+ version = "0.0.4"
4
4
  description = "A library for developing remote sensing datasets and models"
5
5
  authors = [
6
- {name = "Favyen Bastani", email = "favyenb@allenai.org"},
7
- {name = "Yawen Zhang", email = "yawenz@allenai.org"},
8
- {name = "Patrick Beukema", email = "patrickb@allenai.org"},
9
- {name = "Henry Herzog", email = "henryh@allenai.org"},
10
- {name = "Piper Wolters", email = "piperw@allenai.org"},
6
+ { name = "OlmoEarth Team" },
11
7
  ]
12
8
  readme = "README.md"
13
9
  license = {file = "LICENSE"}
@@ -39,7 +35,7 @@ dependencies = [
39
35
  extra = [
40
36
  "accelerate>=1.10",
41
37
  "cdsapi>=0.7.6",
42
- "earthdaily[platform]>=1.0.0",
38
+ "earthdaily[platform]>=1.0.7",
43
39
  "earthengine-api>=1.6.3",
44
40
  "einops>=0.8",
45
41
  "gcsfs==2025.3.0",
@@ -0,0 +1,59 @@
1
+ """Custom Lightning ArgumentParser with environment variable substitution support."""
2
+
3
+ import os
4
+ import re
5
+ from typing import Any
6
+
7
+ from jsonargparse import Namespace
8
+ from lightning.pytorch.cli import LightningArgumentParser
9
+
10
+
11
+ def substitute_env_vars_in_string(content: str) -> str:
12
+ """Substitute environment variables in a string.
13
+
14
+ Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
15
+ This works on raw string content before YAML parsing.
16
+
17
+ Args:
18
+ content: The string content containing template variables
19
+
20
+ Returns:
21
+ The string with environment variables substituted
22
+ """
23
+ pattern = r"\$\{([^}]+)\}"
24
+
25
+ def replace_variable(match_obj: re.Match[str]) -> str:
26
+ var_name = match_obj.group(1)
27
+ env_value = os.getenv(var_name, "")
28
+ return env_value if env_value is not None else ""
29
+
30
+ return re.sub(pattern, replace_variable, content)
31
+
32
+
33
+ class RslearnArgumentParser(LightningArgumentParser):
34
+ """Custom ArgumentParser that substitutes environment variables in config files.
35
+
36
+ This parser extends LightningArgumentParser to automatically substitute
37
+ ${VAR_NAME} patterns with environment variable values before parsing
38
+ configuration content. This allows config files to use environment
39
+ variables while still passing Lightning's validation.
40
+ """
41
+
42
+ def parse_string(
43
+ self,
44
+ cfg_str: str,
45
+ cfg_path: str | os.PathLike = "",
46
+ ext_vars: dict | None = None,
47
+ env: bool | None = None,
48
+ defaults: bool = True,
49
+ with_meta: bool | None = None,
50
+ **kwargs: Any,
51
+ ) -> Namespace:
52
+ """Pre-processes string for environment variable substitution before parsing."""
53
+ # Substitute environment variables in the config string before parsing
54
+ substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
55
+
56
+ # Call the parent method with the substituted config
57
+ return super().parse_string(
58
+ substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
59
+ )
@@ -34,7 +34,7 @@ from rslearn.utils.geometry import (
34
34
  FloatBounds,
35
35
  STGeometry,
36
36
  flatten_shape,
37
- split_shape_at_prime_meridian,
37
+ split_shape_at_antimeridian,
38
38
  )
39
39
  from rslearn.utils.grid_index import GridIndex
40
40
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
@@ -160,7 +160,7 @@ def get_sentinel2_tile_index() -> dict[str, list[FloatBounds]]:
160
160
  # issues where the tile bounds go from -180 to 180 longitude and thus match
161
161
  # with anything at the same latitude.
162
162
  union_shp = shapely.unary_union(shapes)
163
- split_shapes = flatten_shape(split_shape_at_prime_meridian(union_shp))
163
+ split_shapes = flatten_shape(split_shape_at_antimeridian(union_shp))
164
164
  bounds_list: list[FloatBounds] = []
165
165
  for shp in split_shapes:
166
166
  bounds_list.append(shp.bounds)
@@ -222,10 +222,10 @@ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
222
222
  """
223
223
  tile_index = load_sentinel2_tile_index(cache_dir)
224
224
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
225
- # If the shape is a collection, it could be cutting across prime meridian.
225
+ # If the shape is a collection, it could be cutting across antimeridian.
226
226
  # So we query each component shape separately and collect the results to avoid
227
227
  # issues.
228
- # We assume the caller has already applied split_at_prime_meridian.
228
+ # We assume the caller has already applied split_at_antimeridian.
229
229
  results = set()
230
230
  for shp in flatten_shape(wgs84_geometry.shp):
231
231
  for result in tile_index.query(shp.bounds):
@@ -82,6 +82,8 @@ class EarthDaily(DataSource, TileStore):
82
82
  timeout: timedelta = timedelta(seconds=10),
83
83
  skip_items_missing_assets: bool = False,
84
84
  cache_dir: UPath | None = None,
85
+ max_retries: int = 3,
86
+ retry_backoff_factor: float = 5.0,
85
87
  service_name: Literal["platform"] = "platform",
86
88
  ):
87
89
  """Initialize a new EarthDaily instance.
@@ -99,6 +101,11 @@ class EarthDaily(DataSource, TileStore):
99
101
  cache_dir: optional directory to cache items by name, including asset URLs.
100
102
  If not set, there will be no cache and instead STAC requests will be
101
103
  needed each time.
104
+ max_retries: the maximum number of retry attempts for HTTP requests that fail
105
+ due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
106
+ retry_backoff_factor: backoff factor for exponential retry delays between HTTP
107
+ request attempts. The delay between retries is calculated using the formula:
108
+ `(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
102
109
  service_name: the service name, only "platform" is supported, the other
103
110
  services "legacy" and "internal" are not supported.
104
111
  """
@@ -110,6 +117,8 @@ class EarthDaily(DataSource, TileStore):
110
117
  self.timeout = timeout
111
118
  self.skip_items_missing_assets = skip_items_missing_assets
112
119
  self.cache_dir = cache_dir
120
+ self.max_retries = max_retries
121
+ self.retry_backoff_factor = retry_backoff_factor
113
122
  self.service_name = service_name
114
123
 
115
124
  if cache_dir is not None:
@@ -139,6 +148,12 @@ class EarthDaily(DataSource, TileStore):
139
148
  if "cache_dir" in d:
140
149
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
141
150
 
151
+ if "max_retries" in d:
152
+ kwargs["max_retries"] = d["max_retries"]
153
+
154
+ if "retry_backoff_factor" in d:
155
+ kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
156
+
142
157
  simple_optionals = ["query", "sort_by", "sort_ascending"]
143
158
  for k in simple_optionals:
144
159
  if k in d:
@@ -159,7 +174,12 @@ class EarthDaily(DataSource, TileStore):
159
174
  if self.eds_client is not None:
160
175
  return self.eds_client, self.client, self.collection
161
176
 
162
- self.eds_client = EDSClient(EDSConfig())
177
+ self.eds_client = EDSClient(
178
+ EDSConfig(
179
+ max_retries=self.max_retries,
180
+ retry_backoff_factor=self.retry_backoff_factor,
181
+ )
182
+ )
163
183
 
164
184
  if self.service_name == "platform":
165
185
  self.client = self.eds_client.platform.pystac_client
@@ -26,7 +26,7 @@ from rslearn.data_sources.utils import match_candidate_items_to_window
26
26
  from rslearn.log_utils import get_logger
27
27
  from rslearn.tile_stores import TileStoreWithLayer
28
28
  from rslearn.utils.fsspec import join_upath, open_atomic
29
- from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_prime_meridian
29
+ from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_antimeridian
30
30
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
31
31
 
32
32
  from .copernicus import get_harmonize_callback, get_sentinel2_tiles
@@ -358,7 +358,7 @@ class Sentinel2(DataSource):
358
358
  shp = shapely.box(*bounds)
359
359
  sensing_time = row["sensing_time"]
360
360
  geometry = STGeometry(WGS84_PROJECTION, shp, (sensing_time, sensing_time))
361
- geometry = split_at_prime_meridian(geometry)
361
+ geometry = split_at_antimeridian(geometry)
362
362
 
363
363
  cloud_cover = float(row["cloud_cover"])
364
364
 
@@ -511,7 +511,7 @@ class Sentinel2(DataSource):
511
511
 
512
512
  time_range = (product_xml.start_time, product_xml.start_time)
513
513
  geometry = STGeometry(WGS84_PROJECTION, product_xml.shp, time_range)
514
- geometry = split_at_prime_meridian(geometry)
514
+ geometry = split_at_antimeridian(geometry)
515
515
 
516
516
  # Sometimes the geometry is not valid.
517
517
  # We just apply make_valid on it to correct issues.
@@ -256,23 +256,7 @@ def match_candidate_items_to_window(
256
256
  if item_geom.is_global():
257
257
  item_geom = geometry
258
258
  else:
259
- # Windows are usually smaller than items.
260
- # So we first clip the item to the window bounds in the item's
261
- # projection, then re-project the item to the window's projection.
262
- buffered_window_geom = STGeometry(
263
- geometry.projection,
264
- geometry.shp.buffer(1),
265
- geometry.time_range,
266
- )
267
- window_shp_in_item_proj = buffered_window_geom.to_projection(
268
- item_geom.projection
269
- ).shp
270
- clipped_item_geom = STGeometry(
271
- item_geom.projection,
272
- item_geom.shp.intersection(window_shp_in_item_proj),
273
- item_geom.time_range,
274
- )
275
- item_geom = clipped_item_geom.to_projection(geometry.projection)
259
+ item_geom = item_geom.to_projection(geometry.projection)
276
260
  item_shps.append(item_geom.shp)
277
261
 
278
262
  if query_config.space_mode == SpaceMode.CONTAINS:
@@ -13,6 +13,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
13
13
  from rasterio.crs import CRS
14
14
  from upath import UPath
15
15
 
16
+ from rslearn.arg_parser import RslearnArgumentParser
16
17
  from rslearn.config import LayerConfig
17
18
  from rslearn.const import WGS84_EPSG
18
19
  from rslearn.data_sources import Item, data_source_from_config
@@ -779,7 +780,7 @@ def dataset_build_index() -> None:
779
780
 
780
781
 
781
782
  class RslearnLightningCLI(LightningCLI):
782
- """LightningCLI that links data.tasks to model.tasks."""
783
+ """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
783
784
 
784
785
  def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
785
786
  """Link data.tasks to model.tasks.
@@ -787,6 +788,7 @@ class RslearnLightningCLI(LightningCLI):
787
788
  Args:
788
789
  parser: the argument parser
789
790
  """
791
+ # Link data.tasks to model.tasks
790
792
  parser.link_arguments(
791
793
  "data.init_args.task", "model.init_args.task", apply_on="instantiate"
792
794
  )
@@ -815,6 +817,12 @@ class RslearnLightningCLI(LightningCLI):
815
817
  # sampler as needed.
816
818
  c.trainer.use_distributed_sampler = False
817
819
 
820
+ # For predict, make sure that return_predictions is False.
821
+ # Otherwise all the predictions would be stored in memory which can lead to
822
+ # high memory consumption.
823
+ if subcommand == "predict":
824
+ c.return_predictions = False
825
+
818
826
 
819
827
  def model_handler() -> None:
820
828
  """Handler for any rslearn model X commands."""
@@ -825,6 +833,7 @@ def model_handler() -> None:
825
833
  subclass_mode_model=True,
826
834
  subclass_mode_data=True,
827
835
  save_config_kwargs={"overwrite": True},
836
+ parser_class=RslearnArgumentParser,
828
837
  )
829
838
 
830
839
 
@@ -0,0 +1,136 @@
1
+ """Trunk module for decoder."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from rslearn.log_utils import get_logger
9
+ from rslearn.models.task_embedding import BaseTaskEmbedding
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class DecoderTrunkLayer(torch.nn.Module, ABC):
15
+ """Trunk layer for decoder."""
16
+
17
+ def __init__(self) -> None:
18
+ """Initialize the DecoderTrunkLayer module."""
19
+ super().__init__()
20
+
21
+ @abstractmethod
22
+ def forward(
23
+ self, x: torch.Tensor, task_embedding: torch.Tensor | None = None
24
+ ) -> dict[str, torch.Tensor]:
25
+ """Forward pass.
26
+
27
+ Args:
28
+ x: input tensor of shape (batch_size, seq_len, dim)
29
+ task_embedding: task embedding tensor of shape (batch_size, dim), or None
30
+
31
+ Returns:
32
+ dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
33
+ and optionally other keys.
34
+ """
35
+
36
+ @abstractmethod
37
+ def apply_auxiliary_losses(
38
+ self, trunk_out: dict[str, Any], outs: dict[str, Any]
39
+ ) -> None:
40
+ """Apply auxiliary losses in-place.
41
+
42
+ Args:
43
+ trunk_out: The output of the trunk.
44
+ outs: The output of the decoders, with key "loss_dict" containing the losses.
45
+ """
46
+
47
+
48
+ class DecoderTrunk(torch.nn.Module):
49
+ """Trunk module for decoder, including arbitrary layers plus an optional task embedding."""
50
+
51
+ def __init__(
52
+ self,
53
+ task_embedding: BaseTaskEmbedding | None = None,
54
+ layers: list[DecoderTrunkLayer] | None = None,
55
+ ) -> None:
56
+ """Initialize the DecoderTrunk module.
57
+
58
+ Args:
59
+ task_embedding: Task-specific embedding module, or None if not using task embedding.
60
+ layers: List of other shared layers. The first one should expect a
61
+ B x T x C tensor, and the last should output a B x T x C tensor.
62
+ All layers must output a dict with key "outputs" (output tensor of shape
63
+ (B, T, C)) and optionally other keys.
64
+ """
65
+ super().__init__()
66
+ self.layers = torch.nn.ModuleList(layers or [])
67
+ self.task_embedding = task_embedding
68
+
69
+ # If we have multiple instances of the same layer class, output keys will get overwritten
70
+ if layers is not None:
71
+ types = [type(layer) for layer in layers]
72
+ if len(set(types)) != len(types):
73
+ logger.warning(
74
+ "Multiple instances of the same layer class found in trunk. "
75
+ "Only the keys from the last instance will be used"
76
+ )
77
+
78
+ def register_tasks(self, task_names: list[str]) -> None:
79
+ """Register tasks.
80
+
81
+ Args:
82
+ task_names: list of task names
83
+ """
84
+ if self.task_embedding is not None:
85
+ self.task_embedding.register_tasks(task_names)
86
+
87
+ def forward(
88
+ self,
89
+ features: list[torch.tensor],
90
+ inputs: list[dict[str, Any]],
91
+ ) -> dict[str, Any]:
92
+ """Forward pass.
93
+
94
+ Args:
95
+ features: The encoder features, a 1-list of B x C x H x W features.
96
+ inputs: The original inputs to the encoder.
97
+
98
+ Returns:
99
+ dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
100
+ and optionally other keys from the other layers.
101
+ """
102
+ embeds = None
103
+ if self.task_embedding is not None:
104
+ embeds = self.task_embedding.compute_embeds(features, inputs)
105
+ features = self.task_embedding(features, inputs, embeds=embeds)
106
+
107
+ if not self.layers:
108
+ return {"outputs": features}
109
+
110
+ assert len(features) == 1, "DecoderTrunk only supports one feature map"
111
+ x = torch.einsum("bchw->bhwc", features[0])
112
+ x = torch.flatten(x, start_dim=1, end_dim=2) # B x T x C, T = HW
113
+ out = {}
114
+ for layer in self.layers:
115
+ layer_out = layer(x, task_embedding=embeds)
116
+ x = layer_out.pop("outputs") # unspecified shape
117
+ out.update(layer_out)
118
+ x = torch.einsum("btc->bct", x) # B x C x T
119
+ x = x.view(*features[0].shape) # B x C x H x W
120
+
121
+ out["outputs"] = [x]
122
+ return out
123
+
124
+ def apply_auxiliary_losses(
125
+ self, trunk_out: dict[str, Any], outs: dict[str, Any]
126
+ ) -> None:
127
+ """Apply auxiliary losses in-place.
128
+
129
+ Each layer handles its own auxiliary losses, assuming the loss key is `loss_dict`.
130
+
131
+ Args:
132
+ trunk_out: The output of the trunk.
133
+ outs: The output of the decoders, with key "loss_dict" containing the losses.
134
+ """
135
+ for layer in self.layers:
136
+ layer.apply_auxiliary_losses(trunk_out, outs)
@@ -0,0 +1,53 @@
1
+ """Callback to activate/deactivate adapter layers."""
2
+
3
+ from typing import Any
4
+
5
+ from lightning.pytorch import LightningModule
6
+ from lightning.pytorch.callbacks import Callback
7
+ from lightning.pytorch.trainer import Trainer
8
+
9
+ from rslearn.log_utils import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class ActivateLayers(Callback):
15
+ """Activates adapter layers on a given epoch.
16
+
17
+ By default, at every epoch, every adapter layer is deactivated.
18
+ To activate an adapter layer, add a selector with the name of the adapter layer
19
+ and the epoch at which to activate it. Once an adapter layer is activated, it
20
+ remains active until the end of training.
21
+ """
22
+
23
+ def __init__(self, selectors: list[dict[str, Any]]) -> None:
24
+ """Initialize the callback.
25
+
26
+ Args:
27
+ selectors: List of selectors to activate.
28
+ Each selector is a dictionary with the following keys:
29
+ - "name": Substring selector of modules to activate (str).
30
+ - "at_epoch": The epoch at which to activate (int).
31
+ """
32
+ self.selectors = selectors
33
+
34
+ def on_train_epoch_start(
35
+ self,
36
+ trainer: Trainer,
37
+ pl_module: LightningModule,
38
+ ) -> None:
39
+ """Activate adapter layers on a given epoch.
40
+
41
+ Adapter layers are activated/deactivated by setting the `active` attribute.
42
+
43
+ Args:
44
+ trainer: The trainer object.
45
+ pl_module: The LightningModule object.
46
+ """
47
+ status = {}
48
+ for name, module in pl_module.named_modules():
49
+ for selector in self.selectors:
50
+ if selector["name"] in name:
51
+ module.active = trainer.current_epoch >= selector["at_epoch"]
52
+ status[selector["name"]] = "active" if module.active else "inactive"
53
+ logger.info(f"Updated adapter status: {status}")