rslearn 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. rslearn/arg_parser.py +1 -22
  2. rslearn/dataset/dataset.py +4 -1
  3. rslearn/models/anysat.py +207 -0
  4. rslearn/models/clay/clay.py +204 -0
  5. rslearn/models/clay/configs/metadata.yaml +295 -0
  6. rslearn/models/galileo/__init__.py +5 -0
  7. rslearn/models/galileo/galileo.py +517 -0
  8. rslearn/models/galileo/single_file_galileo.py +1672 -0
  9. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  10. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  11. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  12. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  13. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  14. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  15. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  16. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  17. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  18. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  19. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  20. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  21. rslearn/models/presto/presto.py +10 -7
  22. rslearn/models/prithvi.py +1046 -0
  23. rslearn/models/unet.py +17 -11
  24. rslearn/template_params.py +26 -0
  25. rslearn/utils/geometry.py +61 -1
  26. rslearn/utils/vector_format.py +13 -10
  27. {rslearn-0.0.5.dist-info → rslearn-0.0.7.dist-info}/METADATA +145 -15
  28. {rslearn-0.0.5.dist-info → rslearn-0.0.7.dist-info}/RECORD +32 -12
  29. {rslearn-0.0.5.dist-info → rslearn-0.0.7.dist-info}/WHEEL +0 -0
  30. {rslearn-0.0.5.dist-info → rslearn-0.0.7.dist-info}/entry_points.txt +0 -0
  31. {rslearn-0.0.5.dist-info → rslearn-0.0.7.dist-info}/licenses/LICENSE +0 -0
  32. {rslearn-0.0.5.dist-info → rslearn-0.0.7.dist-info}/top_level.txt +0 -0
rslearn/arg_parser.py CHANGED
@@ -1,33 +1,12 @@
1
1
  """Custom Lightning ArgumentParser with environment variable substitution support."""
2
2
 
3
3
  import os
4
- import re
5
4
  from typing import Any
6
5
 
7
6
  from jsonargparse import Namespace
8
7
  from lightning.pytorch.cli import LightningArgumentParser
9
8
 
10
-
11
- def substitute_env_vars_in_string(content: str) -> str:
12
- """Substitute environment variables in a string.
13
-
14
- Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
15
- This works on raw string content before YAML parsing.
16
-
17
- Args:
18
- content: The string content containing template variables
19
-
20
- Returns:
21
- The string with environment variables substituted
22
- """
23
- pattern = r"\$\{([^}]+)\}"
24
-
25
- def replace_variable(match_obj: re.Match[str]) -> str:
26
- var_name = match_obj.group(1)
27
- env_value = os.getenv(var_name, "")
28
- return env_value if env_value is not None else ""
29
-
30
- return re.sub(pattern, replace_variable, content)
9
+ from rslearn.template_params import substitute_env_vars_in_string
31
10
 
32
11
 
33
12
  class RslearnArgumentParser(LightningArgumentParser):
@@ -8,6 +8,7 @@ from upath import UPath
8
8
 
9
9
  from rslearn.config import load_layer_config
10
10
  from rslearn.log_utils import get_logger
11
+ from rslearn.template_params import substitute_env_vars_in_string
11
12
  from rslearn.tile_stores import TileStore, load_tile_store
12
13
 
13
14
  from .index import DatasetIndex
@@ -52,7 +53,9 @@ class Dataset:
52
53
 
53
54
  # Load dataset configuration.
54
55
  with (self.path / "config.json").open("r") as f:
55
- config = json.load(f)
56
+ config_content = f.read()
57
+ config_content = substitute_env_vars_in_string(config_content)
58
+ config = json.loads(config_content)
56
59
  self.layers = {}
57
60
  for layer_name, d in config["layers"].items():
58
61
  # Layer names must not contain period, since we use period to
