rslearn 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. rslearn/models/anysat.py +207 -0
  2. rslearn/models/clay/clay.py +204 -0
  3. rslearn/models/clay/configs/metadata.yaml +295 -0
  4. rslearn/models/galileo/__init__.py +5 -0
  5. rslearn/models/galileo/galileo.py +517 -0
  6. rslearn/models/galileo/single_file_galileo.py +1672 -0
  7. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  8. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  9. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  10. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  11. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  12. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  13. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  14. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  15. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  16. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  17. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  18. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  19. rslearn/models/presto/presto.py +10 -7
  20. rslearn/models/prithvi.py +1046 -0
  21. rslearn/models/unet.py +17 -11
  22. rslearn/utils/geometry.py +61 -1
  23. rslearn/utils/vector_format.py +13 -10
  24. {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/METADATA +145 -15
  25. {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/RECORD +29 -10
  26. {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/WHEEL +0 -0
  27. {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/entry_points.txt +0 -0
  28. {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/licenses/LICENSE +0 -0
  29. {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1046 @@
1
+ """Prithvi V2."""
2
+
3
+ import logging
4
+ import tempfile
5
+ import warnings
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import yaml
12
+ from einops import rearrange
13
+ from huggingface_hub import hf_hub_download
14
+ from timm.layers import to_2tuple
15
+ from timm.models.vision_transformer import Block
16
+ from torch.nn import functional as F
17
+ from upath import UPath
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ # for Prithvi, true values are ["B02", "B03", "B04", "B05", "B06", "B07"]
23
+ PRITHVI_MEAN = [
24
+ 1087.0,
25
+ 1342.0,
26
+ 1433.0,
27
+ 2734.0,
28
+ 1958.0,
29
+ 1363.0,
30
+ ]
31
+ PRITHVI_STD = [
32
+ 2248.0,
33
+ 2179.0,
34
+ 2178.0,
35
+ 1850.0,
36
+ 1242.0,
37
+ 1049.0,
38
+ ]
39
+
40
+
41
+ HF_HUB_ID = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M"
42
+
43
+
44
+ class PrithviV2(nn.Module):
45
+ """An Rslearn wrapper for Prithvi 2.0."""
46
+
47
+ input_keys = ["sentinel2"]
48
+
49
+ def __init__(self, pretrained_path: str | UPath | None = None, num_frames: int = 1):
50
+ """Init.
51
+
52
+ Inputs:
53
+ pretrained_path: The folder in which to download the prithvi config
54
+ and weights. If None, it downloads to a temporary folder.
55
+ num_frames: The number of input frames (timesteps). The model was trained on 3,
56
+ but if there is just one timestamp examples use 1 (e.g.
57
+ https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/
58
+ example_landslide4sense.ipynb)
59
+
60
+ """
61
+ super().__init__()
62
+ if pretrained_path is None:
63
+ pretrained_path = UPath(
64
+ tempfile.gettempdir(), "rslearn_cache", "prithvi_v2"
65
+ )
66
+
67
+ if not (UPath(pretrained_path) / "config.json").exists():
68
+ _ = hf_hub_download(
69
+ local_dir=pretrained_path,
70
+ repo_id=HF_HUB_ID,
71
+ filename="config.json",
72
+ revision="b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
73
+ )
74
+ with (UPath(pretrained_path) / "config.json").open("r") as f:
75
+ config = yaml.safe_load(f)["pretrained_cfg"]
76
+
77
+ config["num_frames"] = num_frames
78
+
79
+ self.model = PrithviMAE(**config)
80
+
81
+ if not (UPath(pretrained_path) / "Prithvi_EO_V2_300M.pt").exists():
82
+ _ = hf_hub_download(
83
+ local_dir=pretrained_path,
84
+ repo_id=HF_HUB_ID,
85
+ filename="Prithvi_EO_V2_300M.pt",
86
+ revision="b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
87
+ )
88
+
89
+ state_dict = torch.load(
90
+ UPath(pretrained_path) / "Prithvi_EO_V2_300M.pt",
91
+ map_location="cpu",
92
+ weights_only=True,
93
+ )
94
+ # discard fixed pos_embedding weight, following
95
+ # https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M/blob/e4aabdc440c8ee703a749def8af5bf4700dee35b/inference.py#L362
96
+ for k in list(state_dict.keys()):
97
+ if "pos_embed" in k:
98
+ del state_dict[k]
99
+ self.model.load_state_dict(state_dict, strict=False)
100
+ self.image_resolution = config["img_size"]
101
+ self.bands = config["bands"]
102
+ # patch size is a list [t, h, w], where h == w
103
+ self.patch_size = config["patch_size"][-1]
104
+
105
+ def _resize_data(self, data: torch.Tensor) -> torch.Tensor:
106
+ """Process individual modality data.
107
+
108
+ Args:
109
+ data: Input tensor of shape [B, C, H, W]
110
+
111
+ Returns:
112
+ list of tensors of shape [B, C, H, W]
113
+ """
114
+ # Get original dimensions
115
+ original_height = data.shape[2]
116
+ new_height = self.patch_size if original_height == 1 else self.image_resolution
117
+ data = F.interpolate(
118
+ data,
119
+ size=(new_height, new_height),
120
+ mode="bilinear",
121
+ align_corners=False,
122
+ )
123
+ return data
124
+
125
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
126
+ """Compute feature maps from the Prithvi V2 backbone.
127
+
128
+ Inputs:
129
+ inputs: input dicts that must include "sentinel2"
130
+ keys depending. Prithvi is designed for HLS (Harmonized Landsat-Sentinel);
131
+ this naming keeps the model consistent with other rslearn models.
132
+
133
+ Returns:
134
+ 11 feature maps (one per transformer block in the Prithvi model),
135
+ of shape [B, H/p_s, W/p_s, D=1024] where p_s=16 is the patch size.
136
+ """
137
+ x = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
138
+ x = self._resize_data(x)
139
+ num_timesteps = x.shape[1] // len(self.bands)
140
+ x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
141
+ features = self.model.encoder.forward_features(x)
142
+ # prepare_features_for_image_model was slightly modified since we already
143
+ # know the number of timesteps and don't need to recompute it.
144
+ # in addition we average along the time dimension (instead of concatenating)
145
+ # to keep the embeddings reasonably sized.
146
+ return self.model.encoder.prepare_features_for_image_model(
147
+ features, num_timesteps
148
+ )
149
+
150
+
151
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
152
+ #
153
+ # Licensed under the Apache License, Version 2.0 (the "License");
154
+ # you may not use this file except in compliance with the License.
155
+ # You may obtain a copy of the License at
156
+ #
157
+ # http://www.apache.org/licenses/LICENSE-2.0
158
+ #
159
+ # Unless required by applicable law or agreed to in writing, software
160
+ # distributed under the License is distributed on an "AS IS" BASIS,
161
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
162
+ # See the License for the specific language governing permissions and
163
+ # limitations under the License.
164
+ # --------------------------------------------------------
165
+ # References:
166
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
167
+ # transformers: https://github.com/huggingface/transformers
168
+ # --------------------------------------------------------
169
+
170
+
171
+ def get_3d_sincos_pos_embed(
172
+ embed_dim: int,
173
+ grid_size: tuple[int, int, int] | list[int],
174
+ add_cls_token: bool = False,
175
+ ) -> torch.Tensor:
176
+ """Create 3D sin/cos positional embeddings.
177
+
178
+ Args:
179
+ embed_dim (int):
180
+ Embedding dimension.
181
+ grid_size (tuple[int, int, int] | list[int]):
182
+ The grid depth, height and width.
183
+ add_cls_token (bool, *optional*, defaults to False):
184
+ Whether or not to add a classification (CLS) token.
185
+
186
+ Returns:
187
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
188
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
189
+ """
190
+ assert embed_dim % 16 == 0
191
+
192
+ t_size, h_size, w_size = grid_size
193
+
194
+ w_embed_dim = embed_dim // 16 * 6
195
+ h_embed_dim = embed_dim // 16 * 6
196
+ t_embed_dim = embed_dim // 16 * 4
197
+
198
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
199
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
200
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
201
+
202
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
203
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
204
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
205
+
206
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
207
+
208
+ if add_cls_token:
209
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
210
+ return pos_embed
211
+
212
+
213
+ def get_1d_sincos_pos_embed_from_grid(
214
+ embed_dim: int, pos: torch.Tensor
215
+ ) -> torch.Tensor:
216
+ """embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)."""
217
+ if embed_dim % 2 != 0:
218
+ raise ValueError("embed_dim must be even")
219
+
220
+ omega = np.arange(embed_dim // 2, dtype=float)
221
+ omega /= embed_dim / 2.0
222
+ omega = 1.0 / 10000**omega # (D/2,)
223
+
224
+ pos = pos.reshape(-1) # (M,)
225
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
226
+
227
+ emb_sin = np.sin(out) # (M, D/2)
228
+ emb_cos = np.cos(out) # (M, D/2)
229
+
230
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
231
+ return emb
232
+
233
+
234
+ def _get_1d_sincos_embed_from_grid_torch(
235
+ embed_dim: int, pos: torch.Tensor
236
+ ) -> torch.Tensor:
237
+ """Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
238
+
239
+ embed_dim: output dimension for each position
240
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
241
+ out: (M, D)
242
+ """
243
+ assert embed_dim % 2 == 0
244
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
245
+
246
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
247
+ omega /= embed_dim / 2.0
248
+ omega = 1.0 / 10000**omega # (D/2,)
249
+
250
+ pos = pos.reshape(-1) # (M,)
251
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
252
+
253
+ emb_sin = torch.sin(out) # (M, D/2)
254
+ emb_cos = torch.cos(out) # (M, D/2)
255
+
256
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
257
+
258
+ return emb
259
+
260
+
261
+ def _init_weights(module: nn.Module) -> None:
262
+ """Initialize the weights."""
263
+ if isinstance(module, nn.Linear):
264
+ nn.init.xavier_uniform_(module.weight)
265
+ if module.bias is not None:
266
+ module.bias.data.zero_()
267
+ elif isinstance(module, nn.LayerNorm):
268
+ module.bias.data.zero_()
269
+ module.weight.data.fill_(1.0)
270
+
271
+
272
+ def _interpolate_pos_encoding(
273
+ pos_embed: torch.Tensor,
274
+ grid_size: tuple[int, int, int] | list[int],
275
+ patch_size: tuple[int, int, int] | list[int],
276
+ shape: tuple[int, int, int] | list[int],
277
+ embed_dim: int,
278
+ ) -> torch.Tensor:
279
+ """_interpolate_pos_encoding.
280
+
281
+ Adapted from:
282
+ - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
283
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
284
+ """
285
+ t, h, w = shape
286
+ t_patches = t // patch_size[0]
287
+ h_patches = h // patch_size[1]
288
+ w_patches = w // patch_size[2]
289
+
290
+ if [t_patches, h_patches, w_patches] == grid_size:
291
+ # No interpolation needed
292
+ return pos_embed
293
+ if t_patches != grid_size[0]:
294
+ # Re-compute pos embedding to handle changed num_frames
295
+ new_grid_size = (t_patches, *grid_size[1:])
296
+ new_pos_embed = get_3d_sincos_pos_embed(
297
+ pos_embed.shape[-1], new_grid_size, add_cls_token=True
298
+ )
299
+ new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
300
+ else:
301
+ new_grid_size = grid_size # type: ignore
302
+ new_pos_embed = pos_embed
303
+
304
+ class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
305
+
306
+ patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(
307
+ 0, 3, 1, 2
308
+ )
309
+
310
+ patch_pos_embed = nn.functional.interpolate(
311
+ patch_pos_embed,
312
+ size=(h_patches, w_patches),
313
+ mode="bicubic",
314
+ align_corners=True,
315
+ )
316
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
317
+
318
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
319
+
320
+
321
+ class PatchEmbed(nn.Module):
322
+ """3D version of timm.models.vision_transformer.PatchEmbed."""
323
+
324
+ def __init__(
325
+ self,
326
+ input_size: tuple[int, int, int] = (1, 224, 224),
327
+ patch_size: tuple[int, int, int] = (1, 16, 16),
328
+ in_chans: int = 3,
329
+ embed_dim: int = 768,
330
+ norm_layer: nn.Module | None = None,
331
+ flatten: bool = True,
332
+ bias: bool = True,
333
+ ) -> None:
334
+ """Init."""
335
+ super().__init__()
336
+ self.input_size = input_size
337
+ self.patch_size = patch_size
338
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
339
+ assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
340
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
341
+ self.flatten = flatten
342
+
343
+ self.proj = nn.Conv3d(
344
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
345
+ )
346
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
347
+
348
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
349
+ """Forward."""
350
+ B, C, T, H, W = x.shape
351
+
352
+ if (
353
+ T / self.patch_size[0] % 1
354
+ or H / self.patch_size[1] % 1
355
+ or W / self.patch_size[2] % 1
356
+ ):
357
+ warnings.warn(
358
+ f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
359
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks."
360
+ )
361
+
362
+ x = self.proj(x)
363
+ if self.flatten:
364
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
365
+ x = self.norm(x)
366
+ return x
367
+
368
+
369
+ class TemporalEncoder(nn.Module):
370
+ """TemporalEncoder."""
371
+
372
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
373
+ """Init."""
374
+ super().__init__()
375
+ self.embed_dim = embed_dim
376
+ self.year_embed_dim = embed_dim // 2
377
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
378
+
379
+ # If trainable, initialize scale with small number
380
+ if trainable_scale:
381
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
382
+ else:
383
+ self.register_buffer("scale", torch.ones(1))
384
+
385
+ def forward(
386
+ self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None
387
+ ) -> torch.Tensor:
388
+ """Forward.
389
+
390
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
391
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
392
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
393
+ """
394
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
395
+
396
+ year = _get_1d_sincos_embed_from_grid_torch(
397
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()
398
+ ).reshape(shape)
399
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
400
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()
401
+ ).reshape(shape)
402
+
403
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
404
+
405
+ if tokens_per_frame is not None:
406
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
407
+
408
+ return embedding # B, T*tokens_per_frame, embed_dim
409
+
410
+
411
+ class LocationEncoder(nn.Module):
412
+ """LocationEncoder."""
413
+
414
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
415
+ """Init."""
416
+ super().__init__()
417
+ self.embed_dim = embed_dim
418
+ self.lat_embed_dim = embed_dim // 2
419
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
420
+
421
+ # If trainable, initialize scale with small number
422
+ if trainable_scale:
423
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
424
+ else:
425
+ self.register_buffer("scale", torch.ones(1))
426
+
427
+ def forward(self, location_coords: torch.Tensor) -> torch.Tensor:
428
+ """location_coords: lat and lon info with shape (B, 2)."""
429
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
430
+
431
+ lat = _get_1d_sincos_embed_from_grid_torch(
432
+ self.lat_embed_dim, location_coords[:, 0].flatten()
433
+ ).reshape(shape)
434
+ lon = _get_1d_sincos_embed_from_grid_torch(
435
+ self.lon_embed_dim, location_coords[:, 1].flatten()
436
+ ).reshape(shape)
437
+
438
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
439
+
440
+ return embedding # B, 1, embed_dim
441
+
442
+
443
+ class PrithviViT(nn.Module):
444
+ """Prithvi ViT Encoder."""
445
+
446
+ def __init__(
447
+ self,
448
+ img_size: int | tuple[int, int] = 224,
449
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
450
+ num_frames: int = 1,
451
+ in_chans: int = 3,
452
+ embed_dim: int = 1024,
453
+ depth: int = 24,
454
+ num_heads: int = 16,
455
+ mlp_ratio: float = 4.0,
456
+ norm_layer: nn.Module = nn.LayerNorm,
457
+ coords_encoding: list[str] | None = None,
458
+ coords_scale_learn: bool = False,
459
+ drop_path: float = 0.0,
460
+ **kwargs: Any,
461
+ ) -> None:
462
+ """Init."""
463
+ super().__init__()
464
+
465
+ self.in_chans = in_chans
466
+ self.num_frames = num_frames
467
+ self.embed_dim = embed_dim
468
+ self.img_size = to_2tuple(img_size)
469
+ if isinstance(patch_size, int):
470
+ patch_size = (1, patch_size, patch_size)
471
+
472
+ # 3D patch embedding
473
+ self.patch_embed = PatchEmbed(
474
+ input_size=(num_frames,) + self.img_size,
475
+ patch_size=patch_size,
476
+ in_chans=in_chans,
477
+ embed_dim=embed_dim,
478
+ )
479
+ self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
480
+
481
+ # Optional temporal and location embedding
482
+ coords_encoding = coords_encoding or []
483
+ self.temporal_encoding = "time" in coords_encoding
484
+ self.location_encoding = "location" in coords_encoding
485
+ if self.temporal_encoding:
486
+ assert patch_size[0] == 1, (
487
+ f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
488
+ )
489
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
490
+ if self.location_encoding:
491
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
492
+
493
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
494
+ self.register_buffer(
495
+ "pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
496
+ )
497
+
498
+ # Transformer layers
499
+ self.blocks = []
500
+ for i in range(depth):
501
+ self.blocks.append(
502
+ Block(
503
+ embed_dim,
504
+ num_heads,
505
+ mlp_ratio,
506
+ qkv_bias=True,
507
+ norm_layer=norm_layer,
508
+ drop_path=drop_path,
509
+ )
510
+ )
511
+ self.blocks = nn.ModuleList(self.blocks)
512
+
513
+ self.norm = norm_layer(embed_dim)
514
+
515
+ self.initialize_weights()
516
+
517
+ def initialize_weights(self) -> None:
518
+ """initialize_weights."""
519
+ # initialize (and freeze) position embeddings by sin-cos embedding
520
+ pos_embed = get_3d_sincos_pos_embed(
521
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
522
+ )
523
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
524
+
525
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
526
+ w = self.patch_embed.proj.weight.data
527
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
528
+
529
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
530
+ torch.nn.init.normal_(self.cls_token, std=0.02)
531
+ self.apply(_init_weights)
532
+
533
+ def random_masking(
534
+ self,
535
+ sequence: torch.Tensor,
536
+ mask_ratio: float,
537
+ noise: None | torch.Tensor = None,
538
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
539
+ """Perform per-sample random masking by per-sample shuffling.
540
+
541
+ Per-sample shuffling is done by argsort random
542
+ noise.
543
+
544
+ Args:
545
+ sequence: (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
546
+ mask_ratio: (float): mask ratio to use.
547
+ noise: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
548
+ mainly used for testing purposes to control randomness and maintain the reproducibility
549
+ """
550
+ batch_size, seq_length, dim = sequence.shape
551
+ len_keep = int(seq_length * (1 - mask_ratio))
552
+
553
+ if noise is None:
554
+ noise = torch.rand(
555
+ batch_size, seq_length, device=sequence.device
556
+ ) # noise in [0, 1]
557
+
558
+ # sort noise for each sample
559
+ ids_shuffle = torch.argsort(noise, dim=1).to(
560
+ sequence.device
561
+ ) # ascend: small is keep, large is remove
562
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
563
+
564
+ # keep the first subset
565
+ ids_keep = ids_shuffle[:, :len_keep]
566
+ sequence_unmasked = torch.gather(
567
+ sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)
568
+ )
569
+
570
+ # generate the binary mask: 0 is keep, 1 is remove
571
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
572
+ mask[:, :len_keep] = 0
573
+ # unshuffle to get the binary mask
574
+ mask = torch.gather(mask, dim=1, index=ids_restore)
575
+
576
+ return sequence_unmasked, mask, ids_restore
577
+
578
+ def interpolate_pos_encoding(
579
+ self, sample_shape: tuple[int, int, int] | list[int]
580
+ ) -> torch.Tensor:
581
+ """interpolate_pos_encoding."""
582
+ pos_embed = _interpolate_pos_encoding(
583
+ pos_embed=self.pos_embed,
584
+ grid_size=self.patch_embed.grid_size,
585
+ patch_size=self.patch_embed.patch_size,
586
+ shape=sample_shape,
587
+ embed_dim=self.embed_dim,
588
+ )
589
+ return pos_embed
590
+
591
+ def forward(
592
+ self,
593
+ x: torch.Tensor,
594
+ temporal_coords: None | torch.Tensor = None,
595
+ location_coords: None | torch.Tensor = None,
596
+ mask_ratio: float = 0.75,
597
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
598
+ """Forward."""
599
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
600
+ # add time dim
601
+ x = x.unsqueeze(2)
602
+ sample_shape = x.shape[-3:]
603
+
604
+ # embed patches
605
+ x = self.patch_embed(x)
606
+
607
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
608
+ # add pos embed w/o cls token
609
+ x = x + pos_embed[:, 1:, :]
610
+
611
+ if self.temporal_encoding and temporal_coords is not None:
612
+ num_tokens_per_frame = x.shape[1] // self.num_frames
613
+ temporal_encoding = self.temporal_embed_enc(
614
+ temporal_coords, num_tokens_per_frame
615
+ )
616
+ x = x + temporal_encoding
617
+ if self.location_encoding and location_coords is not None:
618
+ location_encoding = self.location_embed_enc(location_coords)
619
+ x = x + location_encoding
620
+
621
+ # masking: length -> length * mask_ratio
622
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
623
+
624
+ # append cls token
625
+ cls_token = self.cls_token + pos_embed[:, :1, :]
626
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
627
+ x = torch.cat((cls_tokens, x), dim=1)
628
+
629
+ # apply Transformer blocks
630
+ for block in self.blocks:
631
+ x = block(x)
632
+ x = self.norm(x)
633
+
634
+ return x, mask, ids_restore
635
+
636
+ def forward_features(
637
+ self,
638
+ x: torch.Tensor,
639
+ temporal_coords: None | torch.Tensor = None,
640
+ location_coords: None | torch.Tensor = None,
641
+ ) -> list[torch.Tensor]:
642
+ """forward_features."""
643
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
644
+ # add time dim
645
+ x = x.unsqueeze(2)
646
+ sample_shape = x.shape[-3:]
647
+
648
+ # embed patches
649
+ x = self.patch_embed(x)
650
+
651
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
652
+ # add pos embed w/o cls token
653
+ x = x + pos_embed[:, 1:, :]
654
+
655
+ if self.temporal_encoding and temporal_coords is not None:
656
+ num_tokens_per_frame = x.shape[1] // self.num_frames
657
+ temporal_encoding = self.temporal_embed_enc(
658
+ temporal_coords, num_tokens_per_frame
659
+ )
660
+ x = x + temporal_encoding
661
+ if self.location_encoding and location_coords is not None:
662
+ location_encoding = self.location_embed_enc(location_coords)
663
+ x = x + location_encoding
664
+
665
+ # append cls token
666
+ cls_token = self.cls_token + pos_embed[:, :1, :]
667
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
668
+ x = torch.cat((cls_tokens, x), dim=1)
669
+
670
+ # apply Transformer blocks
671
+ out = []
672
+ for block in self.blocks:
673
+ x = block(x)
674
+ out.append(x.clone())
675
+
676
+ x = self.norm(x)
677
+ out[-1] = x
678
+ return out
679
+
680
+ def prepare_features_for_image_model(
681
+ self, features: list[torch.Tensor], t: int
682
+ ) -> list[torch.Tensor]:
683
+ """prepare_features_for_image_model."""
684
+ out = []
685
+ for x in features:
686
+ x_no_token = x[:, 1:, :]
687
+ number_of_tokens = x_no_token.shape[1]
688
+ tokens_per_timestep = number_of_tokens // t
689
+ h = int(np.sqrt(tokens_per_timestep))
690
+ encoded = rearrange(
691
+ x_no_token,
692
+ "batch (t h w) e -> batch t e h w",
693
+ e=self.embed_dim,
694
+ t=t,
695
+ h=h,
696
+ )
697
+ # mean along the time dimension
698
+ out.append(encoded.mean(dim=1))
699
+ return out
700
+
701
+
702
+ class MAEDecoder(nn.Module):
703
+ """Transformer Decoder used in the Prithvi MAE."""
704
+
705
+ def __init__(
706
+ self,
707
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
708
+ grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
709
+ in_chans: int = 3,
710
+ encoder_embed_dim: int = 1024,
711
+ decoder_embed_dim: int = 512,
712
+ depth: int = 8,
713
+ num_heads: int = 16,
714
+ mlp_ratio: float = 4.0,
715
+ norm_layer: nn.Module = nn.LayerNorm,
716
+ coords_encoding: list[str] | None = None,
717
+ coords_scale_learn: bool = False,
718
+ ) -> None:
719
+ """Init."""
720
+ super().__init__()
721
+
722
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
723
+ self.decoder_embed_dim = decoder_embed_dim
724
+ self.grid_size = grid_size
725
+ if isinstance(patch_size, int):
726
+ patch_size = (1, patch_size, patch_size)
727
+ self.patch_size = patch_size
728
+ self.num_frames = self.grid_size[0] * patch_size[0]
729
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
730
+
731
+ # Optional temporal and location embedding
732
+ coords_encoding = coords_encoding or []
733
+ self.temporal_encoding = "time" in coords_encoding
734
+ self.location_encoding = "location" in coords_encoding
735
+ if self.temporal_encoding:
736
+ self.temporal_embed_dec = TemporalEncoder(
737
+ decoder_embed_dim, coords_scale_learn
738
+ )
739
+ if self.location_encoding:
740
+ self.location_embed_dec = LocationEncoder(
741
+ decoder_embed_dim, coords_scale_learn
742
+ )
743
+
744
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
745
+
746
+ self.register_buffer(
747
+ "decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)
748
+ )
749
+
750
+ self.decoder_blocks = nn.ModuleList(
751
+ [
752
+ Block(
753
+ decoder_embed_dim,
754
+ num_heads,
755
+ mlp_ratio,
756
+ qkv_bias=True,
757
+ norm_layer=norm_layer,
758
+ )
759
+ for _ in range(depth)
760
+ ]
761
+ )
762
+
763
+ self.decoder_norm = norm_layer(decoder_embed_dim)
764
+ self.decoder_pred = nn.Linear(
765
+ decoder_embed_dim,
766
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
767
+ bias=True,
768
+ )
769
+
770
+ self.initialize_weights()
771
+
772
+ def initialize_weights(self) -> None:
773
+ """initialize_weights."""
774
+ # initialize (and freeze) position embeddings by sin-cos embedding
775
+ decoder_pos_embed = get_3d_sincos_pos_embed(
776
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
777
+ )
778
+ self.decoder_pos_embed.data.copy_(
779
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
780
+ )
781
+
782
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
783
+ torch.nn.init.normal_(self.mask_token, std=0.02)
784
+ self.apply(_init_weights)
785
+
786
+ def interpolate_pos_encoding(
787
+ self, sample_shape: tuple[int, int, int]
788
+ ) -> torch.Tensor:
789
+ """interpolate_pos_encoding."""
790
+ pos_embed = _interpolate_pos_encoding(
791
+ pos_embed=self.decoder_pos_embed,
792
+ grid_size=self.grid_size,
793
+ patch_size=self.patch_size,
794
+ shape=sample_shape,
795
+ embed_dim=self.decoder_embed_dim,
796
+ )
797
+
798
+ return pos_embed
799
+
800
+ def forward(
801
+ self,
802
+ hidden_states: torch.Tensor,
803
+ ids_restore: torch.Tensor,
804
+ temporal_coords: None | torch.Tensor = None,
805
+ location_coords: None | torch.Tensor = None,
806
+ input_size: list[int] | None = None,
807
+ ) -> torch.Tensor:
808
+ """Forward."""
809
+ # embed tokens
810
+ x = self.decoder_embed(hidden_states)
811
+ cls_token = x[:, :1, :]
812
+
813
+ # append mask tokens to sequence
814
+ mask_tokens = self.mask_token.repeat(
815
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
816
+ )
817
+ x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
818
+ # unshuffle
819
+ x = torch.gather(
820
+ x,
821
+ dim=1,
822
+ index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device),
823
+ )
824
+
825
+ # add pos embed
826
+ decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:]) # type: ignore
827
+ cls_token = cls_token + decoder_pos_embed[:, :1, :]
828
+ x = x + decoder_pos_embed[:, 1:, :]
829
+
830
+ if self.temporal_encoding and temporal_coords is not None:
831
+ num_tokens_per_frame = x.shape[1] // self.num_frames
832
+ temporal_encoding = self.temporal_embed_dec(
833
+ temporal_coords, num_tokens_per_frame
834
+ )
835
+ # Add temporal encoding w/o cls token
836
+ x = x + temporal_encoding
837
+ if self.location_encoding and location_coords is not None:
838
+ location_encoding = self.location_embed_dec(location_coords)
839
+ # Add location encoding w/o cls token
840
+ x = x + location_encoding
841
+
842
+ # append cls token
843
+ x = torch.cat([cls_token, x], dim=1)
844
+
845
+ # apply Transformer layers (blocks)
846
+ for block in self.decoder_blocks:
847
+ x = block(x)
848
+ x = self.decoder_norm(x)
849
+
850
+ # predictor projection
851
+ pred = self.decoder_pred(x)
852
+
853
+ # remove cls token
854
+ pred = pred[:, 1:, :]
855
+
856
+ return pred
857
+
858
+
859
+ class PrithviMAE(nn.Module):
860
+ """Prithvi Masked Autoencoder."""
861
+
862
+ def __init__(
863
+ self,
864
+ img_size: int | tuple[int, int] = 224,
865
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
866
+ num_frames: int = 4,
867
+ in_chans: int = 6,
868
+ embed_dim: int = 768,
869
+ depth: int = 12,
870
+ num_heads: int = 12,
871
+ decoder_embed_dim: int = 512,
872
+ decoder_depth: int = 8,
873
+ decoder_num_heads: int = 16,
874
+ mlp_ratio: float = 4.0,
875
+ norm_layer: nn.Module = nn.LayerNorm,
876
+ norm_pix_loss: bool = False,
877
+ coords_encoding: list[str] | None = None,
878
+ coords_scale_learn: bool = False,
879
+ drop_path: float = 0.0,
880
+ mask_ratio: float = 0.75,
881
+ **kwargs: Any,
882
+ ):
883
+ """Init."""
884
+ super().__init__()
885
+
886
+ self.encoder = PrithviViT(
887
+ img_size=img_size,
888
+ num_frames=num_frames,
889
+ patch_size=patch_size,
890
+ in_chans=in_chans,
891
+ embed_dim=embed_dim,
892
+ depth=depth,
893
+ num_heads=num_heads,
894
+ mlp_ratio=mlp_ratio,
895
+ norm_layer=norm_layer,
896
+ coords_encoding=coords_encoding,
897
+ coords_scale_learn=coords_scale_learn,
898
+ drop_path=drop_path,
899
+ )
900
+
901
+ self.decoder = MAEDecoder(
902
+ patch_size=patch_size,
903
+ grid_size=self.encoder.patch_embed.grid_size,
904
+ in_chans=in_chans,
905
+ encoder_embed_dim=embed_dim,
906
+ decoder_embed_dim=decoder_embed_dim,
907
+ depth=decoder_depth,
908
+ num_heads=decoder_num_heads,
909
+ mlp_ratio=mlp_ratio,
910
+ norm_layer=norm_layer,
911
+ coords_encoding=coords_encoding,
912
+ coords_scale_learn=coords_scale_learn,
913
+ )
914
+
915
+ self.mask_ratio = mask_ratio
916
+ self.norm_pix_loss = norm_pix_loss
917
+ self.out_channels = self.encoder.out_channels
918
+
919
+ def patchify(self, pixel_values: torch.Tensor) -> torch.Tensor:
920
+ """Patchify.
921
+
922
+ Args:
923
+ pixel_values: (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
924
+ Pixel values.
925
+
926
+ Returns:
927
+ torch.FloatTensor of shape
928
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
929
+ Patchified pixel values.
930
+ """
931
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
932
+ num_channels = self.encoder.in_chans
933
+
934
+ # patchify
935
+ patchified_pixel_values = rearrange(
936
+ pixel_values,
937
+ "b c (t s) (h p) (w q) -> b (t h w) (s p q c)",
938
+ c=num_channels,
939
+ s=patch_size_t,
940
+ p=patch_size_h,
941
+ q=patch_size_w,
942
+ )
943
+
944
+ return patchified_pixel_values
945
+
946
+ def unpatchify(
947
+ self,
948
+ patchified_pixel_values: torch.Tensor,
949
+ image_size: tuple[int, int] | None = None,
950
+ ) -> torch.Tensor:
951
+ """Unpatchify.
952
+
953
+ Args:
954
+ patchified_pixel_values: (`torch.FloatTensor` of shape
955
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
956
+ Patchified pixel values.
957
+ image_size: (`tuple[int, int]`, *optional*):
958
+ Original image size.
959
+
960
+ Returns:
961
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
962
+ Pixel values.
963
+ """
964
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
965
+ image_size = (
966
+ to_2tuple(image_size) if image_size is not None else self.encoder.img_size
967
+ )
968
+ original_height, original_width = image_size
969
+ num_patches_h = original_height // patch_size_h
970
+ num_patches_w = original_width // patch_size_w
971
+ num_channels = self.encoder.in_chans
972
+
973
+ pixel_values = rearrange(
974
+ patchified_pixel_values,
975
+ "b (t h w) (s p q c) -> b c (t s) (h p) (w q)",
976
+ c=num_channels,
977
+ h=num_patches_h,
978
+ w=num_patches_w,
979
+ s=patch_size_t,
980
+ p=patch_size_h,
981
+ q=patch_size_w,
982
+ )
983
+ return pixel_values
984
+
985
+ def forward_loss(
986
+ self, pixel_values: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor
987
+ ) -> torch.Tensor:
988
+ """forward_loss.
989
+
990
+ Args:
991
+ pixel_values: (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
992
+ Pixel values.
993
+ pred: (`torch.FloatTensor` of shape
994
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
995
+ Predicted pixel values.
996
+ mask: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
997
+ Tensor indicating which patches are masked (1) and which are not (0).
998
+
999
+ Returns:
1000
+ `torch.FloatTensor`: Pixel reconstruction loss.
1001
+ """
1002
+ target = self.patchify(pixel_values)
1003
+ if self.norm_pix_loss:
1004
+ mean = target.mean(dim=-1, keepdim=True)
1005
+ var = target.var(dim=-1, keepdim=True)
1006
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
1007
+
1008
+ loss = (pred - target) ** 2
1009
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
1010
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
1011
+ return loss
1012
+
1013
+ def forward(
1014
+ self,
1015
+ pixel_values: torch.Tensor,
1016
+ temporal_coords: None | torch.Tensor = None,
1017
+ location_coords: None | torch.Tensor = None,
1018
+ mask_ratio: float | None = None,
1019
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1020
+ """Forward."""
1021
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
1022
+ # add time dim
1023
+ pixel_values = pixel_values.unsqueeze(2)
1024
+
1025
+ mask_ratio = mask_ratio or self.mask_ratio
1026
+ latent, mask, ids_restore = self.encoder(
1027
+ pixel_values, temporal_coords, location_coords, mask_ratio
1028
+ )
1029
+ pred = self.decoder(
1030
+ latent,
1031
+ ids_restore,
1032
+ temporal_coords,
1033
+ location_coords,
1034
+ input_size=pixel_values.shape,
1035
+ )
1036
+ loss = self.forward_loss(pixel_values, pred, mask)
1037
+ return loss, pred, mask
1038
+
1039
+ def forward_features(
1040
+ self,
1041
+ x: torch.Tensor,
1042
+ temporal_coords: None | torch.Tensor = None,
1043
+ location_coords: None | torch.Tensor = None,
1044
+ ) -> list[torch.Tensor]:
1045
+ """forward_features."""
1046
+ return self.encoder.forward_features(x, temporal_coords, location_coords)