rslearn 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rslearn/arg_parser.py ADDED
@@ -0,0 +1,59 @@
1
+ """Custom Lightning ArgumentParser with environment variable substitution support."""
2
+
3
+ import os
4
+ import re
5
+ from typing import Any
6
+
7
+ from jsonargparse import Namespace
8
+ from lightning.pytorch.cli import LightningArgumentParser
9
+
10
+
11
+ def substitute_env_vars_in_string(content: str) -> str:
12
+ """Substitute environment variables in a string.
13
+
14
+ Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
15
+ This works on raw string content before YAML parsing.
16
+
17
+ Args:
18
+ content: The string content containing template variables
19
+
20
+ Returns:
21
+ The string with environment variables substituted
22
+ """
23
+ pattern = r"\$\{([^}]+)\}"
24
+
25
+ def replace_variable(match_obj: re.Match[str]) -> str:
26
+ var_name = match_obj.group(1)
27
+ env_value = os.getenv(var_name, "")
28
+ return env_value if env_value is not None else ""
29
+
30
+ return re.sub(pattern, replace_variable, content)
31
+
32
+
33
+ class RslearnArgumentParser(LightningArgumentParser):
34
+ """Custom ArgumentParser that substitutes environment variables in config files.
35
+
36
+ This parser extends LightningArgumentParser to automatically substitute
37
+ ${VAR_NAME} patterns with environment variable values before parsing
38
+ configuration content. This allows config files to use environment
39
+ variables while still passing Lightning's validation.
40
+ """
41
+
42
+ def parse_string(
43
+ self,
44
+ cfg_str: str,
45
+ cfg_path: str | os.PathLike = "",
46
+ ext_vars: dict | None = None,
47
+ env: bool | None = None,
48
+ defaults: bool = True,
49
+ with_meta: bool | None = None,
50
+ **kwargs: Any,
51
+ ) -> Namespace:
52
+ """Pre-processes string for environment variable substitution before parsing."""
53
+ # Substitute environment variables in the config string before parsing
54
+ substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
55
+
56
+ # Call the parent method with the substituted config
57
+ return super().parse_string(
58
+ substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
59
+ )
@@ -34,7 +34,7 @@ from rslearn.utils.geometry import (
34
34
  FloatBounds,
35
35
  STGeometry,
36
36
  flatten_shape,
37
- split_shape_at_prime_meridian,
37
+ split_shape_at_antimeridian,
38
38
  )
39
39
  from rslearn.utils.grid_index import GridIndex
40
40
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
@@ -160,7 +160,7 @@ def get_sentinel2_tile_index() -> dict[str, list[FloatBounds]]:
160
160
  # issues where the tile bounds go from -180 to 180 longitude and thus match
161
161
  # with anything at the same latitude.
162
162
  union_shp = shapely.unary_union(shapes)
163
- split_shapes = flatten_shape(split_shape_at_prime_meridian(union_shp))
163
+ split_shapes = flatten_shape(split_shape_at_antimeridian(union_shp))
164
164
  bounds_list: list[FloatBounds] = []
165
165
  for shp in split_shapes:
166
166
  bounds_list.append(shp.bounds)
@@ -222,10 +222,10 @@ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
222
222
  """
223
223
  tile_index = load_sentinel2_tile_index(cache_dir)
224
224
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
225
- # If the shape is a collection, it could be cutting across prime meridian.
225
+ # If the shape is a collection, it could be cutting across antimeridian.
226
226
  # So we query each component shape separately and collect the results to avoid
227
227
  # issues.
228
- # We assume the caller has already applied split_at_prime_meridian.
228
+ # We assume the caller has already applied split_at_antimeridian.
229
229
  results = set()
230
230
  for shp in flatten_shape(wgs84_geometry.shp):
231
231
  for result in tile_index.query(shp.bounds):
@@ -82,6 +82,8 @@ class EarthDaily(DataSource, TileStore):
82
82
  timeout: timedelta = timedelta(seconds=10),
83
83
  skip_items_missing_assets: bool = False,
84
84
  cache_dir: UPath | None = None,
85
+ max_retries: int = 3,
86
+ retry_backoff_factor: float = 5.0,
85
87
  service_name: Literal["platform"] = "platform",
86
88
  ):
87
89
  """Initialize a new EarthDaily instance.
