rslearn 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,216 @@
1
+ # type: ignore
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ # --------------------------------------------------------
8
+ # Position embedding utils
9
+ # --------------------------------------------------------
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
20
+ # --------------------------------------------------------
21
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
22
+ """grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """embed_dim: output dimension for each position
51
+ pos: a list of positions to be encoded: size (M,)
52
+ out: (M, D)
53
+ """
54
+ assert embed_dim % 2 == 0
55
+ # omega = np.arange(embed_dim // 2, dtype=np.float) # numpy deprecated in 1.20
56
+ omega = np.arange(embed_dim // 2, dtype=float)
57
+
58
+ omega /= embed_dim / 2.0
59
+ omega = 1.0 / 10000**omega # (D/2,)
60
+
61
+ pos = pos.reshape(-1) # (M,)
62
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
63
+
64
+ emb_sin = np.sin(out) # (M, D/2)
65
+ emb_cos = np.cos(out) # (M, D/2)
66
+
67
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68
+ return emb
69
+
70
+
71
+ # --------------------------------------------------------
72
+ # Interpolate position embeddings for high-resolution
73
+ # References:
74
+ # DeiT: https://github.com/facebookresearch/deit
75
+ # --------------------------------------------------------
76
+ def interpolate_pos_embed(model, checkpoint_model):
77
+ if "pos_embed" in checkpoint_model:
78
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
79
+ embedding_size = pos_embed_checkpoint.shape[-1]
80
+ num_patches = model.patch_embed.num_patches
81
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
82
+ # height (== width) for the checkpoint position embedding
83
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
84
+ # height (== width) for the new position embedding
85
+ new_size = int(num_patches**0.5)
86
+ # class_token and dist_token are kept unchanged
87
+ if orig_size != new_size:
88
+ print(
89
+ "Position interpolate from %dx%d to %dx%d"
90
+ % (orig_size, orig_size, new_size, new_size)
91
+ )
92
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
93
+ # only the position tokens are interpolated
94
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
95
+ pos_tokens = pos_tokens.reshape(
96
+ -1, orig_size, orig_size, embedding_size
97
+ ).permute(0, 3, 1, 2)
98
+ pos_tokens = torch.nn.functional.interpolate(
99
+ pos_tokens,
100
+ size=(new_size, new_size),
101
+ mode="bicubic",
102
+ align_corners=False,
103
+ )
104
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
105
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
106
+ checkpoint_model["pos_embed"] = new_pos_embed
107
+
108
+
109
+ def interpolate_pos_embed_ofa(model, checkpoint_model):
110
+ if "pos_embed" in checkpoint_model:
111
+ pos_embed_dict = checkpoint_model["pos_embed"]
112
+
113
+ for key, pos_embed in pos_embed_dict.items():
114
+ pos_embed_checkpoint = pos_embed
115
+ embedding_size = pos_embed_checkpoint.shape[-1]
116
+ num_patches = model.patch_embed[key].num_patches
117
+ num_extra_tokens = model.pos_embed[key].shape[-2] - num_patches
118
+ # height (== width) for the checkpoint position embedding
119
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
120
+ # height (== width) for the new position embedding
121
+ new_size = int(num_patches**0.5)
122
+ # class_token and dist_token are kept unchanged
123
+ if orig_size != new_size:
124
+ print(
125
+ "Position interpolate from %dx%d to %dx%d"
126
+ % (orig_size, orig_size, new_size, new_size)
127
+ )
128
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
129
+ # only the position tokens are interpolated
130
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
131
+ pos_tokens = pos_tokens.reshape(
132
+ -1, orig_size, orig_size, embedding_size
133
+ ).permute(0, 3, 1, 2)
134
+ pos_tokens = torch.nn.functional.interpolate(
135
+ pos_tokens,
136
+ size=(new_size, new_size),
137
+ mode="bicubic",
138
+ align_corners=False,
139
+ )
140
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
141
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
142
+ checkpoint_model["pos_embed"][key] = new_pos_embed
143
+
144
+
145
+ def get_2d_sincos_pos_embed_with_resolution(
146
+ embed_dim, grid_size, res, cls_token=False, device="cpu"
147
+ ):
148
+ """grid_size: int of the grid height and width
149
+ res: array of size n, representing the resolution of a pixel (say, in meters),
150
+
151
+ Return:
152
+ pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
153
+ """
154
+ # res = torch.FloatTensor(res).to(device)
155
+ res = res.to(device)
156
+ grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
157
+ grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
158
+ grid = torch.meshgrid(
159
+ grid_w, grid_h, indexing="xy"
160
+ ) # here h goes first,direction reversed for numpy
161
+ grid = torch.stack(grid, dim=0) # 2 x h x w
162
+
163
+ # grid = grid.reshape([2, 1, grid_size, grid_size])
164
+ grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
165
+ _, n, h, w = grid.shape
166
+ pos_embed = get_2d_sincos_pos_embed_from_grid_torch(
167
+ embed_dim, grid
168
+ ) # # (nxH*W, D/2)
169
+ pos_embed = pos_embed.reshape(n, h * w, embed_dim)
170
+ if cls_token:
171
+ pos_embed = torch.cat(
172
+ [
173
+ torch.zeros(
174
+ [n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device
175
+ ),
176
+ pos_embed,
177
+ ],
178
+ dim=1,
179
+ )
180
+ return pos_embed
181
+
182
+
183
+ def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
184
+ assert embed_dim % 2 == 0
185
+
186
+ # use half of dimensions to encode grid_h
187
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(
188
+ embed_dim // 2, grid[0]
189
+ ) # (H*W, D/2)
190
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(
191
+ embed_dim // 2, grid[1]
192
+ ) # (H*W, D/2)
193
+
194
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
195
+ return emb
196
+
197
+
198
+ def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
199
+ """embed_dim: output dimension for each position
200
+ pos: a list of positions to be encoded: size (M,)
201
+ out: (M, D)
202
+ """
203
+ assert embed_dim % 2 == 0
204
+ old_shape = pos
205
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
206
+ omega /= embed_dim / 2.0
207
+ omega = 1.0 / 10000**omega # (D/2,)
208
+
209
+ pos = pos.reshape(-1) # (M,)
210
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
211
+
212
+ emb_sin = torch.sin(out) # (M, D/2)
213
+ emb_cos = torch.cos(out) # (M, D/2)
214
+
215
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
216
+ return emb
@@ -0,0 +1,167 @@
1
+ """Wrapper for the Panopticon model."""
2
+
3
+ import math
4
+ from enum import StrEnum
5
+ from importlib import resources
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import yaml
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+ from rslearn.log_utils import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class PanopticonModalities(StrEnum):
20
+ """Modalities supported by Panopticon.
21
+
22
+ These are the keys needed to load the yaml file from panopticon_data/sensors
23
+ """
24
+
25
+ SENTINEL2 = "sentinel2"
26
+ LANDSAT8 = "landsat8"
27
+ SENTINEL1 = "sentinel1"
28
+ # Add more modalities as needed
29
+
30
+
31
+ class Panopticon(nn.Module):
32
+ """Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
33
+
34
+ patch_size: int = 14
35
+ base_image_size: int = 224
36
+
37
+ def __init__(
38
+ self,
39
+ band_order: dict[str, list[str]],
40
+ torchhub_id: str = "panopticon_vitb14",
41
+ ):
42
+ """Initialize the Panopticon wrapper.
43
+
44
+ Args:
45
+ band_order: The band order for the panopticon model, must match the specified order in the data config
46
+ torchhub_id: The torch hub model ID for panopticon
47
+ """
48
+ super().__init__()
49
+ # Load the panopticon model
50
+ self._load_model(torchhub_id)
51
+ self.output_dim = self.model.embed_dim
52
+ self.band_order = band_order
53
+ self.supported_modalities = list(band_order.keys())
54
+
55
+ def _load_model(self, torchhub_id: str) -> None:
56
+ """Load the panopticon model from torch hub."""
57
+ import time
58
+
59
+ # Hack to get around https://discuss.pytorch.org/t/torch-hub-load-gives-httperror-rate-limit-exceeded/124769
60
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
61
+ for attempt in range(2):
62
+ try:
63
+ self.model = torch.hub.load( # nosec B614
64
+ "panopticon-FM/panopticon",
65
+ torchhub_id,
66
+ )
67
+ break
68
+ except Exception as e:
69
+ logger.warning(
70
+ f"Error loading panopticon model: {e}. Retrying in 5 seconds..."
71
+ )
72
+ time.sleep(5)
73
+ else:
74
+ raise RuntimeError(
75
+ f"Failed to load panopticon model {torchhub_id} after retrying."
76
+ )
77
+
78
+ def _process_modality_data(self, data: torch.Tensor) -> torch.Tensor:
79
+ """Process individual modality data.
80
+
81
+ Args:
82
+ data: Input tensor of shape [B, C, H, W]
83
+
84
+ Returns:
85
+ Processed tensor of shape [B, C, H, W]
86
+ """
87
+ original_height = data.shape[2]
88
+ new_height = self.patch_size if original_height == 1 else self.base_image_size
89
+
90
+ data = F.interpolate(
91
+ data,
92
+ size=(new_height, new_height),
93
+ mode="bilinear",
94
+ align_corners=False,
95
+ )
96
+ return data
97
+
98
+ def _create_channel_ids(
99
+ self, modality: str, batch_size: int, device: torch.device
100
+ ) -> torch.Tensor:
101
+ """Create channel IDs for the panopticon model."""
102
+ with resources.open_text(
103
+ "rslearn.models.panopticon_data.sensors", f"{modality}.yaml"
104
+ ) as f:
105
+ sensor_config = yaml.safe_load(f)
106
+
107
+ band_order = self.band_order[modality]
108
+ chn_ids = [
109
+ sensor_config["bands"][band.upper()]["gaussian"]["mu"]
110
+ for band in band_order
111
+ ]
112
+ chn_ids = torch.tensor(chn_ids, dtype=torch.float32, device=device)
113
+ chn_ids = repeat(chn_ids, "c -> b c", b=batch_size)
114
+ return chn_ids
115
+
116
+ def prepare_input(
117
+ self, input_data: dict[str, torch.Tensor]
118
+ ) -> dict[str, torch.Tensor]:
119
+ """Prepare input for the panopticon model from MaskedHeliosSample."""
120
+ channel_ids_list: list[torch.Tensor] = []
121
+ processed_data_list: list[torch.Tensor] = []
122
+ for modality in self.supported_modalities:
123
+ if modality not in input_data.keys():
124
+ logger.debug(f"Modality {modality} not found in input data")
125
+ continue
126
+ data = input_data[modality]
127
+ device = data.device
128
+ processed_data = self._process_modality_data(data)
129
+ processed_data_list.append(processed_data)
130
+ batch_size = processed_data.shape[0]
131
+ chn_ids = self._create_channel_ids(modality, batch_size, device)
132
+ channel_ids_list.append(chn_ids)
133
+
134
+ processed_data = torch.cat(processed_data_list, dim=1)
135
+ chn_ids = torch.cat(channel_ids_list, dim=1)
136
+ return {
137
+ "imgs": processed_data,
138
+ "chn_ids": chn_ids,
139
+ }
140
+
141
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
142
+ """Forward pass through the panopticon model."""
143
+ batch_inputs = {
144
+ key: torch.stack([inp[key] for inp in inputs], dim=0)
145
+ for key in inputs[0].keys()
146
+ }
147
+ panopticon_inputs = self.prepare_input(batch_inputs)
148
+ output_features = self.model.forward_features(panopticon_inputs)[
149
+ "x_norm_patchtokens"
150
+ ]
151
+
152
+ num_tokens = output_features.shape[1]
153
+ height = int(math.sqrt(num_tokens))
154
+ output_features = rearrange(
155
+ output_features, "b (h w) d -> b d h w", h=height, w=height
156
+ )
157
+ return [output_features]
158
+
159
+ def get_backbone_channels(self) -> list:
160
+ """Returns the output channels of this model when used as a backbone.
161
+
162
+ The output channels is a list of (downsample_factor, depth) that corresponds
163
+ to the feature maps that the backbone returns. For example, an element [2, 32]
164
+ indicates that the corresponding feature map is 1/2 the input resolution and
165
+ has 32 channels.
166
+ """
167
+ return [(self.patch_size, self.output_dim)]
@@ -0,0 +1,5 @@
1
+ """Presto."""
2
+
3
+ from .presto import Presto
4
+
5
+ __all__ = ["Presto"]
@@ -0,0 +1,247 @@
1
+ """Presto wrapper to ingest Masked Helios Samples."""
2
+
3
+ import logging
4
+ import tempfile
5
+ from typing import Any
6
+
7
+ import torch
8
+ from einops import rearrange, repeat
9
+ from huggingface_hub import hf_hub_download
10
+ from torch import nn
11
+ from upath import UPath
12
+
13
+ from rslearn.models.presto.single_file_presto import (
14
+ ERA5_BANDS,
15
+ NUM_DYNAMIC_WORLD_CLASSES,
16
+ PRESTO_ADD_BY,
17
+ PRESTO_BANDS,
18
+ PRESTO_DIV_BY,
19
+ PRESTO_S1_BANDS,
20
+ PRESTO_S2_BANDS,
21
+ SRTM_BANDS,
22
+ )
23
+ from rslearn.models.presto.single_file_presto import Presto as SFPresto
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ INPUT_PRESTO_BANDS = [b for b in PRESTO_BANDS if b != "B09"]
28
+ INPUT_PRESTO_S2_BANDS = [b for b in PRESTO_S2_BANDS if b != "B09"]
29
+
30
+ PRESTO_S1_SUBTRACT_VALUE = -25.0
31
+ PRESTO_S1_DIV_VALUE = 25.0
32
+ PRESTO_S2_SUBTRACT_VALUE = 0.0
33
+ PRESTO_S2_DIV_VALUE = 1e4
34
+
35
+ HF_HUB_ID = "nasaharvest/presto"
36
+ MODEL_FILENAME = "default_model.pt"
37
+
38
+
39
+ class Presto(nn.Module):
40
+ """Presto."""
41
+
42
+ input_keys = [
43
+ "s1",
44
+ "s2",
45
+ "era5",
46
+ "srtm",
47
+ "dynamic_world",
48
+ "latlon",
49
+ ]
50
+
51
+ def __init__(
52
+ self,
53
+ pretrained_path: str | UPath | None = None,
54
+ pixel_batch_size: int = 128,
55
+ ):
56
+ """Initialize the Presto wrapper.
57
+
58
+ Args:
59
+ pretrained_path: The directory to load from
60
+ pixel_batch_size: If the input has a h,w dimension >1, this is
61
+ flattened into a batch dimension (b h w) before being passed
62
+ to the model (since Presto is designed for pixel timeseries).
63
+ """
64
+ super().__init__()
65
+
66
+ if pretrained_path is None:
67
+ pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "presto")
68
+ if not (UPath(pretrained_path) / MODEL_FILENAME).exists():
69
+ _ = hf_hub_download(
70
+ local_dir=UPath(pretrained_path),
71
+ repo_id=HF_HUB_ID,
72
+ filename=MODEL_FILENAME,
73
+ # pin the model to a specific hugging face commit
74
+ revision="1b97f885969da4e2d5834ca8c92707c737911464",
75
+ )
76
+
77
+ model = SFPresto.construct()
78
+ model.load_state_dict(
79
+ torch.load(
80
+ UPath(pretrained_path) / MODEL_FILENAME,
81
+ map_location="cpu",
82
+ weights_only=True,
83
+ )
84
+ )
85
+ self.pixel_batch_size = pixel_batch_size
86
+ self.model = model.encoder
87
+ self.month = 6 # default month
88
+
89
+ def construct_presto_input(
90
+ self,
91
+ s1: torch.Tensor | None = None,
92
+ s1_bands: torch.Tensor | None = None,
93
+ s2: torch.Tensor | None = None,
94
+ s2_bands: torch.Tensor | None = None,
95
+ era5: torch.Tensor | None = None,
96
+ era5_bands: torch.Tensor | None = None,
97
+ srtm: torch.Tensor | None = None,
98
+ srtm_bands: torch.Tensor | None = None,
99
+ dynamic_world: torch.Tensor | None = None,
100
+ months: torch.Tensor | None = None,
101
+ normalize: bool = True,
102
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ """Inputs are paired into a tensor input <X> and a list <X>_bands, which describes <X>.
104
+
105
+ <X> should have shape (b, num_timesteps, h, w len(<X>_bands)), with the following bands for
106
+ each input:
107
+
108
+ s1: ["VV", "VH"]
109
+ s2: ["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"]
110
+ era5: ["temperature_2m", "total_precipitation"]
111
+ "temperature_2m": Temperature of air at 2m above the surface of land,
112
+ sea or in-land waters in Kelvin (K)
113
+ "total_precipitation": Accumulated liquid and frozen water, including rain and snow,
114
+ that falls to the Earth's surface. Measured in metres (m)
115
+ srtm: ["elevation", "slope"]
116
+
117
+ dynamic_world is a 1d input of shape (num_timesteps,) representing the dynamic world classes
118
+ of each timestep for that pixel
119
+ """
120
+ bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
121
+ hs = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
122
+ ws = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
123
+
124
+ assert len(bs) > 0
125
+ assert len(set(bs)) == 1
126
+ assert len(set(hs)) == 1
127
+ assert len(set(ws)) == 1
128
+ b, h, w = bs[0], hs[0], ws[0]
129
+
130
+ # these values will be initialized as
131
+ # we iterate through the data
132
+ x: torch.Tensor | None = None
133
+ mask: torch.Tensor | None = None
134
+ t: int | None = None
135
+
136
+ for band_group in [
137
+ (s1, s1_bands),
138
+ (s2, s2_bands),
139
+ (era5, era5_bands),
140
+ (srtm, srtm_bands),
141
+ ]:
142
+ data, input_bands = band_group
143
+ if data is not None:
144
+ assert input_bands is not None
145
+ else:
146
+ continue
147
+
148
+ m_t = data.shape[1] // len(input_bands)
149
+ if t is None:
150
+ t = m_t
151
+ else:
152
+ if t != m_t:
153
+ raise ValueError("inconsistent values for t")
154
+
155
+ data = rearrange(data, "b (t c) h w -> b t h w c", t=m_t)
156
+ if x is None:
157
+ x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS))
158
+ if mask is None:
159
+ mask = torch.ones(b, t, h, w, len(INPUT_PRESTO_BANDS))
160
+
161
+ # construct a mapping from the input bands to the presto input bands
162
+ input_to_output_mapping = [
163
+ INPUT_PRESTO_BANDS.index(val) for val in input_bands
164
+ ]
165
+ x[:, :, :, :, input_to_output_mapping] = data
166
+ mask[:, :, :, :, input_to_output_mapping] = 0
167
+
168
+ assert x is not None
169
+ assert mask is not None
170
+ assert t is not None
171
+
172
+ if dynamic_world is None:
173
+ dynamic_world = torch.ones(b, t, h, w) * NUM_DYNAMIC_WORLD_CLASSES
174
+
175
+ if months is None:
176
+ months = torch.ones((b, t), device=x.device) * self.month
177
+ else:
178
+ assert months.shape[-1] == t
179
+
180
+ if normalize:
181
+ x = (x + PRESTO_ADD_BY) / PRESTO_DIV_BY
182
+ return x, mask, dynamic_world.long(), months.long()
183
+
184
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
185
+ """Compute feature maps from the Presto backbone.
186
+
187
+ Inputs:
188
+ inputs
189
+ """
190
+ stacked_inputs = {}
191
+ latlons: torch.Tensor | None = None
192
+ for key in inputs[0].keys():
193
+ # assume all the keys in an input are consistent
194
+ if key in self.input_keys:
195
+ if key == "latlon":
196
+ latlons = torch.stack([inp[key] for inp in inputs], dim=0)
197
+ else:
198
+ stacked_inputs[key] = torch.stack(
199
+ [inp[key] for inp in inputs], dim=0
200
+ )
201
+
202
+ (
203
+ x,
204
+ mask,
205
+ dynamic_world,
206
+ months,
207
+ ) = self.construct_presto_input(
208
+ **stacked_inputs,
209
+ s1_bands=PRESTO_S1_BANDS,
210
+ s2_bands=INPUT_PRESTO_S2_BANDS,
211
+ era5_bands=ERA5_BANDS,
212
+ srtm_bands=SRTM_BANDS,
213
+ normalize=True,
214
+ )
215
+ b, _, h, w, _ = x.shape
216
+
217
+ output_features = torch.zeros(
218
+ b * h * w, self.model.embedding_size, device=x.device
219
+ )
220
+
221
+ x = rearrange(x, "b t h w d -> (b h w) t d")
222
+ mask = rearrange(mask, "b t h w d -> (b h w) t d")
223
+ dynamic_world = rearrange(dynamic_world, "b t h w -> (b h w) t")
224
+ months = repeat(months, "b t -> (b h w) t", h=h, w=w)
225
+ if latlons is not None:
226
+ latlons = rearrange(latlons, "b c h w -> (b h w) c")
227
+
228
+ for batch_idx in range(0, b * h * w, self.pixel_batch_size):
229
+ x_b = x[batch_idx : batch_idx + self.pixel_batch_size]
230
+ mask_b = mask[batch_idx : batch_idx + self.pixel_batch_size]
231
+ dw = dynamic_world[batch_idx : batch_idx + self.pixel_batch_size]
232
+ months_b = months[batch_idx : batch_idx + self.pixel_batch_size]
233
+ if latlons is not None:
234
+ l_b = latlons[batch_idx : batch_idx + self.pixel_batch_size]
235
+ else:
236
+ l_b = None
237
+ output_b = self.model(
238
+ x=x_b,
239
+ dynamic_world=dw,
240
+ mask=mask_b,
241
+ month=months_b,
242
+ latlons=l_b,
243
+ eval_task=True,
244
+ )
245
+ output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
246
+
247
+ return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]