rslearn 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +1 -22
- rslearn/data_sources/copernicus.py +6 -4
- rslearn/data_sources/eurocrops.py +246 -0
- rslearn/data_sources/local_files.py +11 -0
- rslearn/data_sources/openstreetmap.py +2 -4
- rslearn/dataset/dataset.py +4 -1
- rslearn/models/copernicusfm.py +216 -0
- rslearn/models/copernicusfm_src/__init__.py +1 -0
- rslearn/models/copernicusfm_src/aurora/area.py +50 -0
- rslearn/models/copernicusfm_src/aurora/fourier.py +134 -0
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +523 -0
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py +260 -0
- rslearn/models/copernicusfm_src/flexivit/utils.py +69 -0
- rslearn/models/copernicusfm_src/model_vit.py +348 -0
- rslearn/models/copernicusfm_src/util/pos_embed.py +216 -0
- rslearn/models/panopticon.py +167 -0
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +247 -0
- rslearn/models/presto/single_file_presto.py +932 -0
- rslearn/models/unet.py +15 -0
- rslearn/template_params.py +26 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/METADATA +4 -1
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/RECORD +27 -12
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/WHEEL +0 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
|
|
5
|
+
# This source code is licensed under the license found in the
|
|
6
|
+
# LICENSE file in the root directory of this source tree.
|
|
7
|
+
# --------------------------------------------------------
|
|
8
|
+
# Position embedding utils
|
|
9
|
+
# --------------------------------------------------------
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# --------------------------------------------------------
|
|
16
|
+
# 2D sine-cosine position embedding
|
|
17
|
+
# References:
|
|
18
|
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
|
19
|
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
|
20
|
+
# --------------------------------------------------------
|
|
21
|
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
|
22
|
+
"""grid_size: int of the grid height and width
|
|
23
|
+
return:
|
|
24
|
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
25
|
+
"""
|
|
26
|
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
|
27
|
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
|
28
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
29
|
+
grid = np.stack(grid, axis=0)
|
|
30
|
+
|
|
31
|
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
|
32
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
33
|
+
if cls_token:
|
|
34
|
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
35
|
+
return pos_embed
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|
39
|
+
assert embed_dim % 2 == 0
|
|
40
|
+
|
|
41
|
+
# use half of dimensions to encode grid_h
|
|
42
|
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
|
43
|
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
|
44
|
+
|
|
45
|
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
46
|
+
return emb
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
50
|
+
"""embed_dim: output dimension for each position
|
|
51
|
+
pos: a list of positions to be encoded: size (M,)
|
|
52
|
+
out: (M, D)
|
|
53
|
+
"""
|
|
54
|
+
assert embed_dim % 2 == 0
|
|
55
|
+
# omega = np.arange(embed_dim // 2, dtype=np.float) # numpy deprecated in 1.20
|
|
56
|
+
omega = np.arange(embed_dim // 2, dtype=float)
|
|
57
|
+
|
|
58
|
+
omega /= embed_dim / 2.0
|
|
59
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
60
|
+
|
|
61
|
+
pos = pos.reshape(-1) # (M,)
|
|
62
|
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
63
|
+
|
|
64
|
+
emb_sin = np.sin(out) # (M, D/2)
|
|
65
|
+
emb_cos = np.cos(out) # (M, D/2)
|
|
66
|
+
|
|
67
|
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
68
|
+
return emb
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# --------------------------------------------------------
|
|
72
|
+
# Interpolate position embeddings for high-resolution
|
|
73
|
+
# References:
|
|
74
|
+
# DeiT: https://github.com/facebookresearch/deit
|
|
75
|
+
# --------------------------------------------------------
|
|
76
|
+
def interpolate_pos_embed(model, checkpoint_model):
|
|
77
|
+
if "pos_embed" in checkpoint_model:
|
|
78
|
+
pos_embed_checkpoint = checkpoint_model["pos_embed"]
|
|
79
|
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
80
|
+
num_patches = model.patch_embed.num_patches
|
|
81
|
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
|
82
|
+
# height (== width) for the checkpoint position embedding
|
|
83
|
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
|
84
|
+
# height (== width) for the new position embedding
|
|
85
|
+
new_size = int(num_patches**0.5)
|
|
86
|
+
# class_token and dist_token are kept unchanged
|
|
87
|
+
if orig_size != new_size:
|
|
88
|
+
print(
|
|
89
|
+
"Position interpolate from %dx%d to %dx%d"
|
|
90
|
+
% (orig_size, orig_size, new_size, new_size)
|
|
91
|
+
)
|
|
92
|
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
93
|
+
# only the position tokens are interpolated
|
|
94
|
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
95
|
+
pos_tokens = pos_tokens.reshape(
|
|
96
|
+
-1, orig_size, orig_size, embedding_size
|
|
97
|
+
).permute(0, 3, 1, 2)
|
|
98
|
+
pos_tokens = torch.nn.functional.interpolate(
|
|
99
|
+
pos_tokens,
|
|
100
|
+
size=(new_size, new_size),
|
|
101
|
+
mode="bicubic",
|
|
102
|
+
align_corners=False,
|
|
103
|
+
)
|
|
104
|
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
105
|
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
106
|
+
checkpoint_model["pos_embed"] = new_pos_embed
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def interpolate_pos_embed_ofa(model, checkpoint_model):
|
|
110
|
+
if "pos_embed" in checkpoint_model:
|
|
111
|
+
pos_embed_dict = checkpoint_model["pos_embed"]
|
|
112
|
+
|
|
113
|
+
for key, pos_embed in pos_embed_dict.items():
|
|
114
|
+
pos_embed_checkpoint = pos_embed
|
|
115
|
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
116
|
+
num_patches = model.patch_embed[key].num_patches
|
|
117
|
+
num_extra_tokens = model.pos_embed[key].shape[-2] - num_patches
|
|
118
|
+
# height (== width) for the checkpoint position embedding
|
|
119
|
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
|
120
|
+
# height (== width) for the new position embedding
|
|
121
|
+
new_size = int(num_patches**0.5)
|
|
122
|
+
# class_token and dist_token are kept unchanged
|
|
123
|
+
if orig_size != new_size:
|
|
124
|
+
print(
|
|
125
|
+
"Position interpolate from %dx%d to %dx%d"
|
|
126
|
+
% (orig_size, orig_size, new_size, new_size)
|
|
127
|
+
)
|
|
128
|
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
129
|
+
# only the position tokens are interpolated
|
|
130
|
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
131
|
+
pos_tokens = pos_tokens.reshape(
|
|
132
|
+
-1, orig_size, orig_size, embedding_size
|
|
133
|
+
).permute(0, 3, 1, 2)
|
|
134
|
+
pos_tokens = torch.nn.functional.interpolate(
|
|
135
|
+
pos_tokens,
|
|
136
|
+
size=(new_size, new_size),
|
|
137
|
+
mode="bicubic",
|
|
138
|
+
align_corners=False,
|
|
139
|
+
)
|
|
140
|
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
141
|
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
142
|
+
checkpoint_model["pos_embed"][key] = new_pos_embed
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_2d_sincos_pos_embed_with_resolution(
|
|
146
|
+
embed_dim, grid_size, res, cls_token=False, device="cpu"
|
|
147
|
+
):
|
|
148
|
+
"""grid_size: int of the grid height and width
|
|
149
|
+
res: array of size n, representing the resolution of a pixel (say, in meters),
|
|
150
|
+
|
|
151
|
+
Return:
|
|
152
|
+
pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
153
|
+
"""
|
|
154
|
+
# res = torch.FloatTensor(res).to(device)
|
|
155
|
+
res = res.to(device)
|
|
156
|
+
grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
|
|
157
|
+
grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
|
|
158
|
+
grid = torch.meshgrid(
|
|
159
|
+
grid_w, grid_h, indexing="xy"
|
|
160
|
+
) # here h goes first,direction reversed for numpy
|
|
161
|
+
grid = torch.stack(grid, dim=0) # 2 x h x w
|
|
162
|
+
|
|
163
|
+
# grid = grid.reshape([2, 1, grid_size, grid_size])
|
|
164
|
+
grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
|
|
165
|
+
_, n, h, w = grid.shape
|
|
166
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid_torch(
|
|
167
|
+
embed_dim, grid
|
|
168
|
+
) # # (nxH*W, D/2)
|
|
169
|
+
pos_embed = pos_embed.reshape(n, h * w, embed_dim)
|
|
170
|
+
if cls_token:
|
|
171
|
+
pos_embed = torch.cat(
|
|
172
|
+
[
|
|
173
|
+
torch.zeros(
|
|
174
|
+
[n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device
|
|
175
|
+
),
|
|
176
|
+
pos_embed,
|
|
177
|
+
],
|
|
178
|
+
dim=1,
|
|
179
|
+
)
|
|
180
|
+
return pos_embed
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
|
|
184
|
+
assert embed_dim % 2 == 0
|
|
185
|
+
|
|
186
|
+
# use half of dimensions to encode grid_h
|
|
187
|
+
emb_h = get_1d_sincos_pos_embed_from_grid_torch(
|
|
188
|
+
embed_dim // 2, grid[0]
|
|
189
|
+
) # (H*W, D/2)
|
|
190
|
+
emb_w = get_1d_sincos_pos_embed_from_grid_torch(
|
|
191
|
+
embed_dim // 2, grid[1]
|
|
192
|
+
) # (H*W, D/2)
|
|
193
|
+
|
|
194
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
|
|
195
|
+
return emb
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
|
|
199
|
+
"""embed_dim: output dimension for each position
|
|
200
|
+
pos: a list of positions to be encoded: size (M,)
|
|
201
|
+
out: (M, D)
|
|
202
|
+
"""
|
|
203
|
+
assert embed_dim % 2 == 0
|
|
204
|
+
old_shape = pos
|
|
205
|
+
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
|
|
206
|
+
omega /= embed_dim / 2.0
|
|
207
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
208
|
+
|
|
209
|
+
pos = pos.reshape(-1) # (M,)
|
|
210
|
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
211
|
+
|
|
212
|
+
emb_sin = torch.sin(out) # (M, D/2)
|
|
213
|
+
emb_cos = torch.cos(out) # (M, D/2)
|
|
214
|
+
|
|
215
|
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
|
216
|
+
return emb
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Wrapper for the Panopticon model."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
from importlib import resources
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
import yaml
|
|
11
|
+
from einops import rearrange, repeat
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
from rslearn.log_utils import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PanopticonModalities(StrEnum):
|
|
20
|
+
"""Modalities supported by Panopticon.
|
|
21
|
+
|
|
22
|
+
These are the keys needed to load the yaml file from panopticon_data/sensors
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
SENTINEL2 = "sentinel2"
|
|
26
|
+
LANDSAT8 = "landsat8"
|
|
27
|
+
SENTINEL1 = "sentinel1"
|
|
28
|
+
# Add more modalities as needed
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Panopticon(nn.Module):
|
|
32
|
+
"""Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
|
|
33
|
+
|
|
34
|
+
patch_size: int = 14
|
|
35
|
+
base_image_size: int = 224
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
band_order: dict[str, list[str]],
|
|
40
|
+
torchhub_id: str = "panopticon_vitb14",
|
|
41
|
+
):
|
|
42
|
+
"""Initialize the Panopticon wrapper.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
band_order: The band order for the panopticon model, must match the specified order in the data config
|
|
46
|
+
torchhub_id: The torch hub model ID for panopticon
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
# Load the panopticon model
|
|
50
|
+
self._load_model(torchhub_id)
|
|
51
|
+
self.output_dim = self.model.embed_dim
|
|
52
|
+
self.band_order = band_order
|
|
53
|
+
self.supported_modalities = list(band_order.keys())
|
|
54
|
+
|
|
55
|
+
def _load_model(self, torchhub_id: str) -> None:
|
|
56
|
+
"""Load the panopticon model from torch hub."""
|
|
57
|
+
import time
|
|
58
|
+
|
|
59
|
+
# Hack to get around https://discuss.pytorch.org/t/torch-hub-load-gives-httperror-rate-limit-exceeded/124769
|
|
60
|
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
|
61
|
+
for attempt in range(2):
|
|
62
|
+
try:
|
|
63
|
+
self.model = torch.hub.load( # nosec B614
|
|
64
|
+
"panopticon-FM/panopticon",
|
|
65
|
+
torchhub_id,
|
|
66
|
+
)
|
|
67
|
+
break
|
|
68
|
+
except Exception as e:
|
|
69
|
+
logger.warning(
|
|
70
|
+
f"Error loading panopticon model: {e}. Retrying in 5 seconds..."
|
|
71
|
+
)
|
|
72
|
+
time.sleep(5)
|
|
73
|
+
else:
|
|
74
|
+
raise RuntimeError(
|
|
75
|
+
f"Failed to load panopticon model {torchhub_id} after retrying."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _process_modality_data(self, data: torch.Tensor) -> torch.Tensor:
|
|
79
|
+
"""Process individual modality data.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
data: Input tensor of shape [B, C, H, W]
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Processed tensor of shape [B, C, H, W]
|
|
86
|
+
"""
|
|
87
|
+
original_height = data.shape[2]
|
|
88
|
+
new_height = self.patch_size if original_height == 1 else self.base_image_size
|
|
89
|
+
|
|
90
|
+
data = F.interpolate(
|
|
91
|
+
data,
|
|
92
|
+
size=(new_height, new_height),
|
|
93
|
+
mode="bilinear",
|
|
94
|
+
align_corners=False,
|
|
95
|
+
)
|
|
96
|
+
return data
|
|
97
|
+
|
|
98
|
+
def _create_channel_ids(
|
|
99
|
+
self, modality: str, batch_size: int, device: torch.device
|
|
100
|
+
) -> torch.Tensor:
|
|
101
|
+
"""Create channel IDs for the panopticon model."""
|
|
102
|
+
with resources.open_text(
|
|
103
|
+
"rslearn.models.panopticon_data.sensors", f"{modality}.yaml"
|
|
104
|
+
) as f:
|
|
105
|
+
sensor_config = yaml.safe_load(f)
|
|
106
|
+
|
|
107
|
+
band_order = self.band_order[modality]
|
|
108
|
+
chn_ids = [
|
|
109
|
+
sensor_config["bands"][band.upper()]["gaussian"]["mu"]
|
|
110
|
+
for band in band_order
|
|
111
|
+
]
|
|
112
|
+
chn_ids = torch.tensor(chn_ids, dtype=torch.float32, device=device)
|
|
113
|
+
chn_ids = repeat(chn_ids, "c -> b c", b=batch_size)
|
|
114
|
+
return chn_ids
|
|
115
|
+
|
|
116
|
+
def prepare_input(
|
|
117
|
+
self, input_data: dict[str, torch.Tensor]
|
|
118
|
+
) -> dict[str, torch.Tensor]:
|
|
119
|
+
"""Prepare input for the panopticon model from MaskedHeliosSample."""
|
|
120
|
+
channel_ids_list: list[torch.Tensor] = []
|
|
121
|
+
processed_data_list: list[torch.Tensor] = []
|
|
122
|
+
for modality in self.supported_modalities:
|
|
123
|
+
if modality not in input_data.keys():
|
|
124
|
+
logger.debug(f"Modality {modality} not found in input data")
|
|
125
|
+
continue
|
|
126
|
+
data = input_data[modality]
|
|
127
|
+
device = data.device
|
|
128
|
+
processed_data = self._process_modality_data(data)
|
|
129
|
+
processed_data_list.append(processed_data)
|
|
130
|
+
batch_size = processed_data.shape[0]
|
|
131
|
+
chn_ids = self._create_channel_ids(modality, batch_size, device)
|
|
132
|
+
channel_ids_list.append(chn_ids)
|
|
133
|
+
|
|
134
|
+
processed_data = torch.cat(processed_data_list, dim=1)
|
|
135
|
+
chn_ids = torch.cat(channel_ids_list, dim=1)
|
|
136
|
+
return {
|
|
137
|
+
"imgs": processed_data,
|
|
138
|
+
"chn_ids": chn_ids,
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
142
|
+
"""Forward pass through the panopticon model."""
|
|
143
|
+
batch_inputs = {
|
|
144
|
+
key: torch.stack([inp[key] for inp in inputs], dim=0)
|
|
145
|
+
for key in inputs[0].keys()
|
|
146
|
+
}
|
|
147
|
+
panopticon_inputs = self.prepare_input(batch_inputs)
|
|
148
|
+
output_features = self.model.forward_features(panopticon_inputs)[
|
|
149
|
+
"x_norm_patchtokens"
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
num_tokens = output_features.shape[1]
|
|
153
|
+
height = int(math.sqrt(num_tokens))
|
|
154
|
+
output_features = rearrange(
|
|
155
|
+
output_features, "b (h w) d -> b d h w", h=height, w=height
|
|
156
|
+
)
|
|
157
|
+
return [output_features]
|
|
158
|
+
|
|
159
|
+
def get_backbone_channels(self) -> list:
|
|
160
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
161
|
+
|
|
162
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
163
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
164
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
165
|
+
has 32 channels.
|
|
166
|
+
"""
|
|
167
|
+
return [(self.patch_size, self.output_dim)]
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
"""Presto wrapper to ingest Masked Helios Samples."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import tempfile
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange, repeat
|
|
9
|
+
from huggingface_hub import hf_hub_download
|
|
10
|
+
from torch import nn
|
|
11
|
+
from upath import UPath
|
|
12
|
+
|
|
13
|
+
from rslearn.models.presto.single_file_presto import (
|
|
14
|
+
ERA5_BANDS,
|
|
15
|
+
NUM_DYNAMIC_WORLD_CLASSES,
|
|
16
|
+
PRESTO_ADD_BY,
|
|
17
|
+
PRESTO_BANDS,
|
|
18
|
+
PRESTO_DIV_BY,
|
|
19
|
+
PRESTO_S1_BANDS,
|
|
20
|
+
PRESTO_S2_BANDS,
|
|
21
|
+
SRTM_BANDS,
|
|
22
|
+
)
|
|
23
|
+
from rslearn.models.presto.single_file_presto import Presto as SFPresto
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
INPUT_PRESTO_BANDS = [b for b in PRESTO_BANDS if b != "B09"]
|
|
28
|
+
INPUT_PRESTO_S2_BANDS = [b for b in PRESTO_S2_BANDS if b != "B09"]
|
|
29
|
+
|
|
30
|
+
PRESTO_S1_SUBTRACT_VALUE = -25.0
|
|
31
|
+
PRESTO_S1_DIV_VALUE = 25.0
|
|
32
|
+
PRESTO_S2_SUBTRACT_VALUE = 0.0
|
|
33
|
+
PRESTO_S2_DIV_VALUE = 1e4
|
|
34
|
+
|
|
35
|
+
HF_HUB_ID = "nasaharvest/presto"
|
|
36
|
+
MODEL_FILENAME = "default_model.pt"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Presto(nn.Module):
|
|
40
|
+
"""Presto."""
|
|
41
|
+
|
|
42
|
+
input_keys = [
|
|
43
|
+
"s1",
|
|
44
|
+
"s2",
|
|
45
|
+
"era5",
|
|
46
|
+
"srtm",
|
|
47
|
+
"dynamic_world",
|
|
48
|
+
"latlon",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
pretrained_path: str | UPath | None = None,
|
|
54
|
+
pixel_batch_size: int = 128,
|
|
55
|
+
):
|
|
56
|
+
"""Initialize the Presto wrapper.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
pretrained_path: The directory to load from
|
|
60
|
+
pixel_batch_size: If the input has a h,w dimension >1, this is
|
|
61
|
+
flattened into a batch dimension (b h w) before being passed
|
|
62
|
+
to the model (since Presto is designed for pixel timeseries).
|
|
63
|
+
"""
|
|
64
|
+
super().__init__()
|
|
65
|
+
|
|
66
|
+
if pretrained_path is None:
|
|
67
|
+
pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "presto")
|
|
68
|
+
if not (UPath(pretrained_path) / MODEL_FILENAME).exists():
|
|
69
|
+
_ = hf_hub_download(
|
|
70
|
+
local_dir=UPath(pretrained_path),
|
|
71
|
+
repo_id=HF_HUB_ID,
|
|
72
|
+
filename=MODEL_FILENAME,
|
|
73
|
+
# pin the model to a specific hugging face commit
|
|
74
|
+
revision="1b97f885969da4e2d5834ca8c92707c737911464",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
model = SFPresto.construct()
|
|
78
|
+
model.load_state_dict(
|
|
79
|
+
torch.load(
|
|
80
|
+
UPath(pretrained_path) / MODEL_FILENAME,
|
|
81
|
+
map_location="cpu",
|
|
82
|
+
weights_only=True,
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
self.pixel_batch_size = pixel_batch_size
|
|
86
|
+
self.model = model.encoder
|
|
87
|
+
self.month = 6 # default month
|
|
88
|
+
|
|
89
|
+
def construct_presto_input(
|
|
90
|
+
self,
|
|
91
|
+
s1: torch.Tensor | None = None,
|
|
92
|
+
s1_bands: torch.Tensor | None = None,
|
|
93
|
+
s2: torch.Tensor | None = None,
|
|
94
|
+
s2_bands: torch.Tensor | None = None,
|
|
95
|
+
era5: torch.Tensor | None = None,
|
|
96
|
+
era5_bands: torch.Tensor | None = None,
|
|
97
|
+
srtm: torch.Tensor | None = None,
|
|
98
|
+
srtm_bands: torch.Tensor | None = None,
|
|
99
|
+
dynamic_world: torch.Tensor | None = None,
|
|
100
|
+
months: torch.Tensor | None = None,
|
|
101
|
+
normalize: bool = True,
|
|
102
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
103
|
+
"""Inputs are paired into a tensor input <X> and a list <X>_bands, which describes <X>.
|
|
104
|
+
|
|
105
|
+
<X> should have shape (b, num_timesteps, h, w len(<X>_bands)), with the following bands for
|
|
106
|
+
each input:
|
|
107
|
+
|
|
108
|
+
s1: ["VV", "VH"]
|
|
109
|
+
s2: ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"]
|
|
110
|
+
era5: ["temperature_2m", "total_precipitation"]
|
|
111
|
+
"temperature_2m": Temperature of air at 2m above the surface of land,
|
|
112
|
+
sea or in-land waters in Kelvin (K)
|
|
113
|
+
"total_precipitation": Accumulated liquid and frozen water, including rain and snow,
|
|
114
|
+
that falls to the Earth's surface. Measured in metres (m)
|
|
115
|
+
srtm: ["elevation", "slope"]
|
|
116
|
+
|
|
117
|
+
dynamic_world is a 1d input of shape (num_timesteps,) representing the dynamic world classes
|
|
118
|
+
of each timestep for that pixel
|
|
119
|
+
"""
|
|
120
|
+
bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
|
|
121
|
+
hs = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
|
|
122
|
+
ws = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
|
|
123
|
+
|
|
124
|
+
assert len(bs) > 0
|
|
125
|
+
assert len(set(bs)) == 1
|
|
126
|
+
assert len(set(hs)) == 1
|
|
127
|
+
assert len(set(ws)) == 1
|
|
128
|
+
b, h, w = bs[0], hs[0], ws[0]
|
|
129
|
+
|
|
130
|
+
# these values will be initialized as
|
|
131
|
+
# we iterate through the data
|
|
132
|
+
x: torch.Tensor | None = None
|
|
133
|
+
mask: torch.Tensor | None = None
|
|
134
|
+
t: int | None = None
|
|
135
|
+
|
|
136
|
+
for band_group in [
|
|
137
|
+
(s1, s1_bands),
|
|
138
|
+
(s2, s2_bands),
|
|
139
|
+
(era5, era5_bands),
|
|
140
|
+
(srtm, srtm_bands),
|
|
141
|
+
]:
|
|
142
|
+
data, input_bands = band_group
|
|
143
|
+
if data is not None:
|
|
144
|
+
assert input_bands is not None
|
|
145
|
+
else:
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
m_t = data.shape[1] // len(input_bands)
|
|
149
|
+
if t is None:
|
|
150
|
+
t = m_t
|
|
151
|
+
else:
|
|
152
|
+
if t != m_t:
|
|
153
|
+
raise ValueError("inconsistent values for t")
|
|
154
|
+
|
|
155
|
+
data = rearrange(data, "b (t c) h w -> b t h w c", t=m_t)
|
|
156
|
+
if x is None:
|
|
157
|
+
x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS))
|
|
158
|
+
if mask is None:
|
|
159
|
+
mask = torch.ones(b, t, h, w, len(INPUT_PRESTO_BANDS))
|
|
160
|
+
|
|
161
|
+
# construct a mapping from the input bands to the presto input bands
|
|
162
|
+
input_to_output_mapping = [
|
|
163
|
+
INPUT_PRESTO_BANDS.index(val) for val in input_bands
|
|
164
|
+
]
|
|
165
|
+
x[:, :, :, :, input_to_output_mapping] = data
|
|
166
|
+
mask[:, :, :, :, input_to_output_mapping] = 0
|
|
167
|
+
|
|
168
|
+
assert x is not None
|
|
169
|
+
assert mask is not None
|
|
170
|
+
assert t is not None
|
|
171
|
+
|
|
172
|
+
if dynamic_world is None:
|
|
173
|
+
dynamic_world = torch.ones(b, t, h, w) * NUM_DYNAMIC_WORLD_CLASSES
|
|
174
|
+
|
|
175
|
+
if months is None:
|
|
176
|
+
months = torch.ones((b, t), device=x.device) * self.month
|
|
177
|
+
else:
|
|
178
|
+
assert months.shape[-1] == t
|
|
179
|
+
|
|
180
|
+
if normalize:
|
|
181
|
+
x = (x + PRESTO_ADD_BY) / PRESTO_DIV_BY
|
|
182
|
+
return x, mask, dynamic_world.long(), months.long()
|
|
183
|
+
|
|
184
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
185
|
+
"""Compute feature maps from the Presto backbone.
|
|
186
|
+
|
|
187
|
+
Inputs:
|
|
188
|
+
inputs
|
|
189
|
+
"""
|
|
190
|
+
stacked_inputs = {}
|
|
191
|
+
latlons: torch.Tensor | None = None
|
|
192
|
+
for key in inputs[0].keys():
|
|
193
|
+
# assume all the keys in an input are consistent
|
|
194
|
+
if key in self.input_keys:
|
|
195
|
+
if key == "latlon":
|
|
196
|
+
latlons = torch.stack([inp[key] for inp in inputs], dim=0)
|
|
197
|
+
else:
|
|
198
|
+
stacked_inputs[key] = torch.stack(
|
|
199
|
+
[inp[key] for inp in inputs], dim=0
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
(
|
|
203
|
+
x,
|
|
204
|
+
mask,
|
|
205
|
+
dynamic_world,
|
|
206
|
+
months,
|
|
207
|
+
) = self.construct_presto_input(
|
|
208
|
+
**stacked_inputs,
|
|
209
|
+
s1_bands=PRESTO_S1_BANDS,
|
|
210
|
+
s2_bands=INPUT_PRESTO_S2_BANDS,
|
|
211
|
+
era5_bands=ERA5_BANDS,
|
|
212
|
+
srtm_bands=SRTM_BANDS,
|
|
213
|
+
normalize=True,
|
|
214
|
+
)
|
|
215
|
+
b, _, h, w, _ = x.shape
|
|
216
|
+
|
|
217
|
+
output_features = torch.zeros(
|
|
218
|
+
b * h * w, self.model.embedding_size, device=x.device
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
x = rearrange(x, "b t h w d -> (b h w) t d")
|
|
222
|
+
mask = rearrange(mask, "b t h w d -> (b h w) t d")
|
|
223
|
+
dynamic_world = rearrange(dynamic_world, "b t h w -> (b h w) t")
|
|
224
|
+
months = repeat(months, "b t -> (b h w) t", h=h, w=w)
|
|
225
|
+
if latlons is not None:
|
|
226
|
+
latlons = rearrange(latlons, "b c h w -> (b h w) c")
|
|
227
|
+
|
|
228
|
+
for batch_idx in range(0, b * h * w, self.pixel_batch_size):
|
|
229
|
+
x_b = x[batch_idx : batch_idx + self.pixel_batch_size]
|
|
230
|
+
mask_b = mask[batch_idx : batch_idx + self.pixel_batch_size]
|
|
231
|
+
dw = dynamic_world[batch_idx : batch_idx + self.pixel_batch_size]
|
|
232
|
+
months_b = months[batch_idx : batch_idx + self.pixel_batch_size]
|
|
233
|
+
if latlons is not None:
|
|
234
|
+
l_b = latlons[batch_idx : batch_idx + self.pixel_batch_size]
|
|
235
|
+
else:
|
|
236
|
+
l_b = None
|
|
237
|
+
output_b = self.model(
|
|
238
|
+
x=x_b,
|
|
239
|
+
dynamic_world=dw,
|
|
240
|
+
mask=mask_b,
|
|
241
|
+
month=months_b,
|
|
242
|
+
latlons=l_b,
|
|
243
|
+
eval_task=True,
|
|
244
|
+
)
|
|
245
|
+
output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
|
|
246
|
+
|
|
247
|
+
return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
|