rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/models/anysat.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""AnySat model.
|
|
2
|
+
|
|
3
|
+
This code loads the AnySat model from torch hub. See
|
|
4
|
+
https://github.com/gastruc/AnySat for applicable license and copyright information.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
|
|
12
|
+
from rslearn.train.model_context import ModelContext
|
|
13
|
+
|
|
14
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
15
|
+
|
|
16
|
+
# AnySat github: https://github.com/gastruc/AnySat
|
|
17
|
+
# Modalities and expected resolutions (meters)
|
|
18
|
+
MODALITY_RESOLUTIONS: dict[str, float] = {
|
|
19
|
+
"aerial": 0.2,
|
|
20
|
+
"aerial-flair": 0.2,
|
|
21
|
+
"spot": 1,
|
|
22
|
+
"naip": 1.25,
|
|
23
|
+
"s2": 10,
|
|
24
|
+
"s1-asc": 10,
|
|
25
|
+
"s1": 10,
|
|
26
|
+
"alos": 30,
|
|
27
|
+
"l7": 30,
|
|
28
|
+
"l8": 10, # L8 must be upsampled to 10 m in AnySat
|
|
29
|
+
"modis": 250,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
# Modalities and expected band names
|
|
33
|
+
MODALITY_BANDS: dict[str, list[str]] = {
|
|
34
|
+
"aerial": ["R", "G", "B", "NiR"],
|
|
35
|
+
"aerial-flair": ["R", "G", "B", "NiR", "Elevation"],
|
|
36
|
+
"spot": ["R", "G", "B"],
|
|
37
|
+
"naip": ["R", "G", "B", "NiR"],
|
|
38
|
+
"s2": ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8a", "B11", "B12"],
|
|
39
|
+
"s1-asc": ["VV", "VH"],
|
|
40
|
+
"s1": ["VV", "VH", "Ratio"],
|
|
41
|
+
"alos": ["HH", "HV", "Ratio"],
|
|
42
|
+
"l7": ["B1", "B2", "B3", "B4", "B5", "B7"],
|
|
43
|
+
"l8": ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"],
|
|
44
|
+
"modis": ["B1", "B2", "B3", "B4", "B5", "B6", "B7"],
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Modalities that require *_dates* input
|
|
48
|
+
TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AnySat(FeatureExtractor):
|
|
52
|
+
"""AnySat backbone (outputs one feature map)."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
modalities: list[str],
|
|
57
|
+
patch_size_meters: int,
|
|
58
|
+
output: str = "patch",
|
|
59
|
+
output_modality: str | None = None,
|
|
60
|
+
hub_repo: str = "gastruc/anysat",
|
|
61
|
+
pretrained: bool = True,
|
|
62
|
+
force_reload: bool = False,
|
|
63
|
+
flash_attn: bool = False,
|
|
64
|
+
) -> None:
|
|
65
|
+
"""Initialize an AnySat model.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
modalities: list of modalities to use as input (1 or more).
|
|
69
|
+
patch_size_meters: patch size in meters (must be multiple of 10). Avoid having more than 1024 patches per tile
|
|
70
|
+
ie, the height/width in meters should be <= 32 * patch_size_meters.
|
|
71
|
+
dates: dict mapping time-series modalities to list of dates (day number in a year, 0-255).
|
|
72
|
+
output: 'patch' (default) or 'dense'. Use 'patch' for classification tasks,
|
|
73
|
+
'dense' for segmentation tasks.
|
|
74
|
+
output_modality: required if output='dense', specifies which modality to use
|
|
75
|
+
for the dense output (one of the input modalities).
|
|
76
|
+
hub_repo: torch.hub repository to load AnySat from.
|
|
77
|
+
pretrained: whether to load pretrained weights.
|
|
78
|
+
force_reload: whether to force re-download of the model.
|
|
79
|
+
flash_attn: whether to use flash attention (if available).
|
|
80
|
+
"""
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
if not modalities:
|
|
84
|
+
raise ValueError("At least one modality must be specified.")
|
|
85
|
+
for m in modalities:
|
|
86
|
+
if m not in MODALITY_RESOLUTIONS:
|
|
87
|
+
raise ValueError(f"Invalid modality: {m}")
|
|
88
|
+
|
|
89
|
+
if patch_size_meters % 10 != 0:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"In AnySat, `patch_size` is in meters and must be a multiple of 10."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
output = output.lower()
|
|
95
|
+
if output not in {"patch", "dense"}:
|
|
96
|
+
raise ValueError("`output` must be 'patch' or 'dense'.")
|
|
97
|
+
if output == "dense" and output_modality is None:
|
|
98
|
+
raise ValueError("`output_modality` is required when output='dense'.")
|
|
99
|
+
|
|
100
|
+
self.modalities = modalities
|
|
101
|
+
self.patch_size_meters = int(patch_size_meters)
|
|
102
|
+
self.output = output
|
|
103
|
+
self.output_modality = output_modality
|
|
104
|
+
|
|
105
|
+
self.model = torch.hub.load( # nosec B614
|
|
106
|
+
hub_repo,
|
|
107
|
+
"anysat",
|
|
108
|
+
pretrained=pretrained,
|
|
109
|
+
force_reload=force_reload,
|
|
110
|
+
flash_attn=flash_attn,
|
|
111
|
+
)
|
|
112
|
+
self._embed_dim = 768 # base width, 'dense' returns 2x
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def time_ranges_to_doy(
|
|
116
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
117
|
+
device: torch.device,
|
|
118
|
+
) -> torch.Tensor:
|
|
119
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by AnySat.
|
|
120
|
+
|
|
121
|
+
AnySat uses the doy with each timestamp, so we take the midpoint
|
|
122
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
123
|
+
time so that start_time == end_time == mid_time.
|
|
124
|
+
"""
|
|
125
|
+
doys = [(t[0] + ((t[1] - t[0]) / 2)).timetuple().tm_yday for t in time_ranges]
|
|
126
|
+
return torch.tensor(doys, dtype=torch.int32, device=device)
|
|
127
|
+
|
|
128
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
129
|
+
"""Forward pass for the AnySat model.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
context: the model context. Input dicts must include modalities as keys
|
|
133
|
+
which are defined in the self.modalities list
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
a FeatureMaps with one feature map at the configured patch size.
|
|
137
|
+
"""
|
|
138
|
+
inputs = context.inputs
|
|
139
|
+
|
|
140
|
+
batch: dict[str, torch.Tensor] = {}
|
|
141
|
+
spatial_extent: tuple[float, float] | None = None
|
|
142
|
+
|
|
143
|
+
for modality in self.modalities:
|
|
144
|
+
if modality not in inputs[0]:
|
|
145
|
+
raise ValueError(f"Modality '{modality}' not present in inputs.")
|
|
146
|
+
|
|
147
|
+
cur = torch.stack(
|
|
148
|
+
[inp[modality].image for inp in inputs], dim=0
|
|
149
|
+
) # (B, C, T, H, W)
|
|
150
|
+
|
|
151
|
+
if modality in TIME_SERIES_MODALITIES:
|
|
152
|
+
num_bands = cur.shape[1]
|
|
153
|
+
cur = rearrange(cur, "b c t h w -> b t c h w")
|
|
154
|
+
H, W = cur.shape[-2], cur.shape[-1]
|
|
155
|
+
|
|
156
|
+
if inputs[0][modality].timestamps is None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Require timestamps for time series modality {modality}"
|
|
159
|
+
)
|
|
160
|
+
timestamps = torch.stack(
|
|
161
|
+
[
|
|
162
|
+
self.time_ranges_to_doy(inp[modality].timestamps, cur.device) # type: ignore
|
|
163
|
+
for inp in inputs
|
|
164
|
+
],
|
|
165
|
+
dim=0,
|
|
166
|
+
)
|
|
167
|
+
batch[f"{modality}_dates"] = timestamps
|
|
168
|
+
else:
|
|
169
|
+
# take the first (assumed only) timestep
|
|
170
|
+
cur = cur[:, :, 0]
|
|
171
|
+
num_bands = cur.shape[1]
|
|
172
|
+
H, W = cur.shape[-2], cur.shape[-1]
|
|
173
|
+
|
|
174
|
+
if num_bands != len(MODALITY_BANDS[modality]):
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Modality '{modality}' expected {len(MODALITY_BANDS[modality])} bands, "
|
|
177
|
+
f"got {num_bands} (shape {tuple(cur.shape)})"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
batch[modality] = cur
|
|
181
|
+
|
|
182
|
+
# Ensure same spatial extent across all modalities (H*res, W*res)
|
|
183
|
+
extent = (
|
|
184
|
+
H * MODALITY_RESOLUTIONS[modality],
|
|
185
|
+
W * MODALITY_RESOLUTIONS[modality],
|
|
186
|
+
)
|
|
187
|
+
if spatial_extent is None:
|
|
188
|
+
spatial_extent = extent
|
|
189
|
+
elif spatial_extent != extent:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"All modalities must share the same spatial extent (H*res, W*res)."
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
kwargs = {"patch_size": self.patch_size_meters, "output": self.output}
|
|
195
|
+
if self.output == "dense":
|
|
196
|
+
kwargs["output_modality"] = self.output_modality
|
|
197
|
+
|
|
198
|
+
features = self.model(batch, **kwargs)
|
|
199
|
+
return FeatureMaps([rearrange(features, "b h w d -> b d h w")])
|
|
200
|
+
|
|
201
|
+
def get_backbone_channels(self) -> list:
|
|
202
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
203
|
+
|
|
204
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
205
|
+
to the feature maps that the backbone returns.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
209
|
+
"""
|
|
210
|
+
if self.output == "patch":
|
|
211
|
+
return [(self.patch_size_meters // 10, 768)]
|
|
212
|
+
elif self.output == "dense":
|
|
213
|
+
return [(1, 1536)]
|
|
214
|
+
else:
|
|
215
|
+
raise ValueError(f"invalid output type: {self.output}")
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""An attention pooling layer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from rslearn.models.component import (
|
|
12
|
+
FeatureMaps,
|
|
13
|
+
IntermediateComponent,
|
|
14
|
+
TokenFeatureMaps,
|
|
15
|
+
)
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimpleAttentionPool(IntermediateComponent):
|
|
20
|
+
"""Simple Attention Pooling.
|
|
21
|
+
|
|
22
|
+
Given a token feature map of shape BCHWN,
|
|
23
|
+
learn an attention layer which aggregates over
|
|
24
|
+
the N dimension.
|
|
25
|
+
|
|
26
|
+
This is done simply by learning a mapping D->1 which is the weight
|
|
27
|
+
which should be assigned to each token during averaging:
|
|
28
|
+
|
|
29
|
+
output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
|
|
33
|
+
"""Initialize the simple attention pooling layer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
in_dim: the encoding dimension D
|
|
37
|
+
hidden_linear: whether to apply an additional linear transformation D -> D
|
|
38
|
+
to the feat tokens. If this is True, a ReLU activation is applied
|
|
39
|
+
after the first linear transformation.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
if hidden_linear:
|
|
43
|
+
self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
|
|
44
|
+
else:
|
|
45
|
+
self.hidden_linear = None
|
|
46
|
+
self.linear = nn.Linear(in_features=in_dim, out_features=1)
|
|
47
|
+
|
|
48
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
50
|
+
B, D, H, W, N = feat_tokens.shape
|
|
51
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
52
|
+
if self.hidden_linear is not None:
|
|
53
|
+
feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
|
|
54
|
+
attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
|
|
55
|
+
feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
|
|
56
|
+
return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
|
|
57
|
+
|
|
58
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
59
|
+
"""Forward pass for attention pooling linear probe.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
63
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple maps
|
|
64
|
+
are passed, we apply the same linear layers to all of them.
|
|
65
|
+
context: the model context.
|
|
66
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
torch.Tensor:
|
|
70
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
71
|
+
"""
|
|
72
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
73
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
74
|
+
|
|
75
|
+
features = []
|
|
76
|
+
for feat in intermediates.feature_maps:
|
|
77
|
+
features.append(self.forward_for_map(feat))
|
|
78
|
+
return FeatureMaps(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AttentionPool(IntermediateComponent):
|
|
82
|
+
"""Attention Pooling.
|
|
83
|
+
|
|
84
|
+
Given a feature map of shape BCHWN,
|
|
85
|
+
learn an attention layer which aggregates over
|
|
86
|
+
the N dimension.
|
|
87
|
+
|
|
88
|
+
We do this by learning a query token, and applying a standard
|
|
89
|
+
attention mechanism against this learned query token.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
|
|
93
|
+
"""Initialize the attention pooling layer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
in_dim: the encoding dimension D
|
|
97
|
+
num_heads: the number of heads to use
|
|
98
|
+
linear_on_kv: Whether to apply a linear layer on the input tokens
|
|
99
|
+
to create the key and value tokens.
|
|
100
|
+
"""
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
|
|
103
|
+
if linear_on_kv:
|
|
104
|
+
self.k_linear = nn.Linear(in_dim, in_dim)
|
|
105
|
+
self.v_linear = nn.Linear(in_dim, in_dim)
|
|
106
|
+
else:
|
|
107
|
+
self.k_linear = None
|
|
108
|
+
self.v_linear = None
|
|
109
|
+
if in_dim % num_heads != 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
|
|
112
|
+
)
|
|
113
|
+
self.num_heads = num_heads
|
|
114
|
+
self.init_weights()
|
|
115
|
+
|
|
116
|
+
def init_weights(self) -> None:
|
|
117
|
+
"""Initialize weights for the probe."""
|
|
118
|
+
nn.init.trunc_normal_(self.query_token, std=0.02)
|
|
119
|
+
|
|
120
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
122
|
+
B, D, H, W, N = feat_tokens.shape
|
|
123
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
124
|
+
collapsed_dim = B * H * W
|
|
125
|
+
q = self.query_token.expand(collapsed_dim, 1, -1)
|
|
126
|
+
q = q.reshape(
|
|
127
|
+
collapsed_dim, 1, self.num_heads, D // self.num_heads
|
|
128
|
+
) # [B, 1, head, D_head]
|
|
129
|
+
q = rearrange(q, "b h n d -> b n h d")
|
|
130
|
+
if self.k_linear is not None:
|
|
131
|
+
assert self.v_linear is not None
|
|
132
|
+
k = self.k_linear(feat_tokens).reshape(
|
|
133
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
134
|
+
)
|
|
135
|
+
v = self.v_linear(feat_tokens).reshape(
|
|
136
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
k = feat_tokens.reshape(
|
|
140
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
141
|
+
)
|
|
142
|
+
v = feat_tokens.reshape(
|
|
143
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
144
|
+
)
|
|
145
|
+
k = rearrange(k, "b n h d -> b h n d")
|
|
146
|
+
v = rearrange(v, "b n h d -> b h n d")
|
|
147
|
+
|
|
148
|
+
# Compute attention scores
|
|
149
|
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
|
150
|
+
D // self.num_heads
|
|
151
|
+
)
|
|
152
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
153
|
+
x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
|
|
154
|
+
return x.reshape(B, D, H, W)
|
|
155
|
+
|
|
156
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
157
|
+
"""Forward pass for attention pooling linear probe.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
161
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple feature
|
|
162
|
+
maps are passed, we apply the same attention weights (query token and linear k, v layers)
|
|
163
|
+
to all the maps.
|
|
164
|
+
context: the model context.
|
|
165
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
torch.Tensor:
|
|
169
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
170
|
+
"""
|
|
171
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
172
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
173
|
+
|
|
174
|
+
features = []
|
|
175
|
+
for feat in intermediates.feature_maps:
|
|
176
|
+
features.append(self.forward_for_map(feat))
|
|
177
|
+
return FeatureMaps(features)
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Clay models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from importlib.resources import files
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
import yaml
|
|
13
|
+
from einops import rearrange
|
|
14
|
+
from huggingface_hub import hf_hub_download
|
|
15
|
+
|
|
16
|
+
# from claymodel.module import ClayMAEModule
|
|
17
|
+
from terratorch.models.backbones.clay_v15.module import ClayMAEModule
|
|
18
|
+
|
|
19
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
20
|
+
from rslearn.train.model_context import ModelContext
|
|
21
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
22
|
+
from rslearn.train.transforms.transform import Transform
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ClaySize(str, Enum):
|
|
26
|
+
"""Size of the Clay model."""
|
|
27
|
+
|
|
28
|
+
BASE = "base"
|
|
29
|
+
LARGE = "large"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
PATCH_SIZE = 8
|
|
33
|
+
CLAY_MODALITIES = ["sentinel-2-l2a", "sentinel-1-rtc", "landsat-c2l1", "naip"]
|
|
34
|
+
CONFIG_DIR = files("rslearn.models.clay.configs")
|
|
35
|
+
CLAY_METADATA_PATH = str(CONFIG_DIR / "metadata.yaml")
|
|
36
|
+
DEFAULT_IMAGE_RESOLUTION = 128 # image resolution during pretraining
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_clay_checkpoint_path(
|
|
40
|
+
filename: str = "v1.5/clay-v1.5.ckpt",
|
|
41
|
+
repo_id: str = "made-with-clay/Clay",
|
|
42
|
+
) -> str:
|
|
43
|
+
"""Return a cached local path to the Clay ckpt from the Hugging Face Hub."""
|
|
44
|
+
return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Clay(FeatureExtractor):
|
|
48
|
+
"""Clay backbones."""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
model_size: ClaySize,
|
|
53
|
+
modality: str = "sentinel-2-l2a",
|
|
54
|
+
checkpoint_path: str | None = None,
|
|
55
|
+
metadata_path: str = CLAY_METADATA_PATH,
|
|
56
|
+
do_resizing: bool = False,
|
|
57
|
+
) -> None:
|
|
58
|
+
"""Initialize the Clay model.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
model_size: The size of the Clay model.
|
|
62
|
+
modality: The modality to use (subset of CLAY_MODALITIES).
|
|
63
|
+
checkpoint_path: Path to clay-v1.5.ckpt, if None, fetch from HF Hub.
|
|
64
|
+
metadata_path: Path to metadata.yaml.
|
|
65
|
+
do_resizing: Whether to resize the image to the input resolution.
|
|
66
|
+
"""
|
|
67
|
+
super().__init__()
|
|
68
|
+
|
|
69
|
+
# Clay only supports single modality input
|
|
70
|
+
if modality not in CLAY_MODALITIES:
|
|
71
|
+
raise ValueError(f"Invalid modality: {modality}")
|
|
72
|
+
|
|
73
|
+
ckpt = checkpoint_path or get_clay_checkpoint_path()
|
|
74
|
+
if model_size == ClaySize.LARGE:
|
|
75
|
+
self.model = ClayMAEModule.load_from_checkpoint(
|
|
76
|
+
checkpoint_path=ckpt,
|
|
77
|
+
model_size="large",
|
|
78
|
+
metadata_path=metadata_path,
|
|
79
|
+
dolls=[16, 32, 64, 128, 256, 768, 1024],
|
|
80
|
+
doll_weights=[1, 1, 1, 1, 1, 1, 1],
|
|
81
|
+
mask_ratio=0.0,
|
|
82
|
+
shuffle=False,
|
|
83
|
+
)
|
|
84
|
+
elif model_size == ClaySize.BASE:
|
|
85
|
+
# Failed to load Base model in Clay v1.5
|
|
86
|
+
raise ValueError("Clay BASE model currently not supported in v1.5.")
|
|
87
|
+
self.model = ClayMAEModule.load_from_checkpoint(
|
|
88
|
+
checkpoint_path=ckpt,
|
|
89
|
+
model_size="base",
|
|
90
|
+
metadata_path=metadata_path,
|
|
91
|
+
dolls=[16, 32, 64, 128, 256, 768],
|
|
92
|
+
doll_weights=[1, 1, 1, 1, 1, 1],
|
|
93
|
+
mask_ratio=0.0,
|
|
94
|
+
shuffle=False,
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Invalid model size: {model_size}")
|
|
98
|
+
|
|
99
|
+
with open(metadata_path) as f:
|
|
100
|
+
self.metadata = yaml.safe_load(f)
|
|
101
|
+
|
|
102
|
+
self.model_size = model_size
|
|
103
|
+
self.modality = modality
|
|
104
|
+
self.do_resizing = do_resizing
|
|
105
|
+
|
|
106
|
+
def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
|
|
107
|
+
"""Resize the image to the input resolution."""
|
|
108
|
+
new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
|
|
109
|
+
return F.interpolate(
|
|
110
|
+
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
114
|
+
"""Forward pass for the Clay model.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
context: the model context. Input dicts must include `self.modality` as a key
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
a FeatureMaps consisting of one feature map, computed by Clay.
|
|
121
|
+
"""
|
|
122
|
+
param = next(self.model.parameters())
|
|
123
|
+
device = param.device
|
|
124
|
+
|
|
125
|
+
chips = torch.stack(
|
|
126
|
+
[inp[self.modality] for inp in context.inputs], dim=0
|
|
127
|
+
) # (B, C, H, W)
|
|
128
|
+
if self.do_resizing:
|
|
129
|
+
chips = self._resize_image(chips, chips.shape[2])
|
|
130
|
+
order = self.metadata[self.modality]["band_order"]
|
|
131
|
+
wavelengths = []
|
|
132
|
+
for band in self.metadata[self.modality]["band_order"]:
|
|
133
|
+
wavelengths.append(
|
|
134
|
+
self.metadata[self.modality]["bands"]["wavelength"][band] * 1000
|
|
135
|
+
) # Convert to nm
|
|
136
|
+
# Check channel count matches Clay expectation
|
|
137
|
+
if chips.shape[1] != len(order):
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Channel count {chips.shape[1]} does not match expected {len(order)} for {self.modality}"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Time & latlon zeros are valid per Clay doc
|
|
143
|
+
# https://clay-foundation.github.io/model/getting-started/basic_use.html
|
|
144
|
+
datacube = {
|
|
145
|
+
"platform": self.modality,
|
|
146
|
+
"time": torch.zeros(chips.shape[0], 4).to(device),
|
|
147
|
+
"latlon": torch.zeros(chips.shape[0], 4).to(device),
|
|
148
|
+
"pixels": chips.to(device),
|
|
149
|
+
"gsd": torch.tensor(self.metadata[self.modality]["gsd"]).to(device),
|
|
150
|
+
"waves": torch.tensor(wavelengths).to(device),
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
tokens, *_ = self.model.model.encoder(datacube) # (B, 1 + N, D)
|
|
154
|
+
|
|
155
|
+
# Remove CLS token
|
|
156
|
+
spatial = tokens[:, 1:, :] # (B, N, D)
|
|
157
|
+
n_tokens = spatial.shape[1]
|
|
158
|
+
side = int(math.isqrt(n_tokens))
|
|
159
|
+
if chips.shape[2] != side * PATCH_SIZE or chips.shape[3] != side * PATCH_SIZE:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f"Input spatial size {(chips.shape[2], chips.shape[3])} is not compatible with patch size {PATCH_SIZE}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
|
|
165
|
+
return FeatureMaps([features])
|
|
166
|
+
|
|
167
|
+
def get_backbone_channels(self) -> list:
|
|
168
|
+
"""Return output channels of this model when used as a backbone."""
|
|
169
|
+
if self.model_size == ClaySize.LARGE:
|
|
170
|
+
depth = 1024
|
|
171
|
+
elif self.model_size == ClaySize.BASE:
|
|
172
|
+
depth = 768
|
|
173
|
+
else:
|
|
174
|
+
raise ValueError(f"Invalid model size: {self.model_size}")
|
|
175
|
+
return [(PATCH_SIZE, depth)]
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class ClayNormalize(Transform):
|
|
179
|
+
"""Normalize inputs using Clay metadata.
|
|
180
|
+
|
|
181
|
+
For Sentinel-1, the intensities should be converted to decibels.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(self, metadata_path: str = CLAY_METADATA_PATH) -> None:
|
|
185
|
+
"""Initialize ClayNormalize."""
|
|
186
|
+
super().__init__()
|
|
187
|
+
with open(metadata_path) as f:
|
|
188
|
+
metadata = yaml.safe_load(f)
|
|
189
|
+
normalizers = {}
|
|
190
|
+
for modality in CLAY_MODALITIES:
|
|
191
|
+
if modality not in metadata:
|
|
192
|
+
continue
|
|
193
|
+
modality_metadata = metadata[modality]
|
|
194
|
+
means = [
|
|
195
|
+
modality_metadata["bands"]["mean"][b]
|
|
196
|
+
for b in modality_metadata["band_order"]
|
|
197
|
+
]
|
|
198
|
+
stds = [
|
|
199
|
+
modality_metadata["bands"]["std"][b]
|
|
200
|
+
for b in modality_metadata["band_order"]
|
|
201
|
+
]
|
|
202
|
+
normalizers[modality] = Normalize(
|
|
203
|
+
mean=means,
|
|
204
|
+
std=stds,
|
|
205
|
+
selectors=[modality],
|
|
206
|
+
num_bands=len(means),
|
|
207
|
+
)
|
|
208
|
+
self.normalizers = torch.nn.ModuleDict(normalizers)
|
|
209
|
+
|
|
210
|
+
def apply_image(
|
|
211
|
+
self, image: torch.Tensor, means: list[float], stds: list[float]
|
|
212
|
+
) -> torch.Tensor:
|
|
213
|
+
"""Normalize the specified image with Clay normalization."""
|
|
214
|
+
x = image.float()
|
|
215
|
+
if x.shape[0] != len(means):
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"channel count {x.shape[0]} does not match provided band stats {len(means)}"
|
|
218
|
+
)
|
|
219
|
+
for c in range(x.shape[0]):
|
|
220
|
+
x[c] = (x[c] - means[c]) / stds[c]
|
|
221
|
+
return x
|
|
222
|
+
|
|
223
|
+
def forward(
|
|
224
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
225
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
226
|
+
"""Normalize the specified image with Clay normalization."""
|
|
227
|
+
for modality, normalizer in self.normalizers.items():
|
|
228
|
+
if modality not in input_dict:
|
|
229
|
+
continue
|
|
230
|
+
input_dict, target_dict = normalizer(input_dict, target_dict)
|
|
231
|
+
return input_dict, target_dict
|