rslearn 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/dataset/handler_summaries.py +130 -0
- rslearn/dataset/manage.py +157 -22
- rslearn/main.py +60 -8
- rslearn/models/anysat.py +207 -0
- rslearn/models/clay/clay.py +219 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/copernicusfm.py +37 -25
- rslearn/models/dinov3.py +165 -0
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +517 -0
- rslearn/models/galileo/single_file_galileo.py +1672 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/presto/presto.py +10 -7
- rslearn/models/prithvi.py +1122 -0
- rslearn/models/resize_features.py +45 -0
- rslearn/models/simple_time_series.py +65 -10
- rslearn/models/unet.py +17 -11
- rslearn/models/upsample.py +2 -2
- rslearn/tile_stores/default.py +31 -6
- rslearn/train/transforms/normalize.py +34 -5
- rslearn/train/transforms/select_bands.py +67 -0
- rslearn/train/transforms/sentinel1.py +60 -0
- rslearn/utils/geometry.py +61 -1
- rslearn/utils/raster_format.py +7 -1
- rslearn/utils/vector_format.py +13 -10
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/METADATA +144 -15
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/RECORD +42 -18
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/WHEEL +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.8.dist-info}/top_level.txt +0 -0
rslearn/models/anysat.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""AnySat model."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
|
|
8
|
+
# AnySat github: https://github.com/gastruc/AnySat
|
|
9
|
+
# Modalities and expected resolutions (meters)
|
|
10
|
+
MODALITY_RESOLUTIONS: dict[str, float] = {
|
|
11
|
+
"aerial": 0.2,
|
|
12
|
+
"aerial-flair": 0.2,
|
|
13
|
+
"spot": 1,
|
|
14
|
+
"naip": 1.25,
|
|
15
|
+
"s2": 10,
|
|
16
|
+
"s1-asc": 10,
|
|
17
|
+
"s1": 10,
|
|
18
|
+
"alos": 30,
|
|
19
|
+
"l7": 30,
|
|
20
|
+
"l8": 10, # L8 must be upsampled to 10 m in AnySat
|
|
21
|
+
"modis": 250,
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
# Modalities and expected band names
|
|
25
|
+
MODALITY_BANDS: dict[str, list[str]] = {
|
|
26
|
+
"aerial": ["R", "G", "B", "NiR"],
|
|
27
|
+
"aerial-flair": ["R", "G", "B", "NiR", "Elevation"],
|
|
28
|
+
"spot": ["R", "G", "B"],
|
|
29
|
+
"naip": ["R", "G", "B", "NiR"],
|
|
30
|
+
"s2": ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8a", "B11", "B12"],
|
|
31
|
+
"s1-asc": ["VV", "VH"],
|
|
32
|
+
"s1": ["VV", "VH", "Ratio"],
|
|
33
|
+
"alos": ["HH", "HV", "Ratio"],
|
|
34
|
+
"l7": ["B1", "B2", "B3", "B4", "B5", "B7"],
|
|
35
|
+
"l8": ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"],
|
|
36
|
+
"modis": ["B1", "B2", "B3", "B4", "B5", "B6", "B7"],
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Modalities that require *_dates* input
|
|
40
|
+
TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AnySat(torch.nn.Module):
|
|
44
|
+
"""AnySat backbone (outputs one feature map)."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
modalities: list[str],
|
|
49
|
+
patch_size_meters: int,
|
|
50
|
+
dates: dict[str, list[int]],
|
|
51
|
+
output: str = "patch",
|
|
52
|
+
output_modality: str | None = None,
|
|
53
|
+
hub_repo: str = "gastruc/anysat",
|
|
54
|
+
pretrained: bool = True,
|
|
55
|
+
force_reload: bool = False,
|
|
56
|
+
flash_attn: bool = False,
|
|
57
|
+
) -> None:
|
|
58
|
+
"""Initialize an AnySat model.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
modalities: list of modalities to use as input (1 or more).
|
|
62
|
+
patch_size_meters: patch size in meters (must be multiple of 10). Avoid having more than 1024 patches per tile
|
|
63
|
+
ie, the height/width in meters should be <= 32 * patch_size_meters.
|
|
64
|
+
dates: dict mapping time-series modalities to list of dates (day number in a year, 0-255).
|
|
65
|
+
output: 'patch' (default) or 'dense'. Use 'patch' for classification tasks,
|
|
66
|
+
'dense' for segmentation tasks.
|
|
67
|
+
output_modality: required if output='dense', specifies which modality to use
|
|
68
|
+
for the dense output (one of the input modalities).
|
|
69
|
+
hub_repo: torch.hub repository to load AnySat from.
|
|
70
|
+
pretrained: whether to load pretrained weights.
|
|
71
|
+
force_reload: whether to force re-download of the model.
|
|
72
|
+
flash_attn: whether to use flash attention (if available).
|
|
73
|
+
"""
|
|
74
|
+
super().__init__()
|
|
75
|
+
|
|
76
|
+
if not modalities:
|
|
77
|
+
raise ValueError("At least one modality must be specified.")
|
|
78
|
+
for m in modalities:
|
|
79
|
+
if m not in MODALITY_RESOLUTIONS:
|
|
80
|
+
raise ValueError(f"Invalid modality: {m}")
|
|
81
|
+
|
|
82
|
+
if not all(m in TIME_SERIES_MODALITIES for m in dates.keys()):
|
|
83
|
+
raise ValueError("`dates` keys must be time-series modalities only.")
|
|
84
|
+
for m in modalities:
|
|
85
|
+
if m in TIME_SERIES_MODALITIES and m not in dates:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Missing required dates for time-series modality '{m}'."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if patch_size_meters % 10 != 0:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"In AnySat, `patch_size` is in meters and must be a multiple of 10."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
output = output.lower()
|
|
96
|
+
if output not in {"patch", "dense"}:
|
|
97
|
+
raise ValueError("`output` must be 'patch' or 'dense'.")
|
|
98
|
+
if output == "dense" and output_modality is None:
|
|
99
|
+
raise ValueError("`output_modality` is required when output='dense'.")
|
|
100
|
+
|
|
101
|
+
self.modalities = modalities
|
|
102
|
+
self.patch_size_meters = int(patch_size_meters)
|
|
103
|
+
self.dates = dates
|
|
104
|
+
self.output = output
|
|
105
|
+
self.output_modality = output_modality
|
|
106
|
+
|
|
107
|
+
self.model = torch.hub.load( # nosec B614
|
|
108
|
+
hub_repo,
|
|
109
|
+
"anysat",
|
|
110
|
+
pretrained=pretrained,
|
|
111
|
+
force_reload=force_reload,
|
|
112
|
+
flash_attn=flash_attn,
|
|
113
|
+
)
|
|
114
|
+
self._embed_dim = 768 # base width, 'dense' returns 2x
|
|
115
|
+
|
|
116
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
117
|
+
"""Forward pass for the AnySat model.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
inputs: input dicts that must include modalities as keys which are defined in the self.modalities list
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
List[torch.Tensor]: Single-scale feature tensors from the encoder.
|
|
124
|
+
"""
|
|
125
|
+
if not inputs:
|
|
126
|
+
raise ValueError("empty inputs")
|
|
127
|
+
|
|
128
|
+
batch: dict[str, torch.Tensor] = {}
|
|
129
|
+
spatial_extent: tuple[float, float] | None = None
|
|
130
|
+
|
|
131
|
+
for modality in self.modalities:
|
|
132
|
+
if modality not in inputs[0]:
|
|
133
|
+
raise ValueError(f"Modality '{modality}' not present in inputs.")
|
|
134
|
+
|
|
135
|
+
cur = torch.stack(
|
|
136
|
+
[inp[modality] for inp in inputs], dim=0
|
|
137
|
+
) # (B, C, H, W) or (B, T*C, H, W)
|
|
138
|
+
|
|
139
|
+
if modality in TIME_SERIES_MODALITIES:
|
|
140
|
+
num_dates = len(self.dates[modality])
|
|
141
|
+
num_bands = cur.shape[1] // num_dates
|
|
142
|
+
cur = rearrange(
|
|
143
|
+
cur, "b (t c) h w -> b t c h w", t=num_dates, c=num_bands
|
|
144
|
+
)
|
|
145
|
+
H, W = cur.shape[-2], cur.shape[-1]
|
|
146
|
+
else:
|
|
147
|
+
num_bands = cur.shape[1]
|
|
148
|
+
H, W = cur.shape[-2], cur.shape[-1]
|
|
149
|
+
|
|
150
|
+
if num_bands != len(MODALITY_BANDS[modality]):
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Modality '{modality}' expected {len(MODALITY_BANDS[modality])} bands, "
|
|
153
|
+
f"got {num_bands} (shape {tuple(cur.shape)})"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
batch[modality] = cur
|
|
157
|
+
|
|
158
|
+
# Ensure same spatial extent across all modalities (H*res, W*res)
|
|
159
|
+
extent = (
|
|
160
|
+
H * MODALITY_RESOLUTIONS[modality],
|
|
161
|
+
W * MODALITY_RESOLUTIONS[modality],
|
|
162
|
+
)
|
|
163
|
+
if spatial_extent is None:
|
|
164
|
+
spatial_extent = extent
|
|
165
|
+
elif spatial_extent != extent:
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"All modalities must share the same spatial extent (H*res, W*res)."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Add *_dates
|
|
171
|
+
to_add = {}
|
|
172
|
+
for modality, x in list(batch.items()):
|
|
173
|
+
if modality in TIME_SERIES_MODALITIES:
|
|
174
|
+
B, T = x.shape[0], x.shape[1]
|
|
175
|
+
d = torch.as_tensor(
|
|
176
|
+
self.dates[modality], dtype=torch.long, device=x.device
|
|
177
|
+
)
|
|
178
|
+
if d.ndim != 1 or d.numel() != T:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"dates for '{modality}' must be 1D length {T}, got {tuple(d.shape)}"
|
|
181
|
+
)
|
|
182
|
+
to_add[f"{modality}_dates"] = d.unsqueeze(0).repeat(B, 1)
|
|
183
|
+
|
|
184
|
+
batch.update(to_add)
|
|
185
|
+
|
|
186
|
+
kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
|
|
187
|
+
if self.output == "dense":
|
|
188
|
+
kwargs["output_modality"] = self.output_modality
|
|
189
|
+
|
|
190
|
+
features = self.model(batch, **kwargs)
|
|
191
|
+
return [rearrange(features, "b h w d -> b d h w")]
|
|
192
|
+
|
|
193
|
+
def get_backbone_channels(self) -> list:
|
|
194
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
195
|
+
|
|
196
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
197
|
+
to the feature maps that the backbone returns.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
201
|
+
"""
|
|
202
|
+
if self.output == "patch":
|
|
203
|
+
return [(self.patch_size_meters // 10, 768)]
|
|
204
|
+
elif self.output == "dense":
|
|
205
|
+
return [(1, 1536)]
|
|
206
|
+
else:
|
|
207
|
+
raise ValueError(f"invalid output type: {self.output}")
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Clay models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from importlib.resources import files
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import yaml
|
|
12
|
+
from einops import rearrange
|
|
13
|
+
from huggingface_hub import hf_hub_download
|
|
14
|
+
|
|
15
|
+
# from claymodel.module import ClayMAEModule
|
|
16
|
+
from terratorch.models.backbones.clay_v15.module import ClayMAEModule
|
|
17
|
+
|
|
18
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
19
|
+
from rslearn.train.transforms.transform import Transform
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ClaySize(str, Enum):
|
|
23
|
+
"""Size of the Clay model."""
|
|
24
|
+
|
|
25
|
+
BASE = "base"
|
|
26
|
+
LARGE = "large"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
PATCH_SIZE = 8
|
|
30
|
+
CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
|
|
31
|
+
CONFIG_DIR = files("rslearn.models.clay.configs")
|
|
32
|
+
CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_clay_checkpoint_path(
|
|
36
|
+
filename: str = "v1.5/clay-v1.5.ckpt",
|
|
37
|
+
repo_id: str = "made-with-clay/Clay",
|
|
38
|
+
) -> str:
|
|
39
|
+
"""Return a cached local path to the Clay ckpt from the Hugging Face Hub."""
|
|
40
|
+
return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Clay(torch.nn.Module):
|
|
44
|
+
"""Clay backbones."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
model_size: ClaySize,
|
|
49
|
+
modality: str = "sentinel-2-l2a",
|
|
50
|
+
checkpoint_path: str | None = None,
|
|
51
|
+
metadata_path: str = CLAY_METADATA_PATH,
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Initialize the Clay model.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
model_size: The size of the Clay model.
|
|
57
|
+
modality: The modality to use (subset of CLAY_MODALITIES).
|
|
58
|
+
checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
|
|
59
|
+
metadata_path: Path to metadata.yaml.
|
|
60
|
+
"""
|
|
61
|
+
super().__init__()
|
|
62
|
+
|
|
63
|
+
# Clay only supports single modality input
|
|
64
|
+
if modality not in CLAY_MODALITIES:
|
|
65
|
+
raise ValueError(f"Invalid modality: {modality}")
|
|
66
|
+
|
|
67
|
+
ckpt = checkpoint_path or get_clay_checkpoint_path()
|
|
68
|
+
if model_size == ClaySize.LARGE:
|
|
69
|
+
self.model = ClayMAEModule.load_from_checkpoint(
|
|
70
|
+
checkpoint_path=ckpt,
|
|
71
|
+
model_size="large",
|
|
72
|
+
metadata_path=metadata_path,
|
|
73
|
+
dolls=[16, 32, 64, 128, 256, 768, 1024],
|
|
74
|
+
doll_weights=[1, 1, 1, 1, 1, 1, 1],
|
|
75
|
+
mask_ratio=0.0,
|
|
76
|
+
shuffle=False,
|
|
77
|
+
)
|
|
78
|
+
elif model_size == ClaySize.BASE:
|
|
79
|
+
# Failed to load Base model in Clay v1.5
|
|
80
|
+
raise ValueError("Clay BASE model currently not supported in v1.5.")
|
|
81
|
+
self.model = ClayMAEModule.load_from_checkpoint(
|
|
82
|
+
checkpoint_path=ckpt,
|
|
83
|
+
model_size="base",
|
|
84
|
+
metadata_path=metadata_path,
|
|
85
|
+
dolls=[16, 32, 64, 128, 256, 768],
|
|
86
|
+
doll_weights=[1, 1, 1, 1, 1, 1],
|
|
87
|
+
mask_ratio=0.0,
|
|
88
|
+
shuffle=False,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Invalid model size: {model_size}")
|
|
92
|
+
|
|
93
|
+
with open(metadata_path) as f:
|
|
94
|
+
self.metadata = yaml.safe_load(f)
|
|
95
|
+
|
|
96
|
+
self.model_size = model_size
|
|
97
|
+
self.modality = modality
|
|
98
|
+
|
|
99
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
100
|
+
"""Forward pass for the Clay model.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
inputs: input dicts that must include `self.modality` as a key
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
List[torch.Tensor]: Single-scale feature tensors from the encoder.
|
|
107
|
+
"""
|
|
108
|
+
if self.modality not in inputs[0]:
|
|
109
|
+
raise ValueError(f"Missing modality {self.modality} in inputs.")
|
|
110
|
+
|
|
111
|
+
param = next(self.model.parameters())
|
|
112
|
+
device = param.device
|
|
113
|
+
|
|
114
|
+
chips = torch.stack(
|
|
115
|
+
[inp[self.modality] for inp in inputs], dim=0
|
|
116
|
+
) # (B, C, H, W)
|
|
117
|
+
|
|
118
|
+
order = self.metadata[self.modality]["band_order"]
|
|
119
|
+
wavelengths = []
|
|
120
|
+
for band in self.metadata[self.modality]["band_order"]:
|
|
121
|
+
wavelengths.append(
|
|
122
|
+
self.metadata[self.modality]["bands"]["wavelength"][band] * 1000
|
|
123
|
+
) # Convert to nm
|
|
124
|
+
# Check channel count matches Clay expectation
|
|
125
|
+
if chips.shape[1] != len(order):
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Channel count {chips.shape[1]} does not match expected {len(order)} for {self.modality}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Time & latlon zeros are valid per Clay doc
|
|
131
|
+
# https://clay-foundation.github.io/model/getting-started/basic_use.html
|
|
132
|
+
datacube = {
|
|
133
|
+
"platform": self.modality,
|
|
134
|
+
"time": torch.zeros(chips.shape[0], 4).to(device),
|
|
135
|
+
"latlon": torch.zeros(chips.shape[0], 4).to(device),
|
|
136
|
+
"pixels": chips.to(device),
|
|
137
|
+
"gsd": torch.tensor(self.metadata[self.modality]["gsd"]).to(device),
|
|
138
|
+
"waves": torch.tensor(wavelengths).to(device),
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
tokens, *_ = self.model.model.encoder(datacube) # (B, 1 + N, D)
|
|
142
|
+
|
|
143
|
+
# Remove CLS token
|
|
144
|
+
spatial = tokens[:, 1:, :] # (B, N, D)
|
|
145
|
+
n_tokens = spatial.shape[1]
|
|
146
|
+
side = int(math.isqrt(n_tokens))
|
|
147
|
+
if chips.shape[2] != side * PATCH_SIZE or chips.shape[3] != side * PATCH_SIZE:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Input spatial size {(chips.shape[2], chips.shape[3])} is not compatible with patch size {PATCH_SIZE}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
|
|
153
|
+
return [features]
|
|
154
|
+
|
|
155
|
+
def get_backbone_channels(self) -> list:
|
|
156
|
+
"""Return output channels of this model when used as a backbone."""
|
|
157
|
+
if self.model_size == ClaySize.LARGE:
|
|
158
|
+
depth = 1024
|
|
159
|
+
elif self.model_size == ClaySize.BASE:
|
|
160
|
+
depth = 768
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError(f"Invalid model size: {self.model_size}")
|
|
163
|
+
return [(PATCH_SIZE, depth)]
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class ClayNormalize(Transform):
|
|
167
|
+
"""Normalize inputs using Clay metadata.
|
|
168
|
+
|
|
169
|
+
For Sentinel-1, the intensities should be converted to decibels.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, metadata_path: str = CLAY_METADATA_PATH) -> None:
|
|
173
|
+
"""Initialize ClayNormalize."""
|
|
174
|
+
super().__init__()
|
|
175
|
+
with open(metadata_path) as f:
|
|
176
|
+
metadata = yaml.safe_load(f)
|
|
177
|
+
normalizers = {}
|
|
178
|
+
for modality in CLAY_MODALITIES:
|
|
179
|
+
if modality not in metadata:
|
|
180
|
+
continue
|
|
181
|
+
modality_metadata = metadata[modality]
|
|
182
|
+
means = [
|
|
183
|
+
modality_metadata["bands"]["mean"][b]
|
|
184
|
+
for b in modality_metadata["band_order"]
|
|
185
|
+
]
|
|
186
|
+
stds = [
|
|
187
|
+
modality_metadata["bands"]["std"][b]
|
|
188
|
+
for b in modality_metadata["band_order"]
|
|
189
|
+
]
|
|
190
|
+
normalizers[modality] = Normalize(
|
|
191
|
+
mean=means,
|
|
192
|
+
std=stds,
|
|
193
|
+
selectors=[modality],
|
|
194
|
+
num_bands=len(means),
|
|
195
|
+
)
|
|
196
|
+
self.normalizers = torch.nn.ModuleDict(normalizers)
|
|
197
|
+
|
|
198
|
+
def apply_image(
|
|
199
|
+
self, image: torch.Tensor, means: list[float], stds: list[float]
|
|
200
|
+
) -> torch.Tensor:
|
|
201
|
+
"""Normalize the specified image with Clay normalization."""
|
|
202
|
+
x = image.float()
|
|
203
|
+
if x.shape[0] != len(means):
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"channel count {x.shape[0]} does not match provided band stats {len(means)}"
|
|
206
|
+
)
|
|
207
|
+
for c in range(x.shape[0]):
|
|
208
|
+
x[c] = (x[c] - means[c]) / stds[c]
|
|
209
|
+
return x
|
|
210
|
+
|
|
211
|
+
def forward(
|
|
212
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
213
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
214
|
+
"""Normalize the specified image with Clay normalization."""
|
|
215
|
+
for modality, normalizer in self.normalizers.items():
|
|
216
|
+
if modality not in input_dict:
|
|
217
|
+
continue
|
|
218
|
+
input_dict, target_dict = normalizer(input_dict, target_dict)
|
|
219
|
+
return input_dict, target_dict
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
sentinel-2-l2a:
|
|
2
|
+
band_order:
|
|
3
|
+
- blue
|
|
4
|
+
- green
|
|
5
|
+
- red
|
|
6
|
+
- rededge1
|
|
7
|
+
- rededge2
|
|
8
|
+
- rededge3
|
|
9
|
+
- nir
|
|
10
|
+
- nir08
|
|
11
|
+
- swir16
|
|
12
|
+
- swir22
|
|
13
|
+
rgb_indices:
|
|
14
|
+
- 2
|
|
15
|
+
- 1
|
|
16
|
+
- 0
|
|
17
|
+
gsd: 10
|
|
18
|
+
bands:
|
|
19
|
+
mean:
|
|
20
|
+
blue: 1105.
|
|
21
|
+
green: 1355.
|
|
22
|
+
red: 1552.
|
|
23
|
+
rededge1: 1887.
|
|
24
|
+
rededge2: 2422.
|
|
25
|
+
rededge3: 2630.
|
|
26
|
+
nir: 2743.
|
|
27
|
+
nir08: 2785.
|
|
28
|
+
swir16: 2388.
|
|
29
|
+
swir22: 1835.
|
|
30
|
+
std:
|
|
31
|
+
blue: 1809.
|
|
32
|
+
green: 1757.
|
|
33
|
+
red: 1888.
|
|
34
|
+
rededge1: 1870.
|
|
35
|
+
rededge2: 1732.
|
|
36
|
+
rededge3: 1697.
|
|
37
|
+
nir: 1742.
|
|
38
|
+
nir08: 1648.
|
|
39
|
+
swir16: 1470.
|
|
40
|
+
swir22: 1379.
|
|
41
|
+
wavelength:
|
|
42
|
+
blue: 0.493
|
|
43
|
+
green: 0.56
|
|
44
|
+
red: 0.665
|
|
45
|
+
rededge1: 0.704
|
|
46
|
+
rededge2: 0.74
|
|
47
|
+
rededge3: 0.783
|
|
48
|
+
nir: 0.842
|
|
49
|
+
nir08: 0.865
|
|
50
|
+
swir16: 1.61
|
|
51
|
+
swir22: 2.19
|
|
52
|
+
planetscope-sr:
|
|
53
|
+
band_order:
|
|
54
|
+
- coastal_blue
|
|
55
|
+
- blue
|
|
56
|
+
- green_i
|
|
57
|
+
- green
|
|
58
|
+
- yellow
|
|
59
|
+
- red
|
|
60
|
+
- rededge
|
|
61
|
+
- nir
|
|
62
|
+
rgb_indices:
|
|
63
|
+
- 5
|
|
64
|
+
- 3
|
|
65
|
+
- 1
|
|
66
|
+
gsd: 5
|
|
67
|
+
bands:
|
|
68
|
+
mean:
|
|
69
|
+
coastal_blue: 1720.
|
|
70
|
+
blue: 1715.
|
|
71
|
+
green_i: 1913.
|
|
72
|
+
green: 2088.
|
|
73
|
+
yellow: 2274.
|
|
74
|
+
red: 2290.
|
|
75
|
+
rededge: 2613.
|
|
76
|
+
nir: 3970.
|
|
77
|
+
std:
|
|
78
|
+
coastal_blue: 747.
|
|
79
|
+
blue: 698.
|
|
80
|
+
green_i: 739.
|
|
81
|
+
green: 768.
|
|
82
|
+
yellow: 849.
|
|
83
|
+
red: 868.
|
|
84
|
+
rededge: 849.
|
|
85
|
+
nir: 914.
|
|
86
|
+
wavelength:
|
|
87
|
+
coastal_blue: 0.443
|
|
88
|
+
blue: 0.490
|
|
89
|
+
green_i: 0.531
|
|
90
|
+
green: 0.565
|
|
91
|
+
yellow: 0.610
|
|
92
|
+
red: 0.665
|
|
93
|
+
rededge: 0.705
|
|
94
|
+
nir: 0.865
|
|
95
|
+
landsat-c2l1:
|
|
96
|
+
band_order:
|
|
97
|
+
- red
|
|
98
|
+
- green
|
|
99
|
+
- blue
|
|
100
|
+
- nir08
|
|
101
|
+
- swir16
|
|
102
|
+
- swir22
|
|
103
|
+
rgb_indices:
|
|
104
|
+
- 0
|
|
105
|
+
- 1
|
|
106
|
+
- 2
|
|
107
|
+
gsd: 30
|
|
108
|
+
bands:
|
|
109
|
+
mean:
|
|
110
|
+
red: 10678.
|
|
111
|
+
green: 10563.
|
|
112
|
+
blue: 11083.
|
|
113
|
+
nir08: 14792.
|
|
114
|
+
swir16: 12276.
|
|
115
|
+
swir22: 10114.
|
|
116
|
+
std:
|
|
117
|
+
red: 6025.
|
|
118
|
+
green: 5411.
|
|
119
|
+
blue: 5468.
|
|
120
|
+
nir08: 6746.
|
|
121
|
+
swir16: 5897.
|
|
122
|
+
swir22: 4850.
|
|
123
|
+
wavelength:
|
|
124
|
+
red: 0.65
|
|
125
|
+
green: 0.56
|
|
126
|
+
blue: 0.48
|
|
127
|
+
nir08: 0.86
|
|
128
|
+
swir16: 1.6
|
|
129
|
+
swir22: 2.2
|
|
130
|
+
landsat-c2l2-sr:
|
|
131
|
+
band_order:
|
|
132
|
+
- red
|
|
133
|
+
- green
|
|
134
|
+
- blue
|
|
135
|
+
- nir08
|
|
136
|
+
- swir16
|
|
137
|
+
- swir22
|
|
138
|
+
rgb_indices:
|
|
139
|
+
- 0
|
|
140
|
+
- 1
|
|
141
|
+
- 2
|
|
142
|
+
gsd: 30
|
|
143
|
+
bands:
|
|
144
|
+
mean:
|
|
145
|
+
red: 13705.
|
|
146
|
+
green: 13310.
|
|
147
|
+
blue: 12474.
|
|
148
|
+
nir08: 17801.
|
|
149
|
+
swir16: 14615.
|
|
150
|
+
swir22: 12701.
|
|
151
|
+
std:
|
|
152
|
+
red: 9578.
|
|
153
|
+
green: 9408.
|
|
154
|
+
blue: 10144.
|
|
155
|
+
nir08: 8277.
|
|
156
|
+
swir16: 5300.
|
|
157
|
+
swir22: 4522.
|
|
158
|
+
wavelength:
|
|
159
|
+
red: 0.65
|
|
160
|
+
green: 0.56
|
|
161
|
+
blue: 0.48
|
|
162
|
+
nir08: 0.86
|
|
163
|
+
swir16: 1.6
|
|
164
|
+
swir22: 2.2
|
|
165
|
+
naip:
|
|
166
|
+
band_order:
|
|
167
|
+
- red
|
|
168
|
+
- green
|
|
169
|
+
- blue
|
|
170
|
+
- nir
|
|
171
|
+
rgb_indices:
|
|
172
|
+
- 0
|
|
173
|
+
- 1
|
|
174
|
+
- 2
|
|
175
|
+
gsd: 1.0
|
|
176
|
+
bands:
|
|
177
|
+
mean:
|
|
178
|
+
red: 110.16
|
|
179
|
+
green: 115.41
|
|
180
|
+
blue: 98.15
|
|
181
|
+
nir: 139.04
|
|
182
|
+
std:
|
|
183
|
+
red: 47.23
|
|
184
|
+
green: 39.82
|
|
185
|
+
blue: 35.43
|
|
186
|
+
nir: 49.86
|
|
187
|
+
wavelength:
|
|
188
|
+
red: 0.65
|
|
189
|
+
green: 0.56
|
|
190
|
+
blue: 0.48
|
|
191
|
+
nir: 0.842
|
|
192
|
+
linz:
|
|
193
|
+
band_order:
|
|
194
|
+
- red
|
|
195
|
+
- green
|
|
196
|
+
- blue
|
|
197
|
+
rgb_indices:
|
|
198
|
+
- 0
|
|
199
|
+
- 1
|
|
200
|
+
- 2
|
|
201
|
+
gsd: 0.5
|
|
202
|
+
bands:
|
|
203
|
+
mean:
|
|
204
|
+
red: 89.96
|
|
205
|
+
green: 99.46
|
|
206
|
+
blue: 89.51
|
|
207
|
+
std:
|
|
208
|
+
red: 41.83
|
|
209
|
+
green: 36.96
|
|
210
|
+
blue: 31.45
|
|
211
|
+
wavelength:
|
|
212
|
+
red: 0.635
|
|
213
|
+
green: 0.555
|
|
214
|
+
blue: 0.465
|
|
215
|
+
sentinel-1-rtc:
|
|
216
|
+
band_order:
|
|
217
|
+
- vv
|
|
218
|
+
- vh
|
|
219
|
+
gsd: 10
|
|
220
|
+
bands:
|
|
221
|
+
mean:
|
|
222
|
+
vv: -12.113
|
|
223
|
+
vh: -18.673
|
|
224
|
+
std:
|
|
225
|
+
vv: 8.314
|
|
226
|
+
vh: 8.017
|
|
227
|
+
wavelength:
|
|
228
|
+
vv: 3.5
|
|
229
|
+
vh: 4.0
|
|
230
|
+
modis:
|
|
231
|
+
band_order:
|
|
232
|
+
- sur_refl_b01
|
|
233
|
+
- sur_refl_b02
|
|
234
|
+
- sur_refl_b03
|
|
235
|
+
- sur_refl_b04
|
|
236
|
+
- sur_refl_b05
|
|
237
|
+
- sur_refl_b06
|
|
238
|
+
- sur_refl_b07
|
|
239
|
+
rgb_indices:
|
|
240
|
+
- 0
|
|
241
|
+
- 3
|
|
242
|
+
- 2
|
|
243
|
+
gsd: 500
|
|
244
|
+
bands:
|
|
245
|
+
mean:
|
|
246
|
+
sur_refl_b01: 1072.
|
|
247
|
+
sur_refl_b02: 1624.
|
|
248
|
+
sur_refl_b03: 931.
|
|
249
|
+
sur_refl_b04: 1023.
|
|
250
|
+
sur_refl_b05: 1599.
|
|
251
|
+
sur_refl_b06: 1404.
|
|
252
|
+
sur_refl_b07: 1051.
|
|
253
|
+
std:
|
|
254
|
+
sur_refl_b01: 1643.
|
|
255
|
+
sur_refl_b02: 1878.
|
|
256
|
+
sur_refl_b03: 1449.
|
|
257
|
+
sur_refl_b04: 1538.
|
|
258
|
+
sur_refl_b05: 1763.
|
|
259
|
+
sur_refl_b06: 1618.
|
|
260
|
+
sur_refl_b07: 1396.
|
|
261
|
+
wavelength:
|
|
262
|
+
sur_refl_b01: .645
|
|
263
|
+
sur_refl_b02: .858
|
|
264
|
+
sur_refl_b03: .469
|
|
265
|
+
sur_refl_b04: .555
|
|
266
|
+
sur_refl_b05: 1.240
|
|
267
|
+
sur_refl_b06: 1.640
|
|
268
|
+
sur_refl_b07: 2.130
|
|
269
|
+
satellogic-MSI-L1D:
|
|
270
|
+
band_order:
|
|
271
|
+
- red
|
|
272
|
+
- green
|
|
273
|
+
- blue
|
|
274
|
+
- nir
|
|
275
|
+
rgb_indices:
|
|
276
|
+
- 0
|
|
277
|
+
- 1
|
|
278
|
+
- 2
|
|
279
|
+
gsd: 1.0
|
|
280
|
+
bands:
|
|
281
|
+
mean:
|
|
282
|
+
red: 1451.54
|
|
283
|
+
green: 1456.54
|
|
284
|
+
blue: 1543.22
|
|
285
|
+
nir: 2132.68
|
|
286
|
+
std:
|
|
287
|
+
red: 995.48
|
|
288
|
+
green: 771.29
|
|
289
|
+
blue: 708.86
|
|
290
|
+
nir: 1236.71
|
|
291
|
+
wavelength:
|
|
292
|
+
red: 0.640
|
|
293
|
+
green: 0.545
|
|
294
|
+
blue: 0.480
|
|
295
|
+
nir: 0.825
|