@@ -99,6 +101,11 @@ class EarthDaily(DataSource, TileStore):
99
101
  cache_dir: optional directory to cache items by name, including asset URLs.
100
102
  If not set, there will be no cache and instead STAC requests will be
101
103
  needed each time.
104
+ max_retries: the maximum number of retry attempts for HTTP requests that fail
105
+ due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
106
+ retry_backoff_factor: backoff factor for exponential retry delays between HTTP
107
+ request attempts. The delay between retries is calculated using the formula:
108
+ `(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
102
109
  service_name: the service name, only "platform" is supported, the other
103
110
  services "legacy" and "internal" are not supported.
104
111
  """
@@ -110,6 +117,8 @@ class EarthDaily(DataSource, TileStore):
110
117
  self.timeout = timeout
111
118
  self.skip_items_missing_assets = skip_items_missing_assets
112
119
  self.cache_dir = cache_dir
120
+ self.max_retries = max_retries
121
+ self.retry_backoff_factor = retry_backoff_factor
113
122
  self.service_name = service_name
114
123
 
115
124
  if cache_dir is not None:
@@ -139,6 +148,12 @@ class EarthDaily(DataSource, TileStore):
139
148
  if "cache_dir" in d:
140
149
  kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
141
150
 
151
+ if "max_retries" in d:
152
+ kwargs["max_retries"] = d["max_retries"]
153
+
154
+ if "retry_backoff_factor" in d:
155
+ kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
156
+
142
157
  simple_optionals = ["query", "sort_by", "sort_ascending"]
143
158
  for k in simple_optionals:
144
159
  if k in d:
@@ -159,7 +174,12 @@ class EarthDaily(DataSource, TileStore):
159
174
  if self.eds_client is not None:
160
175
  return self.eds_client, self.client, self.collection
161
176
 
162
- self.eds_client = EDSClient(EDSConfig())
177
+ self.eds_client = EDSClient(
178
+ EDSConfig(
179
+ max_retries=self.max_retries,
180
+ retry_backoff_factor=self.retry_backoff_factor,
181
+ )
182
+ )
163
183
 
164
184
  if self.service_name == "platform":
165
185
  self.client = self.eds_client.platform.pystac_client
@@ -26,7 +26,7 @@ from rslearn.data_sources.utils import match_candidate_items_to_window
26
26
  from rslearn.log_utils import get_logger
27
27
  from rslearn.tile_stores import TileStoreWithLayer
28
28
  from rslearn.utils.fsspec import join_upath, open_atomic
29
- from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_prime_meridian
29
+ from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_antimeridian
30
30
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
31
31
 
32
32
  from .copernicus import get_harmonize_callback, get_sentinel2_tiles
@@ -358,7 +358,7 @@ class Sentinel2(DataSource):
358
358
  shp = shapely.box(*bounds)
359
359
  sensing_time = row["sensing_time"]
360
360
  geometry = STGeometry(WGS84_PROJECTION, shp, (sensing_time, sensing_time))
361
- geometry = split_at_prime_meridian(geometry)
361
+ geometry = split_at_antimeridian(geometry)
362
362
 
363
363
  cloud_cover = float(row["cloud_cover"])
364
364
 
@@ -511,7 +511,7 @@ class Sentinel2(DataSource):
511
511
 
512
512
  time_range = (product_xml.start_time, product_xml.start_time)
513
513
  geometry = STGeometry(WGS84_PROJECTION, product_xml.shp, time_range)
514
- geometry = split_at_prime_meridian(geometry)
514
+ geometry = split_at_antimeridian(geometry)
515
515
 
516
516
  # Sometimes the geometry is not valid.
517
517
  # We just apply make_valid on it to correct issues.
@@ -256,23 +256,7 @@ def match_candidate_items_to_window(
256
256
  if item_geom.is_global():
257
257
  item_geom = geometry
258
258
  else:
259
- # Windows are usually smaller than items.
260
- # So we first clip the item to the window bounds in the item's
261
- # projection, then re-project the item to the window's projection.
262
- buffered_window_geom = STGeometry(
263
- geometry.projection,
264
- geometry.shp.buffer(1),
265
- geometry.time_range,
266
- )
267
- window_shp_in_item_proj = buffered_window_geom.to_projection(
268
- item_geom.projection
269
- ).shp
270
- clipped_item_geom = STGeometry(
271
- item_geom.projection,
272
- item_geom.shp.intersection(window_shp_in_item_proj),
273
- item_geom.time_range,
274
- )
275
- item_geom = clipped_item_geom.to_projection(geometry.projection)
259
+ item_geom = item_geom.to_projection(geometry.projection)
276
260
  item_shps.append(item_geom.shp)
277
261
 
278
262
  if query_config.space_mode == SpaceMode.CONTAINS:
rslearn/main.py CHANGED
@@ -13,6 +13,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
13
13
  from rasterio.crs import CRS
14
14
  from upath import UPath
15
15
 
16
+ from rslearn.arg_parser import RslearnArgumentParser
16
17
  from rslearn.config import LayerConfig
17
18
  from rslearn.const import WGS84_EPSG
18
19
  from rslearn.data_sources import Item, data_source_from_config
@@ -779,7 +780,7 @@ def dataset_build_index() -> None:
779
780
 
780
781
 
781
782
  class RslearnLightningCLI(LightningCLI):
782
- """LightningCLI that links data.tasks to model.tasks."""
783
+ """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
783
784
 
784
785
  def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
785
786
  """Link data.tasks to model.tasks.
@@ -787,6 +788,7 @@ class RslearnLightningCLI(LightningCLI):
787
788
  Args:
788
789
  parser: the argument parser
789
790
  """
791
+ # Link data.tasks to model.tasks
790
792
  parser.link_arguments(
791
793
  "data.init_args.task", "model.init_args.task", apply_on="instantiate"
792
794
  )
@@ -815,6 +817,12 @@ class RslearnLightningCLI(LightningCLI):
815
817
  # sampler as needed.
816
818
  c.trainer.use_distributed_sampler = False
817
819
 
820
+ # For predict, make sure that return_predictions is False.
821
+ # Otherwise all the predictions would be stored in memory which can lead to
822
+ # high memory consumption.
823
+ if subcommand == "predict":
824
+ c.return_predictions = False
825
+
818
826
 
819
827
  def model_handler() -> None:
820
828
  """Handler for any rslearn model X commands."""
@@ -825,6 +833,7 @@ def model_handler() -> None:
825
833
  subclass_mode_model=True,
826
834
  subclass_mode_data=True,
827
835
  save_config_kwargs={"overwrite": True},
836
+ parser_class=RslearnArgumentParser,
828
837
  )
829
838
 
830
839
 
rslearn/models/trunk.py CHANGED
@@ -6,7 +6,6 @@ from typing import Any
6
6
  import torch
7
7
 
8
8
  from rslearn.log_utils import get_logger
9
- from rslearn.models.moe.soft import SoftMoE
10
9
  from rslearn.models.task_embedding import BaseTaskEmbedding
11
10
 
12
11
  logger = get_logger(__name__)
@@ -135,146 +134,3 @@ class DecoderTrunk(torch.nn.Module):
135
134
  """
136
135
  for layer in self.layers:
137
136
  layer.apply_auxiliary_losses(trunk_out, outs)
138
-
139
-
140
- class MoETransformer(DecoderTrunkLayer):
141
- """Transformer for decoder trunk."""
142
-
143
- def __init__(
144
- self,
145
- dim: int,
146
- n_layers: int,
147
- n_heads: int,
148
- mlp_dim: int = 512,
149
- dropout: float = 0.1,
150
- task_moe: bool = False,
151
- disable_moe: bool = False,
152
- num_experts: int = 16,
153
- num_slots: int = 256,
154
- expert_mult: int = 4,
155
- load_balance_loss_weight: float = 0.0,
156
- ):
157
- """Standard ViT-style transformer, with soft MoE.
158
-
159
- Since the point of the MoE layers is to deal with task-specific and task-shared
160
- features (and not to route specific tokens), it's probably best to use max_seq_len
161
- as the number of slots, and have at least one expert per task (probably more).
162
-
163
- Args:
164
- dim: dimension of the input and output
165
- n_layers: number of transformer blocks
166
- n_heads: number of attention heads
167
- mlp_dim: dimension of the MLP
168
- dropout: dropout rate
169
- task_moe: if specified, compute dispatch weights given the task embedding
170
- only, and not the token
171
- disable_moe: if True, disable MoE
172
- num_experts: number of experts in soft MoE
173
- num_slots: number of slots in soft MoE
174
- expert_mult: factor by which to multiply mlp_dim in the hidden layer of experts
175
- load_balance_loss_weight: weight of the load balance loss
176
- """
177
- super().__init__()
178
- self.disable_moe = disable_moe
179
- self.num_experts = num_experts
180
- self.num_slots = num_slots
181
- self.task_moe = task_moe
182
- self.load_balance_loss_weight = load_balance_loss_weight
183
- self.norm = torch.nn.LayerNorm(dim)
184
- self.layers = torch.nn.ModuleList([])
185
- for _ in range(n_layers):
186
- mha = torch.nn.MultiheadAttention(
187
- dim, n_heads, dropout=dropout, batch_first=True
188
- )
189
- if not disable_moe:
190
- ffn = SoftMoE(
191
- dim=dim,
192
- num_experts=num_experts,
193
- num_slots=num_slots,
194
- dropout=dropout,
195
- expert_mult=expert_mult,
196
- )
197
- else:
198
- ffn = torch.nn.Sequential(
199
- torch.nn.LayerNorm(dim),
200
- torch.nn.Linear(dim, mlp_dim),
201
- torch.nn.GELU(),
202
- torch.nn.Linear(mlp_dim, dim),
203
- )
204
- drop = torch.nn.Dropout(dropout)
205
- self.layers.append(torch.nn.ModuleList([mha, ffn, drop]))
206
-
207
- def forward(
208
- self, x: torch.Tensor, task_embedding: torch.Tensor | None = None
209
- ) -> dict[str, torch.Tensor]:
210
- """Forward pass.
211
-
212
- Args:
213
- x: input tensor of shape (batch_size, seq_len, dim)
214
- task_embedding: task embedding tensor of shape (batch_size, dim)
215
-
216
- Returns:
217
- dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
218
- and optionally "load_balance_loss", "dispatch_weights", and "combine_weights".
219
- """
220
- # Forward pass through the transformer
221
- infos: list[dict[str, Any]] = []
222
- for mha, ffn, drop in self.layers:
223
- x = mha(x, x, x)[0] + x
224
- if not self.disable_moe:
225
- outs = ffn(x, weight_key=task_embedding if self.task_moe else None)
226
- x_ffn = outs.pop("outputs")
227
- infos.append(outs)
228
- x = drop(x_ffn + x)
229
- else:
230
- x = drop(ffn(x) + x)
231
- x = self.norm(x)
232
- outputs = {"outputs": x}
233
-
234
- # If using MoE, collect expert weights and auxiliary losses
235
- # Don't call detach because we will use this later on in the loss collation
236
- if not self.disable_moe:
237
- collated: dict[str, list[torch.Tensor]] = {
238
- "load_balance_loss": [],
239
- "dispatch_weights": [],
240
- "combine_weights": [],
241
- }
242
- for info in infos:
243
- for k, v in info.items():
244
- if k == "dispatch_weights":
245
- # each weight is [batch, seq_len, num_experts, num_slots]
246
- # compute avg weight per token across slot/batch/expert
247
- # NOTE: this is probably about the same across all tokens,
248
- # assuming all tokens get looked at by a few experts
249
- collated["dispatch_weights"].append(v.mean((0, 2, 3)))
250
-
251
- elif k == "combine_weights":
252
- # each weight is [batch, seq_len, num_experts * num_slots]
253
- # compute avg weight per expert (slot group) across batch/seq
254
- v = v.unflatten(-1, (self.num_experts, self.num_slots))
255
- v = v.sum(-1) # [batch, seq_len, num_experts (softmax)]
256
- collated["combine_weights"].append(v.mean((0, 1)))
257
-
258
- elif k == "load_balance_loss":
259
- # each load balance loss per layer is a scalar
260
- collated["load_balance_loss"].append(v)
261
- outputs.update(collated)
262
-
263
- return outputs
264
-
265
- def apply_auxiliary_losses(
266
- self, trunk_out: dict[str, Any], outs: dict[str, Any]
267
- ) -> None:
268
- """Apply auxiliary losses in-place.
269
-
270
- Just move the load balance loss to the loss dict, where it will eventually be summed.
271
-
272
- Args:
273
- trunk_out: The output of the trunk.
274
- outs: The output of the decoders, with key "loss_dict" containing the losses.
275
- """
276
- if "load_balance_loss" in trunk_out and self.load_balance_loss_weight > 0.0:
277
- total_aux_loss = torch.stack(trunk_out["load_balance_loss"]).mean()
278
- outs["loss_dict"]["load_balance_loss"] = (
279
- self.load_balance_loss_weight * total_aux_loss
280
- )
@@ -0,0 +1,53 @@
1
+ """Callback to activate/deactivate adapter layers."""
2
+
3
+ from typing import Any
4
+
5
+ from lightning.pytorch import LightningModule
6
+ from lightning.pytorch.callbacks import Callback
7
+ from lightning.pytorch.trainer import Trainer
8
+
9
+ from rslearn.log_utils import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class ActivateLayers(Callback):
15
+ """Activates adapter layers on a given epoch.
16
+
17
+ By default, at every epoch, every adapter layer is deactivated.
18
+ To activate an adapter layer, add a selector with the name of the adapter layer
19
+ and the epoch at which to activate it. Once an adapter layer is activated, it
20
+ remains active until the end of training.
21
+ """
22
+
23
+ def __init__(self, selectors: list[dict[str, Any]]) -> None:
24
+ """Initialize the callback.
25
+
26
+ Args:
27
+ selectors: List of selectors to activate.
28
+ Each selector is a dictionary with the following keys:
29
+ - "name": Substring selector of modules to activate (str).
30
+ - "at_epoch": The epoch at which to activate (int).
31
+ """
32
+ self.selectors = selectors
33
+
34
+ def on_train_epoch_start(
35
+ self,
36
+ trainer: Trainer,
37
+ pl_module: LightningModule,
38
+ ) -> None:
39
+ """Activate adapter layers on a given epoch.
40
+
41
+ Adapter layers are activated/deactivated by setting the `active` attribute.
42
+
43
+ Args:
44
+ trainer: The trainer object.
45
+ pl_module: The LightningModule object.
46
+ """
47
+ status = {}
48
+ for name, module in pl_module.named_modules():
49
+ for selector in self.selectors:
50
+ if selector["name"] in name:
51
+ module.active = trainer.current_epoch >= selector["at_epoch"]
52
+ status[selector["name"]] = "active" if module.active else "inactive"
53
+ logger.info(f"Updated adapter status: {status}")