@@ -0,0 +1,207 @@
1
+ """AnySat model."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+ # AnySat github: https://github.com/gastruc/AnySat
9
+ # Modalities and expected resolutions (meters)
10
+ MODALITY_RESOLUTIONS: dict[str, float] = {
11
+ "aerial": 0.2,
12
+ "aerial-flair": 0.2,
13
+ "spot": 1,
14
+ "naip": 1.25,
15
+ "s2": 10,
16
+ "s1-asc": 10,
17
+ "s1": 10,
18
+ "alos": 30,
19
+ "l7": 30,
20
+ "l8": 10, # L8 must be upsampled to 10 m in AnySat
21
+ "modis": 250,
22
+ }
23
+
24
+ # Modalities and expected band names
25
+ MODALITY_BANDS: dict[str, list[str]] = {
26
+ "aerial": ["R", "G", "B", "NiR"],
27
+ "aerial-flair": ["R", "G", "B", "NiR", "Elevation"],
28
+ "spot": ["R", "G", "B"],
29
+ "naip": ["R", "G", "B", "NiR"],
30
+ "s2": ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8a", "B11", "B12"],
31
+ "s1-asc": ["VV", "VH"],
32
+ "s1": ["VV", "VH", "Ratio"],
33
+ "alos": ["HH", "HV", "Ratio"],
34
+ "l7": ["B1", "B2", "B3", "B4", "B5", "B7"],
35
+ "l8": ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"],
36
+ "modis": ["B1", "B2", "B3", "B4", "B5", "B6", "B7"],
37
+ }
38
+
39
+ # Modalities that require *_dates* input
40
+ TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
41
+
42
+
43
+ class AnySat(torch.nn.Module):
44
+ """AnySat backbone (outputs one feature map)."""
45
+
46
+ def __init__(
47
+ self,
48
+ modalities: list[str],
49
+ patch_size_meters: int,
50
+ dates: dict[str, list[int]],
51
+ output: str = "patch",
52
+ output_modality: str | None = None,
53
+ hub_repo: str = "gastruc/anysat",
54
+ pretrained: bool = True,
55
+ force_reload: bool = False,
56
+ flash_attn: bool = False,
57
+ ) -> None:
58
+ """Initialize an AnySat model.
59
+
60
+ Args:
61
+ modalities: list of modalities to use as input (1 or more).
62
+ patch_size_meters: patch size in meters (must be multiple of 10). Avoid having more than 1024 patches per tile
63
+ ie, the height/width in meters should be <= 32 * patch_size_meters.
64
+ dates: dict mapping time-series modalities to list of dates (day number in a year, 0-255).
65
+ output: 'patch' (default) or 'dense'. Use 'patch' for classification tasks,
66
+ 'dense' for segmentation tasks.
67
+ output_modality: required if output='dense', specifies which modality to use
68
+ for the dense output (one of the input modalities).
69
+ hub_repo: torch.hub repository to load AnySat from.
70
+ pretrained: whether to load pretrained weights.
71
+ force_reload: whether to force re-download of the model.
72
+ flash_attn: whether to use flash attention (if available).
73
+ """
74
+ super().__init__()
75
+
76
+ if not modalities:
77
+ raise ValueError("At least one modality must be specified.")
78
+ for m in modalities:
79
+ if m not in MODALITY_RESOLUTIONS:
80
+ raise ValueError(f"Invalid modality: {m}")
81
+
82
+ if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
83
+ raise ValueError("`dates` keys must be time-series modalities only.")
84
+ for m in modalities:
85
+ if m in TIME_SERIES_MODALITIES and m not in dates:
86
+ raise ValueError(
87
+ f"Missing required dates for time-series modality '{m}'."
88
+ )
89
+
90
+ if patch_size_meters % 10 != 0:
91
+ raise ValueError(
92
+ "In AnySat, `patch_size` is in meters and must be a multiple of 10."
93
+ )
94
+
95
+ output = output.lower()
96
+ if output not in {"patch", "dense"}:
97
+ raise ValueError("`output` must be 'patch' or 'dense'.")
98
+ if output == "dense" and output_modality is None:
99
+ raise ValueError("`output_modality` is required when output='dense'.")
100
+
101
+ self.modalities = modalities
102
+ self.patch_size_meters = int(patch_size_meters)
103
+ self.dates = dates
104
+ self.output = output
105
+ self.output_modality = output_modality
106
+
107
+ self.model = torch.hub.load( # nosec B614
108
+ hub_repo,
109
+ "anysat",
110
+ pretrained=pretrained,
111
+ force_reload=force_reload,
112
+ flash_attn=flash_attn,
113
+ )
114
+ self._embed_dim = 768 # base width, 'dense' returns 2x
115
+
116
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
117
+ """Forward pass for the AnySat model.
118
+
119
+ Args:
120
+ inputs: input dicts that must include modalities as keys which are defined in the self.modalities list
121
+
122
+ Returns:
123
+ List[torch.Tensor]: Single-scale feature tensors from the encoder.
124
+ """
125
+ if not inputs:
126
+ raise ValueError("empty inputs")
127
+
128
+ batch: dict[str, torch.Tensor] = {}
129
+ spatial_extent: tuple[float, float] | None = None
130
+
131
+ for modality in self.modalities:
132
+ if modality not in inputs[0]:
133
+ raise ValueError(f"Modality '{modality}' not present in inputs.")
134
+
135
+ cur = torch.stack(
136
+ [inp[modality] for inp in inputs], dim=0
137
+ ) # (B, C, H, W) or (B, T*C, H, W)
138
+
139
+ if modality in TIME_SERIES_MODALITIES:
140
+ num_dates = len(self.dates[modality])
141
+ num_bands = cur.shape[1] // num_dates
142
+ cur = rearrange(
143
+ cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
144
+ )
145
+ H, W = cur.shape[-2], cur.shape[-1]
146
+ else:
147
+ num_bands = cur.shape[1]
148
+ H, W = cur.shape[-2], cur.shape[-1]
149
+
150
+ if num_bands != len(MODALITY_BANDS[modality]):
151
+ raise ValueError(
152
+ f"Modality '{modality}' expected {len(MODALITY_BANDS[modality])} bands, "
153
+ f"got {num_bands} (shape {tuple(cur.shape)})"
154
+ )
155
+
156
+ batch[modality] = cur
157
+
158
+ # Ensure same spatial extent across all modalities (H*res, W*res)
159
+ extent = (
160
+ H * MODALITY_RESOLUTIONS[modality],
161
+ W * MODALITY_RESOLUTIONS[modality],
162
+ )
163
+ if spatial_extent is None:
164
+ spatial_extent = extent
165
+ elif spatial_extent != extent:
166
+ raise ValueError(
167
+ "All modalities must share the same spatial extent (H*res, W*res)."
168
+ )
169
+
170
+ # Add *_dates
171
+ to_add = {}
172
+ for modality, x in list(batch.items()):
173
+ if modality in TIME_SERIES_MODALITIES:
174
+ B, T = x.shape[0], x.shape[1]
175
+ d = torch.as_tensor(
176
+ self.dates[modality], dtype=torch.long, device=x.device
177
+ )
178
+ if d.ndim != 1 or d.numel() != T:
179
+ raise ValueError(
180
+ f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
181
+ )
182
+ to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
183
+
184
+ batch.update(to_add)
185
+
186
+ kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
187
+ if self.output == "dense":
188
+ kwargs["output_modality"] = self.output_modality
189
+
190
+ features = self.model(batch, **kwargs)
191
+ return [rearrange(features, "b h w d -> b d h w")]
192
+
193
+ def get_backbone_channels(self) -> list:
194
+ """Returns the output channels of this model when used as a backbone.
195
+
196
+ The output channels is a list of (patch_size, depth) that corresponds
197
+ to the feature maps that the backbone returns.
198
+
199
+ Returns:
200
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
201
+ """
202
+ if self.output == "patch":
203
+ return [(self.patch_size_meters // 10, 768)]
204
+ elif self.output == "dense":
205
+ return [(1, 1536)]
206
+ else:
207
+ raise ValueError(f"invalid output type: {self.output}")
@@ -0,0 +1,204 @@
1
+ """Clay models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from enum import Enum
7
+ from importlib.resources import files
8
+ from typing import Any
9
+
10
+ import torch
11
+ import yaml
12
+ from einops import rearrange
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # from claymodel.module import ClayMAEModule
16
+ from terratorch.models.backbones.clay_v15.module import ClayMAEModule
17
+
18
+ from rslearn.train.transforms.transform import Transform
19
+
20
+
21
+ class ClaySize(str, Enum):
22
+ """Size of the Clay model."""
23
+
24
+ BASE = "base"
25
+ LARGE = "large"
26
+
27
+
28
+ PATCH_SIZE = 8
29
+ CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
30
+ CONFIG_DIR = files("rslearn.models.clay.configs")
31
+ CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
32
+
33
+
34
+ def get_clay_checkpoint_path(
35
+ filename: str = "v1.5/clay-v1.5.ckpt",
36
+ repo_id: str = "made-with-clay/Clay",
37
+ ) -> str:
38
+ """Return a cached local path to the Clay ckpt from the Hugging Face Hub."""
39
+ return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
40
+
41
+
42
+ class Clay(torch.nn.Module):
43
+ """Clay backbones."""
44
+
45
+ def __init__(
46
+ self,
47
+ model_size: ClaySize,
48
+ modality: str = "sentinel-2-l2a",
49
+ checkpoint_path: str | None = None,
50
+ metadata_path: str = CLAY_METADATA_PATH,
51
+ ) -> None:
52
+ """Initialize the Clay model.
53
+
54
+ Args:
55
+ model_size: The size of the Clay model.
56
+ modality: The modality to use (subset of CLAY_MODALITIES).
57
+ checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
58
+ metadata_path: Path to metadata.yaml.
59
+ """
60
+ super().__init__()
61
+
62
+ # Clay only supports single modality input
63
+ if modality not in CLAY_MODALITIES:
64
+ raise ValueError(f"Invalid modality: {modality}")
65
+
66
+ ckpt = checkpoint_path or get_clay_checkpoint_path()
67
+ if model_size == ClaySize.LARGE:
68
+ self.model = ClayMAEModule.load_from_checkpoint(
69
+ checkpoint_path=ckpt,
70
+ model_size="large",
71
+ metadata_path=metadata_path,
72
+ dolls=[16, 32, 64, 128, 256, 768, 1024],
73
+ doll_weights=[1, 1, 1, 1, 1, 1, 1],
74
+ mask_ratio=0.0,
75
+ shuffle=False,
76
+ )
77
+ elif model_size == ClaySize.BASE:
78
+ # Failed to load Base model in Clay v1.5
79
+ raise ValueError("Clay BASE model currently not supported in v1.5.")
80
+ self.model = ClayMAEModule.load_from_checkpoint(
81
+ checkpoint_path=ckpt,
82
+ model_size="base",
83
+ metadata_path=metadata_path,
84
+ dolls=[16, 32, 64, 128, 256, 768],
85
+ doll_weights=[1, 1, 1, 1, 1, 1],
86
+ mask_ratio=0.0,
87
+ shuffle=False,
88
+ )
89
+ else:
90
+ raise ValueError(f"Invalid model size: {model_size}")
91
+
92
+ with open(metadata_path) as f:
93
+ self.metadata = yaml.safe_load(f)
94
+
95
+ self.model_size = model_size
96
+ self.modality = modality
97
+
98
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
99
+ """Forward pass for the Clay model.
100
+
101
+ Args:
102
+ inputs: input dicts that must include `self.modality` as a key
103
+
104
+ Returns:
105
+ List[torch.Tensor]: Single-scale feature tensors from the encoder.
106
+ """
107
+ if self.modality not in inputs[0]:
108
+ raise ValueError(f"Missing modality {self.modality} in inputs.")
109
+
110
+ param = next(self.model.parameters())
111
+ device = param.device
112
+
113
+ chips = torch.stack(
114
+ [inp[self.modality] for inp in inputs], dim=0
115
+ ) # (B, C, H, W)
116
+
117
+ order = self.metadata[self.modality]["band_order"]
118
+ wavelengths = []
119
+ for band in self.metadata[self.modality]["band_order"]:
120
+ wavelengths.append(
121
+ self.metadata[self.modality]["bands"]["wavelength"][band] * 1000
122
+ ) # Convert to nm
123
+ # Check channel count matches Clay expectation
124
+ if chips.shape[1] != len(order):
125
+ raise ValueError(
126
+ f"Channel count {chips.shape[1]} does not match expected {len(order)} for {self.modality}"
127
+ )
128
+
129
+ # Time & latlon zeros are valid per Clay doc
130
+ # https://clay-foundation.github.io/model/getting-started/basic_use.html
131
+ datacube = {
132
+ "platform": self.modality,
133
+ "time": torch.zeros(chips.shape[0], 4).to(device),
134
+ "latlon": torch.zeros(chips.shape[0], 4).to(device),
135
+ "pixels": chips.to(device),
136
+ "gsd": torch.tensor(self.metadata[self.modality]["gsd"]).to(device),
137
+ "waves": torch.tensor(wavelengths).to(device),
138
+ }
139
+
140
+ tokens, *_ = self.model.model.encoder(datacube) # (B, 1 + N, D)
141
+
142
+ # Remove CLS token
143
+ spatial = tokens[:, 1:, :] # (B, N, D)
144
+ n_tokens = spatial.shape[1]
145
+ side = int(math.isqrt(n_tokens))
146
+ if chips.shape[2] != side * PATCH_SIZE or chips.shape[3] != side * PATCH_SIZE:
147
+ raise ValueError(
148
+ f"Input spatial size {(chips.shape[2], chips.shape[3])} is not compatible with patch size {PATCH_SIZE}"
149
+ )
150
+
151
+ features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
152
+ return [features]
153
+
154
+ def get_backbone_channels(self) -> list:
155
+ """Return output channels of this model when used as a backbone."""
156
+ if self.model_size == ClaySize.LARGE:
157
+ depth = 1024
158
+ elif self.model_size == ClaySize.BASE:
159
+ depth = 768
160
+ else:
161
+ raise ValueError(f"Invalid model size: {self.model_size}")
162
+ return [(PATCH_SIZE, depth)]
163
+
164
+
165
+ class ClayNormalize(Transform):
166
+ """Normalize inputs using Clay metadata."""
167
+
168
+ def __init__(self, metadata_path: str = CLAY_METADATA_PATH) -> None:
169
+ """Initialize ClayNormalize."""
170
+ super().__init__()
171
+ with open(metadata_path) as f:
172
+ self.metadata = yaml.safe_load(f)
173
+
174
+ def apply_image(
175
+ self, image: torch.Tensor, means: list[float], stds: list[float]
176
+ ) -> torch.Tensor:
177
+ """Normalize the specified image with Clay normalization."""
178
+ x = image.float()
179
+ if x.shape[0] != len(means):
180
+ raise ValueError(
181
+ f"channel count {x.shape[0]} does not match provided band stats {len(means)}"
182
+ )
183
+ for c in range(x.shape[0]):
184
+ x[c] = (x[c] - means[c]) / stds[c]
185
+ return x
186
+
187
+ def forward(
188
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
189
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
190
+ """Normalize the specified image with Clay normalization."""
191
+ for modality in CLAY_MODALITIES:
192
+ if modality not in input_dict or modality not in self.metadata:
193
+ continue
194
+ modality_metadata = self.metadata[modality]
195
+ means = [
196
+ modality_metadata["bands"]["mean"][b]
197
+ for b in modality_metadata["band_order"]
198
+ ]
199
+ stds = [
200
+ modality_metadata["bands"]["std"][b]
201
+ for b in modality_metadata["band_order"]
202
+ ]
203
+ input_dict[modality] = self.apply_image(input_dict[modality], means, stds)
204
+ return input_dict, target_dict