rslearn 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/models/anysat.py +207 -0
- rslearn/models/clay/clay.py +204 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +517 -0
- rslearn/models/galileo/single_file_galileo.py +1672 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/presto/presto.py +10 -7
- rslearn/models/prithvi.py +1046 -0
- rslearn/models/unet.py +17 -11
- rslearn/utils/geometry.py +61 -1
- rslearn/utils/vector_format.py +13 -10
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/METADATA +145 -15
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/RECORD +29 -10
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/WHEEL +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1046 @@
|
|
|
1
|
+
"""Prithvi V2."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import tempfile
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import yaml
|
|
12
|
+
from einops import rearrange
|
|
13
|
+
from huggingface_hub import hf_hub_download
|
|
14
|
+
from timm.layers import to_2tuple
|
|
15
|
+
from timm.models.vision_transformer import Block
|
|
16
|
+
from torch.nn import functional as F
|
|
17
|
+
from upath import UPath
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# for Prithvi, true values are ["B02", "B03", "B04", "B05", "B06", "B07"]
|
|
23
|
+
PRITHVI_MEAN = [
|
|
24
|
+
1087.0,
|
|
25
|
+
1342.0,
|
|
26
|
+
1433.0,
|
|
27
|
+
2734.0,
|
|
28
|
+
1958.0,
|
|
29
|
+
1363.0,
|
|
30
|
+
]
|
|
31
|
+
PRITHVI_STD = [
|
|
32
|
+
2248.0,
|
|
33
|
+
2179.0,
|
|
34
|
+
2178.0,
|
|
35
|
+
1850.0,
|
|
36
|
+
1242.0,
|
|
37
|
+
1049.0,
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
HF_HUB_ID = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class PrithviV2(nn.Module):
|
|
45
|
+
"""An Rslearn wrapper for Prithvi 2.0."""
|
|
46
|
+
|
|
47
|
+
input_keys = ["sentinel2"]
|
|
48
|
+
|
|
49
|
+
def __init__(self, pretrained_path: str | UPath | None = None, num_frames: int = 1):
|
|
50
|
+
"""Init.
|
|
51
|
+
|
|
52
|
+
Inputs:
|
|
53
|
+
pretrained_path: The folder in which to download the prithvi config
|
|
54
|
+
and weights. If None, it downloads to a temporary folder.
|
|
55
|
+
num_frames: The number of input frames (timesteps). The model was trained on 3,
|
|
56
|
+
but if there is just one timestamp examples use 1 (e.g.
|
|
57
|
+
https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/
|
|
58
|
+
example_landslide4sense.ipynb)
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
super().__init__()
|
|
62
|
+
if pretrained_path is None:
|
|
63
|
+
pretrained_path = UPath(
|
|
64
|
+
tempfile.gettempdir(), "rslearn_cache", "prithvi_v2"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if not (UPath(pretrained_path) / "config.json").exists():
|
|
68
|
+
_ = hf_hub_download(
|
|
69
|
+
local_dir=pretrained_path,
|
|
70
|
+
repo_id=HF_HUB_ID,
|
|
71
|
+
filename="config.json",
|
|
72
|
+
revision="b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
|
|
73
|
+
)
|
|
74
|
+
with (UPath(pretrained_path) / "config.json").open("r") as f:
|
|
75
|
+
config = yaml.safe_load(f)["pretrained_cfg"]
|
|
76
|
+
|
|
77
|
+
config["num_frames"] = num_frames
|
|
78
|
+
|
|
79
|
+
self.model = PrithviMAE(**config)
|
|
80
|
+
|
|
81
|
+
if not (UPath(pretrained_path) / "Prithvi_EO_V2_300M.pt").exists():
|
|
82
|
+
_ = hf_hub_download(
|
|
83
|
+
local_dir=pretrained_path,
|
|
84
|
+
repo_id=HF_HUB_ID,
|
|
85
|
+
filename="Prithvi_EO_V2_300M.pt",
|
|
86
|
+
revision="b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
state_dict = torch.load(
|
|
90
|
+
UPath(pretrained_path) / "Prithvi_EO_V2_300M.pt",
|
|
91
|
+
map_location="cpu",
|
|
92
|
+
weights_only=True,
|
|
93
|
+
)
|
|
94
|
+
# discard fixed pos_embedding weight, following
|
|
95
|
+
# https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M/blob/e4aabdc440c8ee703a749def8af5bf4700dee35b/inference.py#L362
|
|
96
|
+
for k in list(state_dict.keys()):
|
|
97
|
+
if "pos_embed" in k:
|
|
98
|
+
del state_dict[k]
|
|
99
|
+
self.model.load_state_dict(state_dict, strict=False)
|
|
100
|
+
self.image_resolution = config["img_size"]
|
|
101
|
+
self.bands = config["bands"]
|
|
102
|
+
# patch size is a list [t, h, w], where h == w
|
|
103
|
+
self.patch_size = config["patch_size"][-1]
|
|
104
|
+
|
|
105
|
+
def _resize_data(self, data: torch.Tensor) -> torch.Tensor:
|
|
106
|
+
"""Process individual modality data.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
data: Input tensor of shape [B, C, H, W]
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
list of tensors of shape [B, C, H, W]
|
|
113
|
+
"""
|
|
114
|
+
# Get original dimensions
|
|
115
|
+
original_height = data.shape[2]
|
|
116
|
+
new_height = self.patch_size if original_height == 1 else self.image_resolution
|
|
117
|
+
data = F.interpolate(
|
|
118
|
+
data,
|
|
119
|
+
size=(new_height, new_height),
|
|
120
|
+
mode="bilinear",
|
|
121
|
+
align_corners=False,
|
|
122
|
+
)
|
|
123
|
+
return data
|
|
124
|
+
|
|
125
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
126
|
+
"""Compute feature maps from the Prithvi V2 backbone.
|
|
127
|
+
|
|
128
|
+
Inputs:
|
|
129
|
+
inputs: input dicts that must include "sentinel2"
|
|
130
|
+
keys depending. Prithvi is designed for HLS (Harmonized Landsat-Sentinel);
|
|
131
|
+
this naming keeps the model consistent with other rslearn models.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
11 feature maps (one per transformer block in the Prithvi model),
|
|
135
|
+
of shape [B, H/p_s, W/p_s, D=1024] where p_s=16 is the patch size.
|
|
136
|
+
"""
|
|
137
|
+
x = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
|
|
138
|
+
x = self._resize_data(x)
|
|
139
|
+
num_timesteps = x.shape[1] // len(self.bands)
|
|
140
|
+
x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
|
|
141
|
+
features = self.model.encoder.forward_features(x)
|
|
142
|
+
# prepare_features_for_image_model was slightly modified since we already
|
|
143
|
+
# know the number of timesteps and don't need to recompute it.
|
|
144
|
+
# in addition we average along the time dimension (instead of concatenating)
|
|
145
|
+
# to keep the embeddings reasonably sized.
|
|
146
|
+
return self.model.encoder.prepare_features_for_image_model(
|
|
147
|
+
features, num_timesteps
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# Copyright (c) IBM Corp. 2024. All rights reserved.
|
|
152
|
+
#
|
|
153
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
154
|
+
# you may not use this file except in compliance with the License.
|
|
155
|
+
# You may obtain a copy of the License at
|
|
156
|
+
#
|
|
157
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
158
|
+
#
|
|
159
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
160
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
161
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
162
|
+
# See the License for the specific language governing permissions and
|
|
163
|
+
# limitations under the License.
|
|
164
|
+
# --------------------------------------------------------
|
|
165
|
+
# References:
|
|
166
|
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
|
167
|
+
# transformers: https://github.com/huggingface/transformers
|
|
168
|
+
# --------------------------------------------------------
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def get_3d_sincos_pos_embed(
|
|
172
|
+
embed_dim: int,
|
|
173
|
+
grid_size: tuple[int, int, int] | list[int],
|
|
174
|
+
add_cls_token: bool = False,
|
|
175
|
+
) -> torch.Tensor:
|
|
176
|
+
"""Create 3D sin/cos positional embeddings.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
embed_dim (int):
|
|
180
|
+
Embedding dimension.
|
|
181
|
+
grid_size (tuple[int, int, int] | list[int]):
|
|
182
|
+
The grid depth, height and width.
|
|
183
|
+
add_cls_token (bool, *optional*, defaults to False):
|
|
184
|
+
Whether or not to add a classification (CLS) token.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
|
|
188
|
+
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
|
|
189
|
+
"""
|
|
190
|
+
assert embed_dim % 16 == 0
|
|
191
|
+
|
|
192
|
+
t_size, h_size, w_size = grid_size
|
|
193
|
+
|
|
194
|
+
w_embed_dim = embed_dim // 16 * 6
|
|
195
|
+
h_embed_dim = embed_dim // 16 * 6
|
|
196
|
+
t_embed_dim = embed_dim // 16 * 4
|
|
197
|
+
|
|
198
|
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
|
199
|
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
|
200
|
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
|
201
|
+
|
|
202
|
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
|
203
|
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
|
204
|
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
|
205
|
+
|
|
206
|
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
|
207
|
+
|
|
208
|
+
if add_cls_token:
|
|
209
|
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
210
|
+
return pos_embed
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_1d_sincos_pos_embed_from_grid(
|
|
214
|
+
embed_dim: int, pos: torch.Tensor
|
|
215
|
+
) -> torch.Tensor:
|
|
216
|
+
"""embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)."""
|
|
217
|
+
if embed_dim % 2 != 0:
|
|
218
|
+
raise ValueError("embed_dim must be even")
|
|
219
|
+
|
|
220
|
+
omega = np.arange(embed_dim // 2, dtype=float)
|
|
221
|
+
omega /= embed_dim / 2.0
|
|
222
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
223
|
+
|
|
224
|
+
pos = pos.reshape(-1) # (M,)
|
|
225
|
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
226
|
+
|
|
227
|
+
emb_sin = np.sin(out) # (M, D/2)
|
|
228
|
+
emb_cos = np.cos(out) # (M, D/2)
|
|
229
|
+
|
|
230
|
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
231
|
+
return emb
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _get_1d_sincos_embed_from_grid_torch(
|
|
235
|
+
embed_dim: int, pos: torch.Tensor
|
|
236
|
+
) -> torch.Tensor:
|
|
237
|
+
"""Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
|
|
238
|
+
|
|
239
|
+
embed_dim: output dimension for each position
|
|
240
|
+
pos: a list of positions to be encoded: size (M,) - must be float dtype!
|
|
241
|
+
out: (M, D)
|
|
242
|
+
"""
|
|
243
|
+
assert embed_dim % 2 == 0
|
|
244
|
+
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
245
|
+
|
|
246
|
+
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
|
|
247
|
+
omega /= embed_dim / 2.0
|
|
248
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
249
|
+
|
|
250
|
+
pos = pos.reshape(-1) # (M,)
|
|
251
|
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
252
|
+
|
|
253
|
+
emb_sin = torch.sin(out) # (M, D/2)
|
|
254
|
+
emb_cos = torch.cos(out) # (M, D/2)
|
|
255
|
+
|
|
256
|
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
|
257
|
+
|
|
258
|
+
return emb
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _init_weights(module: nn.Module) -> None:
|
|
262
|
+
"""Initialize the weights."""
|
|
263
|
+
if isinstance(module, nn.Linear):
|
|
264
|
+
nn.init.xavier_uniform_(module.weight)
|
|
265
|
+
if module.bias is not None:
|
|
266
|
+
module.bias.data.zero_()
|
|
267
|
+
elif isinstance(module, nn.LayerNorm):
|
|
268
|
+
module.bias.data.zero_()
|
|
269
|
+
module.weight.data.fill_(1.0)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _interpolate_pos_encoding(
|
|
273
|
+
pos_embed: torch.Tensor,
|
|
274
|
+
grid_size: tuple[int, int, int] | list[int],
|
|
275
|
+
patch_size: tuple[int, int, int] | list[int],
|
|
276
|
+
shape: tuple[int, int, int] | list[int],
|
|
277
|
+
embed_dim: int,
|
|
278
|
+
) -> torch.Tensor:
|
|
279
|
+
"""_interpolate_pos_encoding.
|
|
280
|
+
|
|
281
|
+
Adapted from:
|
|
282
|
+
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
|
|
283
|
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
|
|
284
|
+
"""
|
|
285
|
+
t, h, w = shape
|
|
286
|
+
t_patches = t // patch_size[0]
|
|
287
|
+
h_patches = h // patch_size[1]
|
|
288
|
+
w_patches = w // patch_size[2]
|
|
289
|
+
|
|
290
|
+
if [t_patches, h_patches, w_patches] == grid_size:
|
|
291
|
+
# No interpolation needed
|
|
292
|
+
return pos_embed
|
|
293
|
+
if t_patches != grid_size[0]:
|
|
294
|
+
# Re-compute pos embedding to handle changed num_frames
|
|
295
|
+
new_grid_size = (t_patches, *grid_size[1:])
|
|
296
|
+
new_pos_embed = get_3d_sincos_pos_embed(
|
|
297
|
+
pos_embed.shape[-1], new_grid_size, add_cls_token=True
|
|
298
|
+
)
|
|
299
|
+
new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
|
|
300
|
+
else:
|
|
301
|
+
new_grid_size = grid_size # type: ignore
|
|
302
|
+
new_pos_embed = pos_embed
|
|
303
|
+
|
|
304
|
+
class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
|
|
305
|
+
|
|
306
|
+
patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(
|
|
307
|
+
0, 3, 1, 2
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
patch_pos_embed = nn.functional.interpolate(
|
|
311
|
+
patch_pos_embed,
|
|
312
|
+
size=(h_patches, w_patches),
|
|
313
|
+
mode="bicubic",
|
|
314
|
+
align_corners=True,
|
|
315
|
+
)
|
|
316
|
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
|
|
317
|
+
|
|
318
|
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class PatchEmbed(nn.Module):
|
|
322
|
+
"""3D version of timm.models.vision_transformer.PatchEmbed."""
|
|
323
|
+
|
|
324
|
+
def __init__(
|
|
325
|
+
self,
|
|
326
|
+
input_size: tuple[int, int, int] = (1, 224, 224),
|
|
327
|
+
patch_size: tuple[int, int, int] = (1, 16, 16),
|
|
328
|
+
in_chans: int = 3,
|
|
329
|
+
embed_dim: int = 768,
|
|
330
|
+
norm_layer: nn.Module | None = None,
|
|
331
|
+
flatten: bool = True,
|
|
332
|
+
bias: bool = True,
|
|
333
|
+
) -> None:
|
|
334
|
+
"""Init."""
|
|
335
|
+
super().__init__()
|
|
336
|
+
self.input_size = input_size
|
|
337
|
+
self.patch_size = patch_size
|
|
338
|
+
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
|
|
339
|
+
assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
|
|
340
|
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
|
341
|
+
self.flatten = flatten
|
|
342
|
+
|
|
343
|
+
self.proj = nn.Conv3d(
|
|
344
|
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
|
345
|
+
)
|
|
346
|
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
347
|
+
|
|
348
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
349
|
+
"""Forward."""
|
|
350
|
+
B, C, T, H, W = x.shape
|
|
351
|
+
|
|
352
|
+
if (
|
|
353
|
+
T / self.patch_size[0] % 1
|
|
354
|
+
or H / self.patch_size[1] % 1
|
|
355
|
+
or W / self.patch_size[2] % 1
|
|
356
|
+
):
|
|
357
|
+
warnings.warn(
|
|
358
|
+
f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
|
|
359
|
+
f"The border will be ignored, add backbone_padding for pixel-wise tasks."
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
x = self.proj(x)
|
|
363
|
+
if self.flatten:
|
|
364
|
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
|
365
|
+
x = self.norm(x)
|
|
366
|
+
return x
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class TemporalEncoder(nn.Module):
|
|
370
|
+
"""TemporalEncoder."""
|
|
371
|
+
|
|
372
|
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
|
373
|
+
"""Init."""
|
|
374
|
+
super().__init__()
|
|
375
|
+
self.embed_dim = embed_dim
|
|
376
|
+
self.year_embed_dim = embed_dim // 2
|
|
377
|
+
self.julian_day_embed_dim = embed_dim - self.year_embed_dim
|
|
378
|
+
|
|
379
|
+
# If trainable, initialize scale with small number
|
|
380
|
+
if trainable_scale:
|
|
381
|
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
|
382
|
+
else:
|
|
383
|
+
self.register_buffer("scale", torch.ones(1))
|
|
384
|
+
|
|
385
|
+
def forward(
|
|
386
|
+
self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None
|
|
387
|
+
) -> torch.Tensor:
|
|
388
|
+
"""Forward.
|
|
389
|
+
|
|
390
|
+
temporal_coords: year and day-of-year info with shape (B, T, 2).
|
|
391
|
+
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
|
|
392
|
+
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
|
|
393
|
+
"""
|
|
394
|
+
shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
|
|
395
|
+
|
|
396
|
+
year = _get_1d_sincos_embed_from_grid_torch(
|
|
397
|
+
self.year_embed_dim, temporal_coords[:, :, 0].flatten()
|
|
398
|
+
).reshape(shape)
|
|
399
|
+
julian_day = _get_1d_sincos_embed_from_grid_torch(
|
|
400
|
+
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()
|
|
401
|
+
).reshape(shape)
|
|
402
|
+
|
|
403
|
+
embedding = self.scale * torch.cat([year, julian_day], dim=-1)
|
|
404
|
+
|
|
405
|
+
if tokens_per_frame is not None:
|
|
406
|
+
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
|
|
407
|
+
|
|
408
|
+
return embedding # B, T*tokens_per_frame, embed_dim
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class LocationEncoder(nn.Module):
|
|
412
|
+
"""LocationEncoder."""
|
|
413
|
+
|
|
414
|
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
|
415
|
+
"""Init."""
|
|
416
|
+
super().__init__()
|
|
417
|
+
self.embed_dim = embed_dim
|
|
418
|
+
self.lat_embed_dim = embed_dim // 2
|
|
419
|
+
self.lon_embed_dim = embed_dim - self.lat_embed_dim
|
|
420
|
+
|
|
421
|
+
# If trainable, initialize scale with small number
|
|
422
|
+
if trainable_scale:
|
|
423
|
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
|
424
|
+
else:
|
|
425
|
+
self.register_buffer("scale", torch.ones(1))
|
|
426
|
+
|
|
427
|
+
def forward(self, location_coords: torch.Tensor) -> torch.Tensor:
|
|
428
|
+
"""location_coords: lat and lon info with shape (B, 2)."""
|
|
429
|
+
shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
|
|
430
|
+
|
|
431
|
+
lat = _get_1d_sincos_embed_from_grid_torch(
|
|
432
|
+
self.lat_embed_dim, location_coords[:, 0].flatten()
|
|
433
|
+
).reshape(shape)
|
|
434
|
+
lon = _get_1d_sincos_embed_from_grid_torch(
|
|
435
|
+
self.lon_embed_dim, location_coords[:, 1].flatten()
|
|
436
|
+
).reshape(shape)
|
|
437
|
+
|
|
438
|
+
embedding = self.scale * torch.cat([lat, lon], dim=-1)
|
|
439
|
+
|
|
440
|
+
return embedding # B, 1, embed_dim
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class PrithviViT(nn.Module):
|
|
444
|
+
"""Prithvi ViT Encoder."""
|
|
445
|
+
|
|
446
|
+
def __init__(
|
|
447
|
+
self,
|
|
448
|
+
img_size: int | tuple[int, int] = 224,
|
|
449
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
450
|
+
num_frames: int = 1,
|
|
451
|
+
in_chans: int = 3,
|
|
452
|
+
embed_dim: int = 1024,
|
|
453
|
+
depth: int = 24,
|
|
454
|
+
num_heads: int = 16,
|
|
455
|
+
mlp_ratio: float = 4.0,
|
|
456
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
457
|
+
coords_encoding: list[str] | None = None,
|
|
458
|
+
coords_scale_learn: bool = False,
|
|
459
|
+
drop_path: float = 0.0,
|
|
460
|
+
**kwargs: Any,
|
|
461
|
+
) -> None:
|
|
462
|
+
"""Init."""
|
|
463
|
+
super().__init__()
|
|
464
|
+
|
|
465
|
+
self.in_chans = in_chans
|
|
466
|
+
self.num_frames = num_frames
|
|
467
|
+
self.embed_dim = embed_dim
|
|
468
|
+
self.img_size = to_2tuple(img_size)
|
|
469
|
+
if isinstance(patch_size, int):
|
|
470
|
+
patch_size = (1, patch_size, patch_size)
|
|
471
|
+
|
|
472
|
+
# 3D patch embedding
|
|
473
|
+
self.patch_embed = PatchEmbed(
|
|
474
|
+
input_size=(num_frames,) + self.img_size,
|
|
475
|
+
patch_size=patch_size,
|
|
476
|
+
in_chans=in_chans,
|
|
477
|
+
embed_dim=embed_dim,
|
|
478
|
+
)
|
|
479
|
+
self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
|
|
480
|
+
|
|
481
|
+
# Optional temporal and location embedding
|
|
482
|
+
coords_encoding = coords_encoding or []
|
|
483
|
+
self.temporal_encoding = "time" in coords_encoding
|
|
484
|
+
self.location_encoding = "location" in coords_encoding
|
|
485
|
+
if self.temporal_encoding:
|
|
486
|
+
assert patch_size[0] == 1, (
|
|
487
|
+
f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
|
|
488
|
+
)
|
|
489
|
+
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
|
|
490
|
+
if self.location_encoding:
|
|
491
|
+
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
|
|
492
|
+
|
|
493
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
494
|
+
self.register_buffer(
|
|
495
|
+
"pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# Transformer layers
|
|
499
|
+
self.blocks = []
|
|
500
|
+
for i in range(depth):
|
|
501
|
+
self.blocks.append(
|
|
502
|
+
Block(
|
|
503
|
+
embed_dim,
|
|
504
|
+
num_heads,
|
|
505
|
+
mlp_ratio,
|
|
506
|
+
qkv_bias=True,
|
|
507
|
+
norm_layer=norm_layer,
|
|
508
|
+
drop_path=drop_path,
|
|
509
|
+
)
|
|
510
|
+
)
|
|
511
|
+
self.blocks = nn.ModuleList(self.blocks)
|
|
512
|
+
|
|
513
|
+
self.norm = norm_layer(embed_dim)
|
|
514
|
+
|
|
515
|
+
self.initialize_weights()
|
|
516
|
+
|
|
517
|
+
def initialize_weights(self) -> None:
|
|
518
|
+
"""initialize_weights."""
|
|
519
|
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
|
520
|
+
pos_embed = get_3d_sincos_pos_embed(
|
|
521
|
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
|
|
522
|
+
)
|
|
523
|
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
|
524
|
+
|
|
525
|
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
|
526
|
+
w = self.patch_embed.proj.weight.data
|
|
527
|
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
528
|
+
|
|
529
|
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
|
530
|
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
|
531
|
+
self.apply(_init_weights)
|
|
532
|
+
|
|
533
|
+
def random_masking(
|
|
534
|
+
self,
|
|
535
|
+
sequence: torch.Tensor,
|
|
536
|
+
mask_ratio: float,
|
|
537
|
+
noise: None | torch.Tensor = None,
|
|
538
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
539
|
+
"""Perform per-sample random masking by per-sample shuffling.
|
|
540
|
+
|
|
541
|
+
Per-sample shuffling is done by argsort random
|
|
542
|
+
noise.
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
sequence: (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
|
|
546
|
+
mask_ratio: (float): mask ratio to use.
|
|
547
|
+
noise: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
|
548
|
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
|
549
|
+
"""
|
|
550
|
+
batch_size, seq_length, dim = sequence.shape
|
|
551
|
+
len_keep = int(seq_length * (1 - mask_ratio))
|
|
552
|
+
|
|
553
|
+
if noise is None:
|
|
554
|
+
noise = torch.rand(
|
|
555
|
+
batch_size, seq_length, device=sequence.device
|
|
556
|
+
) # noise in [0, 1]
|
|
557
|
+
|
|
558
|
+
# sort noise for each sample
|
|
559
|
+
ids_shuffle = torch.argsort(noise, dim=1).to(
|
|
560
|
+
sequence.device
|
|
561
|
+
) # ascend: small is keep, large is remove
|
|
562
|
+
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
|
|
563
|
+
|
|
564
|
+
# keep the first subset
|
|
565
|
+
ids_keep = ids_shuffle[:, :len_keep]
|
|
566
|
+
sequence_unmasked = torch.gather(
|
|
567
|
+
sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# generate the binary mask: 0 is keep, 1 is remove
|
|
571
|
+
mask = torch.ones([batch_size, seq_length], device=sequence.device)
|
|
572
|
+
mask[:, :len_keep] = 0
|
|
573
|
+
# unshuffle to get the binary mask
|
|
574
|
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
|
575
|
+
|
|
576
|
+
return sequence_unmasked, mask, ids_restore
|
|
577
|
+
|
|
578
|
+
def interpolate_pos_encoding(
|
|
579
|
+
self, sample_shape: tuple[int, int, int] | list[int]
|
|
580
|
+
) -> torch.Tensor:
|
|
581
|
+
"""interpolate_pos_encoding."""
|
|
582
|
+
pos_embed = _interpolate_pos_encoding(
|
|
583
|
+
pos_embed=self.pos_embed,
|
|
584
|
+
grid_size=self.patch_embed.grid_size,
|
|
585
|
+
patch_size=self.patch_embed.patch_size,
|
|
586
|
+
shape=sample_shape,
|
|
587
|
+
embed_dim=self.embed_dim,
|
|
588
|
+
)
|
|
589
|
+
return pos_embed
|
|
590
|
+
|
|
591
|
+
def forward(
|
|
592
|
+
self,
|
|
593
|
+
x: torch.Tensor,
|
|
594
|
+
temporal_coords: None | torch.Tensor = None,
|
|
595
|
+
location_coords: None | torch.Tensor = None,
|
|
596
|
+
mask_ratio: float = 0.75,
|
|
597
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
598
|
+
"""Forward."""
|
|
599
|
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
|
600
|
+
# add time dim
|
|
601
|
+
x = x.unsqueeze(2)
|
|
602
|
+
sample_shape = x.shape[-3:]
|
|
603
|
+
|
|
604
|
+
# embed patches
|
|
605
|
+
x = self.patch_embed(x)
|
|
606
|
+
|
|
607
|
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
|
608
|
+
# add pos embed w/o cls token
|
|
609
|
+
x = x + pos_embed[:, 1:, :]
|
|
610
|
+
|
|
611
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
612
|
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
|
613
|
+
temporal_encoding = self.temporal_embed_enc(
|
|
614
|
+
temporal_coords, num_tokens_per_frame
|
|
615
|
+
)
|
|
616
|
+
x = x + temporal_encoding
|
|
617
|
+
if self.location_encoding and location_coords is not None:
|
|
618
|
+
location_encoding = self.location_embed_enc(location_coords)
|
|
619
|
+
x = x + location_encoding
|
|
620
|
+
|
|
621
|
+
# masking: length -> length * mask_ratio
|
|
622
|
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
|
623
|
+
|
|
624
|
+
# append cls token
|
|
625
|
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
|
626
|
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
627
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
628
|
+
|
|
629
|
+
# apply Transformer blocks
|
|
630
|
+
for block in self.blocks:
|
|
631
|
+
x = block(x)
|
|
632
|
+
x = self.norm(x)
|
|
633
|
+
|
|
634
|
+
return x, mask, ids_restore
|
|
635
|
+
|
|
636
|
+
def forward_features(
|
|
637
|
+
self,
|
|
638
|
+
x: torch.Tensor,
|
|
639
|
+
temporal_coords: None | torch.Tensor = None,
|
|
640
|
+
location_coords: None | torch.Tensor = None,
|
|
641
|
+
) -> list[torch.Tensor]:
|
|
642
|
+
"""forward_features."""
|
|
643
|
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
|
644
|
+
# add time dim
|
|
645
|
+
x = x.unsqueeze(2)
|
|
646
|
+
sample_shape = x.shape[-3:]
|
|
647
|
+
|
|
648
|
+
# embed patches
|
|
649
|
+
x = self.patch_embed(x)
|
|
650
|
+
|
|
651
|
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
|
652
|
+
# add pos embed w/o cls token
|
|
653
|
+
x = x + pos_embed[:, 1:, :]
|
|
654
|
+
|
|
655
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
656
|
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
|
657
|
+
temporal_encoding = self.temporal_embed_enc(
|
|
658
|
+
temporal_coords, num_tokens_per_frame
|
|
659
|
+
)
|
|
660
|
+
x = x + temporal_encoding
|
|
661
|
+
if self.location_encoding and location_coords is not None:
|
|
662
|
+
location_encoding = self.location_embed_enc(location_coords)
|
|
663
|
+
x = x + location_encoding
|
|
664
|
+
|
|
665
|
+
# append cls token
|
|
666
|
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
|
667
|
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
668
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
669
|
+
|
|
670
|
+
# apply Transformer blocks
|
|
671
|
+
out = []
|
|
672
|
+
for block in self.blocks:
|
|
673
|
+
x = block(x)
|
|
674
|
+
out.append(x.clone())
|
|
675
|
+
|
|
676
|
+
x = self.norm(x)
|
|
677
|
+
out[-1] = x
|
|
678
|
+
return out
|
|
679
|
+
|
|
680
|
+
def prepare_features_for_image_model(
|
|
681
|
+
self, features: list[torch.Tensor], t: int
|
|
682
|
+
) -> list[torch.Tensor]:
|
|
683
|
+
"""prepare_features_for_image_model."""
|
|
684
|
+
out = []
|
|
685
|
+
for x in features:
|
|
686
|
+
x_no_token = x[:, 1:, :]
|
|
687
|
+
number_of_tokens = x_no_token.shape[1]
|
|
688
|
+
tokens_per_timestep = number_of_tokens // t
|
|
689
|
+
h = int(np.sqrt(tokens_per_timestep))
|
|
690
|
+
encoded = rearrange(
|
|
691
|
+
x_no_token,
|
|
692
|
+
"batch (t h w) e -> batch t e h w",
|
|
693
|
+
e=self.embed_dim,
|
|
694
|
+
t=t,
|
|
695
|
+
h=h,
|
|
696
|
+
)
|
|
697
|
+
# mean along the time dimension
|
|
698
|
+
out.append(encoded.mean(dim=1))
|
|
699
|
+
return out
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class MAEDecoder(nn.Module):
|
|
703
|
+
"""Transformer Decoder used in the Prithvi MAE."""
|
|
704
|
+
|
|
705
|
+
def __init__(
|
|
706
|
+
self,
|
|
707
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
708
|
+
grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
|
|
709
|
+
in_chans: int = 3,
|
|
710
|
+
encoder_embed_dim: int = 1024,
|
|
711
|
+
decoder_embed_dim: int = 512,
|
|
712
|
+
depth: int = 8,
|
|
713
|
+
num_heads: int = 16,
|
|
714
|
+
mlp_ratio: float = 4.0,
|
|
715
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
716
|
+
coords_encoding: list[str] | None = None,
|
|
717
|
+
coords_scale_learn: bool = False,
|
|
718
|
+
) -> None:
|
|
719
|
+
"""Init."""
|
|
720
|
+
super().__init__()
|
|
721
|
+
|
|
722
|
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
|
723
|
+
self.decoder_embed_dim = decoder_embed_dim
|
|
724
|
+
self.grid_size = grid_size
|
|
725
|
+
if isinstance(patch_size, int):
|
|
726
|
+
patch_size = (1, patch_size, patch_size)
|
|
727
|
+
self.patch_size = patch_size
|
|
728
|
+
self.num_frames = self.grid_size[0] * patch_size[0]
|
|
729
|
+
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
|
730
|
+
|
|
731
|
+
# Optional temporal and location embedding
|
|
732
|
+
coords_encoding = coords_encoding or []
|
|
733
|
+
self.temporal_encoding = "time" in coords_encoding
|
|
734
|
+
self.location_encoding = "location" in coords_encoding
|
|
735
|
+
if self.temporal_encoding:
|
|
736
|
+
self.temporal_embed_dec = TemporalEncoder(
|
|
737
|
+
decoder_embed_dim, coords_scale_learn
|
|
738
|
+
)
|
|
739
|
+
if self.location_encoding:
|
|
740
|
+
self.location_embed_dec = LocationEncoder(
|
|
741
|
+
decoder_embed_dim, coords_scale_learn
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
|
745
|
+
|
|
746
|
+
self.register_buffer(
|
|
747
|
+
"decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
self.decoder_blocks = nn.ModuleList(
|
|
751
|
+
[
|
|
752
|
+
Block(
|
|
753
|
+
decoder_embed_dim,
|
|
754
|
+
num_heads,
|
|
755
|
+
mlp_ratio,
|
|
756
|
+
qkv_bias=True,
|
|
757
|
+
norm_layer=norm_layer,
|
|
758
|
+
)
|
|
759
|
+
for _ in range(depth)
|
|
760
|
+
]
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
|
764
|
+
self.decoder_pred = nn.Linear(
|
|
765
|
+
decoder_embed_dim,
|
|
766
|
+
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
|
|
767
|
+
bias=True,
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
self.initialize_weights()
|
|
771
|
+
|
|
772
|
+
def initialize_weights(self) -> None:
|
|
773
|
+
"""initialize_weights."""
|
|
774
|
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
|
775
|
+
decoder_pos_embed = get_3d_sincos_pos_embed(
|
|
776
|
+
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
|
|
777
|
+
)
|
|
778
|
+
self.decoder_pos_embed.data.copy_(
|
|
779
|
+
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
|
783
|
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
|
784
|
+
self.apply(_init_weights)
|
|
785
|
+
|
|
786
|
+
def interpolate_pos_encoding(
|
|
787
|
+
self, sample_shape: tuple[int, int, int]
|
|
788
|
+
) -> torch.Tensor:
|
|
789
|
+
"""interpolate_pos_encoding."""
|
|
790
|
+
pos_embed = _interpolate_pos_encoding(
|
|
791
|
+
pos_embed=self.decoder_pos_embed,
|
|
792
|
+
grid_size=self.grid_size,
|
|
793
|
+
patch_size=self.patch_size,
|
|
794
|
+
shape=sample_shape,
|
|
795
|
+
embed_dim=self.decoder_embed_dim,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
return pos_embed
|
|
799
|
+
|
|
800
|
+
def forward(
|
|
801
|
+
self,
|
|
802
|
+
hidden_states: torch.Tensor,
|
|
803
|
+
ids_restore: torch.Tensor,
|
|
804
|
+
temporal_coords: None | torch.Tensor = None,
|
|
805
|
+
location_coords: None | torch.Tensor = None,
|
|
806
|
+
input_size: list[int] | None = None,
|
|
807
|
+
) -> torch.Tensor:
|
|
808
|
+
"""Forward."""
|
|
809
|
+
# embed tokens
|
|
810
|
+
x = self.decoder_embed(hidden_states)
|
|
811
|
+
cls_token = x[:, :1, :]
|
|
812
|
+
|
|
813
|
+
# append mask tokens to sequence
|
|
814
|
+
mask_tokens = self.mask_token.repeat(
|
|
815
|
+
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
|
|
816
|
+
)
|
|
817
|
+
x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
|
818
|
+
# unshuffle
|
|
819
|
+
x = torch.gather(
|
|
820
|
+
x,
|
|
821
|
+
dim=1,
|
|
822
|
+
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device),
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
# add pos embed
|
|
826
|
+
decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:]) # type: ignore
|
|
827
|
+
cls_token = cls_token + decoder_pos_embed[:, :1, :]
|
|
828
|
+
x = x + decoder_pos_embed[:, 1:, :]
|
|
829
|
+
|
|
830
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
831
|
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
|
832
|
+
temporal_encoding = self.temporal_embed_dec(
|
|
833
|
+
temporal_coords, num_tokens_per_frame
|
|
834
|
+
)
|
|
835
|
+
# Add temporal encoding w/o cls token
|
|
836
|
+
x = x + temporal_encoding
|
|
837
|
+
if self.location_encoding and location_coords is not None:
|
|
838
|
+
location_encoding = self.location_embed_dec(location_coords)
|
|
839
|
+
# Add location encoding w/o cls token
|
|
840
|
+
x = x + location_encoding
|
|
841
|
+
|
|
842
|
+
# append cls token
|
|
843
|
+
x = torch.cat([cls_token, x], dim=1)
|
|
844
|
+
|
|
845
|
+
# apply Transformer layers (blocks)
|
|
846
|
+
for block in self.decoder_blocks:
|
|
847
|
+
x = block(x)
|
|
848
|
+
x = self.decoder_norm(x)
|
|
849
|
+
|
|
850
|
+
# predictor projection
|
|
851
|
+
pred = self.decoder_pred(x)
|
|
852
|
+
|
|
853
|
+
# remove cls token
|
|
854
|
+
pred = pred[:, 1:, :]
|
|
855
|
+
|
|
856
|
+
return pred
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class PrithviMAE(nn.Module):
|
|
860
|
+
"""Prithvi Masked Autoencoder."""
|
|
861
|
+
|
|
862
|
+
def __init__(
|
|
863
|
+
self,
|
|
864
|
+
img_size: int | tuple[int, int] = 224,
|
|
865
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
866
|
+
num_frames: int = 4,
|
|
867
|
+
in_chans: int = 6,
|
|
868
|
+
embed_dim: int = 768,
|
|
869
|
+
depth: int = 12,
|
|
870
|
+
num_heads: int = 12,
|
|
871
|
+
decoder_embed_dim: int = 512,
|
|
872
|
+
decoder_depth: int = 8,
|
|
873
|
+
decoder_num_heads: int = 16,
|
|
874
|
+
mlp_ratio: float = 4.0,
|
|
875
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
876
|
+
norm_pix_loss: bool = False,
|
|
877
|
+
coords_encoding: list[str] | None = None,
|
|
878
|
+
coords_scale_learn: bool = False,
|
|
879
|
+
drop_path: float = 0.0,
|
|
880
|
+
mask_ratio: float = 0.75,
|
|
881
|
+
**kwargs: Any,
|
|
882
|
+
):
|
|
883
|
+
"""Init."""
|
|
884
|
+
super().__init__()
|
|
885
|
+
|
|
886
|
+
self.encoder = PrithviViT(
|
|
887
|
+
img_size=img_size,
|
|
888
|
+
num_frames=num_frames,
|
|
889
|
+
patch_size=patch_size,
|
|
890
|
+
in_chans=in_chans,
|
|
891
|
+
embed_dim=embed_dim,
|
|
892
|
+
depth=depth,
|
|
893
|
+
num_heads=num_heads,
|
|
894
|
+
mlp_ratio=mlp_ratio,
|
|
895
|
+
norm_layer=norm_layer,
|
|
896
|
+
coords_encoding=coords_encoding,
|
|
897
|
+
coords_scale_learn=coords_scale_learn,
|
|
898
|
+
drop_path=drop_path,
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
self.decoder = MAEDecoder(
|
|
902
|
+
patch_size=patch_size,
|
|
903
|
+
grid_size=self.encoder.patch_embed.grid_size,
|
|
904
|
+
in_chans=in_chans,
|
|
905
|
+
encoder_embed_dim=embed_dim,
|
|
906
|
+
decoder_embed_dim=decoder_embed_dim,
|
|
907
|
+
depth=decoder_depth,
|
|
908
|
+
num_heads=decoder_num_heads,
|
|
909
|
+
mlp_ratio=mlp_ratio,
|
|
910
|
+
norm_layer=norm_layer,
|
|
911
|
+
coords_encoding=coords_encoding,
|
|
912
|
+
coords_scale_learn=coords_scale_learn,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
self.mask_ratio = mask_ratio
|
|
916
|
+
self.norm_pix_loss = norm_pix_loss
|
|
917
|
+
self.out_channels = self.encoder.out_channels
|
|
918
|
+
|
|
919
|
+
def patchify(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
920
|
+
"""Patchify.
|
|
921
|
+
|
|
922
|
+
Args:
|
|
923
|
+
pixel_values: (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
|
|
924
|
+
Pixel values.
|
|
925
|
+
|
|
926
|
+
Returns:
|
|
927
|
+
torch.FloatTensor of shape
|
|
928
|
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
|
929
|
+
Patchified pixel values.
|
|
930
|
+
"""
|
|
931
|
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
|
932
|
+
num_channels = self.encoder.in_chans
|
|
933
|
+
|
|
934
|
+
# patchify
|
|
935
|
+
patchified_pixel_values = rearrange(
|
|
936
|
+
pixel_values,
|
|
937
|
+
"b c (t s) (h p) (w q) -> b (t h w) (s p q c)",
|
|
938
|
+
c=num_channels,
|
|
939
|
+
s=patch_size_t,
|
|
940
|
+
p=patch_size_h,
|
|
941
|
+
q=patch_size_w,
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
return patchified_pixel_values
|
|
945
|
+
|
|
946
|
+
def unpatchify(
|
|
947
|
+
self,
|
|
948
|
+
patchified_pixel_values: torch.Tensor,
|
|
949
|
+
image_size: tuple[int, int] | None = None,
|
|
950
|
+
) -> torch.Tensor:
|
|
951
|
+
"""Unpatchify.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
patchified_pixel_values: (`torch.FloatTensor` of shape
|
|
955
|
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
|
|
956
|
+
Patchified pixel values.
|
|
957
|
+
image_size: (`tuple[int, int]`, *optional*):
|
|
958
|
+
Original image size.
|
|
959
|
+
|
|
960
|
+
Returns:
|
|
961
|
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
|
962
|
+
Pixel values.
|
|
963
|
+
"""
|
|
964
|
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
|
965
|
+
image_size = (
|
|
966
|
+
to_2tuple(image_size) if image_size is not None else self.encoder.img_size
|
|
967
|
+
)
|
|
968
|
+
original_height, original_width = image_size
|
|
969
|
+
num_patches_h = original_height // patch_size_h
|
|
970
|
+
num_patches_w = original_width // patch_size_w
|
|
971
|
+
num_channels = self.encoder.in_chans
|
|
972
|
+
|
|
973
|
+
pixel_values = rearrange(
|
|
974
|
+
patchified_pixel_values,
|
|
975
|
+
"b (t h w) (s p q c) -> b c (t s) (h p) (w q)",
|
|
976
|
+
c=num_channels,
|
|
977
|
+
h=num_patches_h,
|
|
978
|
+
w=num_patches_w,
|
|
979
|
+
s=patch_size_t,
|
|
980
|
+
p=patch_size_h,
|
|
981
|
+
q=patch_size_w,
|
|
982
|
+
)
|
|
983
|
+
return pixel_values
|
|
984
|
+
|
|
985
|
+
def forward_loss(
|
|
986
|
+
self, pixel_values: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor
|
|
987
|
+
) -> torch.Tensor:
|
|
988
|
+
"""forward_loss.
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
pixel_values: (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
|
|
992
|
+
Pixel values.
|
|
993
|
+
pred: (`torch.FloatTensor` of shape
|
|
994
|
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
|
995
|
+
Predicted pixel values.
|
|
996
|
+
mask: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
|
997
|
+
Tensor indicating which patches are masked (1) and which are not (0).
|
|
998
|
+
|
|
999
|
+
Returns:
|
|
1000
|
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
|
1001
|
+
"""
|
|
1002
|
+
target = self.patchify(pixel_values)
|
|
1003
|
+
if self.norm_pix_loss:
|
|
1004
|
+
mean = target.mean(dim=-1, keepdim=True)
|
|
1005
|
+
var = target.var(dim=-1, keepdim=True)
|
|
1006
|
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
|
1007
|
+
|
|
1008
|
+
loss = (pred - target) ** 2
|
|
1009
|
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
|
1010
|
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
|
1011
|
+
return loss
|
|
1012
|
+
|
|
1013
|
+
def forward(
|
|
1014
|
+
self,
|
|
1015
|
+
pixel_values: torch.Tensor,
|
|
1016
|
+
temporal_coords: None | torch.Tensor = None,
|
|
1017
|
+
location_coords: None | torch.Tensor = None,
|
|
1018
|
+
mask_ratio: float | None = None,
|
|
1019
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1020
|
+
"""Forward."""
|
|
1021
|
+
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
|
|
1022
|
+
# add time dim
|
|
1023
|
+
pixel_values = pixel_values.unsqueeze(2)
|
|
1024
|
+
|
|
1025
|
+
mask_ratio = mask_ratio or self.mask_ratio
|
|
1026
|
+
latent, mask, ids_restore = self.encoder(
|
|
1027
|
+
pixel_values, temporal_coords, location_coords, mask_ratio
|
|
1028
|
+
)
|
|
1029
|
+
pred = self.decoder(
|
|
1030
|
+
latent,
|
|
1031
|
+
ids_restore,
|
|
1032
|
+
temporal_coords,
|
|
1033
|
+
location_coords,
|
|
1034
|
+
input_size=pixel_values.shape,
|
|
1035
|
+
)
|
|
1036
|
+
loss = self.forward_loss(pixel_values, pred, mask)
|
|
1037
|
+
return loss, pred, mask
|
|
1038
|
+
|
|
1039
|
+
def forward_features(
|
|
1040
|
+
self,
|
|
1041
|
+
x: torch.Tensor,
|
|
1042
|
+
temporal_coords: None | torch.Tensor = None,
|
|
1043
|
+
location_coords: None | torch.Tensor = None,
|
|
1044
|
+
) -> list[torch.Tensor]:
|
|
1045
|
+
"""forward_features."""
|
|
1046
|
+
return self.encoder.forward_features(x, temporal_coords, location_coords)
|