rslearn 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/models/anysat.py +207 -0
- rslearn/models/clay/clay.py +204 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +517 -0
- rslearn/models/galileo/single_file_galileo.py +1672 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/presto/presto.py +10 -7
- rslearn/models/prithvi.py +1046 -0
- rslearn/models/unet.py +17 -11
- rslearn/utils/geometry.py +61 -1
- rslearn/utils/vector_format.py +13 -10
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/METADATA +145 -15
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/RECORD +29 -10
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/WHEEL +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.6.dist-info → rslearn-0.0.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
"""Galileo models."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import tempfile
|
|
5
|
+
from enum import StrEnum
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from einops import rearrange, repeat
|
|
12
|
+
from huggingface_hub import hf_hub_download
|
|
13
|
+
from upath import UPath
|
|
14
|
+
|
|
15
|
+
from rslearn.log_utils import get_logger
|
|
16
|
+
from rslearn.models.galileo.single_file_galileo import (
|
|
17
|
+
CONFIG_FILENAME,
|
|
18
|
+
DW_BANDS,
|
|
19
|
+
ENCODER_FILENAME,
|
|
20
|
+
ERA5_BANDS,
|
|
21
|
+
LANDSCAN_BANDS,
|
|
22
|
+
LOCATION_BANDS,
|
|
23
|
+
S1_BANDS,
|
|
24
|
+
S2_BANDS,
|
|
25
|
+
SPACE_BAND_GROUPS_IDX,
|
|
26
|
+
SPACE_BANDS,
|
|
27
|
+
SPACE_TIME_BANDS,
|
|
28
|
+
SPACE_TIME_BANDS_GROUPS_IDX,
|
|
29
|
+
SRTM_BANDS,
|
|
30
|
+
STATIC_BAND_GROUPS_IDX,
|
|
31
|
+
STATIC_BANDS,
|
|
32
|
+
TC_BANDS,
|
|
33
|
+
TIME_BAND_GROUPS_IDX,
|
|
34
|
+
TIME_BANDS,
|
|
35
|
+
VIIRS_BANDS,
|
|
36
|
+
WC_BANDS,
|
|
37
|
+
Encoder,
|
|
38
|
+
MaskedOutput,
|
|
39
|
+
Normalizer,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
logger = get_logger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
HF_HUB_ID = "nasaharvest/galileo"
|
|
46
|
+
DEFAULT_MONTH = 5
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Galileo provides three sizes: nano, tiny, base
|
|
50
|
+
class GalileoSize(StrEnum):
|
|
51
|
+
"""Size of the Galileo model."""
|
|
52
|
+
|
|
53
|
+
NANO = "nano"
|
|
54
|
+
TINY = "tiny"
|
|
55
|
+
BASE = "base"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
pretrained_weights: dict[GalileoSize, str] = {
|
|
59
|
+
GalileoSize.NANO: "models/nano",
|
|
60
|
+
GalileoSize.TINY: "models/tiny",
|
|
61
|
+
GalileoSize.BASE: "models/base",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
DEFAULT_NORMALIZER = Normalizer()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GalileoModel(nn.Module):
|
|
68
|
+
"""Galileo backbones."""
|
|
69
|
+
|
|
70
|
+
input_keys = [
|
|
71
|
+
"s1",
|
|
72
|
+
"s2",
|
|
73
|
+
"era5",
|
|
74
|
+
"tc",
|
|
75
|
+
"viirs",
|
|
76
|
+
"srtm",
|
|
77
|
+
"dw",
|
|
78
|
+
"wc",
|
|
79
|
+
"landscan",
|
|
80
|
+
"latlon",
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
size: GalileoSize,
|
|
86
|
+
patch_size: int = 4,
|
|
87
|
+
pretrained_path: str | UPath | None = None,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""Initialize the Galileo model.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
size: The size of the Galileo model.
|
|
93
|
+
patch_size: The patch size to use.
|
|
94
|
+
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
95
|
+
downloaded and cached in temp directory.
|
|
96
|
+
"""
|
|
97
|
+
super().__init__()
|
|
98
|
+
if pretrained_path is None:
|
|
99
|
+
pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "galileo")
|
|
100
|
+
|
|
101
|
+
pretrained_path_for_size = UPath(pretrained_path) / pretrained_weights[size]
|
|
102
|
+
if not (pretrained_path_for_size / CONFIG_FILENAME).exists():
|
|
103
|
+
_ = hf_hub_download(
|
|
104
|
+
local_dir=pretrained_path,
|
|
105
|
+
repo_id=HF_HUB_ID,
|
|
106
|
+
filename=f"{pretrained_weights[size]}/{CONFIG_FILENAME}",
|
|
107
|
+
revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
|
|
108
|
+
)
|
|
109
|
+
if not (pretrained_path_for_size / ENCODER_FILENAME).exists():
|
|
110
|
+
_ = hf_hub_download(
|
|
111
|
+
local_dir=pretrained_path,
|
|
112
|
+
repo_id=HF_HUB_ID,
|
|
113
|
+
filename=f"{pretrained_weights[size]}/{ENCODER_FILENAME}",
|
|
114
|
+
revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
assert (pretrained_path_for_size / ENCODER_FILENAME).exists()
|
|
118
|
+
assert (pretrained_path_for_size / CONFIG_FILENAME).exists()
|
|
119
|
+
|
|
120
|
+
self.model = Encoder.load_from_folder(
|
|
121
|
+
pretrained_path_for_size, device=torch.device("cpu")
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
self.s_t_channels_s2 = [
|
|
125
|
+
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key
|
|
126
|
+
]
|
|
127
|
+
self.s_t_channels_s1 = [
|
|
128
|
+
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
self.patch_size = patch_size
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def to_cartesian(
|
|
135
|
+
lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
|
|
136
|
+
) -> np.ndarray | torch.Tensor:
|
|
137
|
+
"""Transform latitudes and longitudes to cartesian coordinates."""
|
|
138
|
+
if isinstance(lat, float):
|
|
139
|
+
assert -90 <= lat <= 90, (
|
|
140
|
+
f"lat out of range ({lat}). Make sure you are in EPSG:4326"
|
|
141
|
+
)
|
|
142
|
+
assert -180 <= lon <= 180, (
|
|
143
|
+
f"lon out of range ({lon}). Make sure you are in EPSG:4326"
|
|
144
|
+
)
|
|
145
|
+
assert isinstance(lon, float), f"Expected float got {type(lon)}"
|
|
146
|
+
# transform to radians
|
|
147
|
+
lat = lat * math.pi / 180
|
|
148
|
+
lon = lon * math.pi / 180
|
|
149
|
+
x = math.cos(lat) * math.cos(lon)
|
|
150
|
+
y = math.cos(lat) * math.sin(lon)
|
|
151
|
+
z = math.sin(lat)
|
|
152
|
+
return np.array([x, y, z])
|
|
153
|
+
elif isinstance(lon, np.ndarray):
|
|
154
|
+
assert -90 <= lat.min(), (
|
|
155
|
+
f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
|
|
156
|
+
)
|
|
157
|
+
assert 90 >= lat.max(), (
|
|
158
|
+
f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
|
|
159
|
+
)
|
|
160
|
+
assert -180 <= lon.min(), (
|
|
161
|
+
f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
|
|
162
|
+
)
|
|
163
|
+
assert 180 >= lon.max(), (
|
|
164
|
+
f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
|
|
165
|
+
)
|
|
166
|
+
assert isinstance(lat, np.ndarray), f"Expected np.ndarray got {type(lat)}"
|
|
167
|
+
# transform to radians
|
|
168
|
+
lat = lat * math.pi / 180
|
|
169
|
+
lon = lon * math.pi / 180
|
|
170
|
+
x_np = np.cos(lat) * np.cos(lon)
|
|
171
|
+
y_np = np.cos(lat) * np.sin(lon)
|
|
172
|
+
z_np = np.sin(lat)
|
|
173
|
+
return np.stack([x_np, y_np, z_np], axis=-1)
|
|
174
|
+
elif isinstance(lon, torch.Tensor):
|
|
175
|
+
assert -90 <= lat.min(), (
|
|
176
|
+
f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
|
|
177
|
+
)
|
|
178
|
+
assert 90 >= lat.max(), (
|
|
179
|
+
f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
|
|
180
|
+
)
|
|
181
|
+
assert -180 <= lon.min(), (
|
|
182
|
+
f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
|
|
183
|
+
)
|
|
184
|
+
assert 180 >= lon.max(), (
|
|
185
|
+
f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
|
|
186
|
+
)
|
|
187
|
+
assert isinstance(lat, torch.Tensor), (
|
|
188
|
+
f"Expected torch.Tensor got {type(lat)}"
|
|
189
|
+
)
|
|
190
|
+
# transform to radians
|
|
191
|
+
lat = lat * math.pi / 180
|
|
192
|
+
lon = lon * math.pi / 180
|
|
193
|
+
x_t = torch.cos(lat) * torch.cos(lon)
|
|
194
|
+
y_t = torch.cos(lat) * torch.sin(lon)
|
|
195
|
+
z_t = torch.sin(lat)
|
|
196
|
+
return torch.stack([x_t, y_t, z_t], dim=-1)
|
|
197
|
+
else:
|
|
198
|
+
raise AssertionError(f"Unexpected input type {type(lon)}")
|
|
199
|
+
|
|
200
|
+
@classmethod
|
|
201
|
+
def construct_galileo_input(
|
|
202
|
+
cls,
|
|
203
|
+
s1: torch.Tensor | None = None, # [H, W, T, D]
|
|
204
|
+
s2: torch.Tensor | None = None, # [H, W, T, D]
|
|
205
|
+
era5: torch.Tensor | None = None, # [T, D]
|
|
206
|
+
tc: torch.Tensor | None = None, # [T, D]
|
|
207
|
+
viirs: torch.Tensor | None = None, # [T, D]
|
|
208
|
+
srtm: torch.Tensor | None = None, # [H, W, D]
|
|
209
|
+
dw: torch.Tensor | None = None, # [H, W, D]
|
|
210
|
+
wc: torch.Tensor | None = None, # [H, W, D]
|
|
211
|
+
landscan: torch.Tensor | None = None, # [D]
|
|
212
|
+
latlon: torch.Tensor | None = None, # [D]
|
|
213
|
+
months: torch.Tensor | None = None, # [T]
|
|
214
|
+
normalize: bool = False,
|
|
215
|
+
) -> MaskedOutput:
|
|
216
|
+
"""Construct a Galileo input."""
|
|
217
|
+
space_time_inputs = [s1, s2]
|
|
218
|
+
time_inputs = [era5, tc, viirs]
|
|
219
|
+
space_inputs = [srtm, dw, wc]
|
|
220
|
+
static_inputs = [landscan, latlon]
|
|
221
|
+
devices = [
|
|
222
|
+
x.device
|
|
223
|
+
for x in space_time_inputs + time_inputs + space_inputs + static_inputs
|
|
224
|
+
if x is not None
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
if len(devices) == 0:
|
|
228
|
+
raise ValueError("At least one input must be not None")
|
|
229
|
+
if not all(devices[0] == device for device in devices):
|
|
230
|
+
raise ValueError("Received tensors on multiple devices")
|
|
231
|
+
device = devices[0]
|
|
232
|
+
|
|
233
|
+
# first, check all the input shapes are consistent
|
|
234
|
+
batch_list = (
|
|
235
|
+
[x.shape[0] for x in space_time_inputs if x is not None]
|
|
236
|
+
+ [x.shape[0] for x in time_inputs if x is not None]
|
|
237
|
+
+ [x.shape[0] for x in space_inputs if x is not None]
|
|
238
|
+
+ [x.shape[0] for x in static_inputs if x is not None]
|
|
239
|
+
)
|
|
240
|
+
timesteps_list = [x.shape[3] for x in space_time_inputs if x is not None] + [
|
|
241
|
+
x.shape[1] for x in time_inputs if x is not None
|
|
242
|
+
]
|
|
243
|
+
height_list = [x.shape[1] for x in space_time_inputs if x is not None] + [
|
|
244
|
+
x.shape[1] for x in space_inputs if x is not None
|
|
245
|
+
]
|
|
246
|
+
width_list = [x.shape[2] for x in space_time_inputs if x is not None] + [
|
|
247
|
+
x.shape[2] for x in space_inputs if x is not None
|
|
248
|
+
]
|
|
249
|
+
if len(batch_list) > 0:
|
|
250
|
+
if len(set(batch_list)) > 1:
|
|
251
|
+
raise ValueError("Inconsistent number of batch sizes per input")
|
|
252
|
+
b = batch_list[0]
|
|
253
|
+
|
|
254
|
+
if len(timesteps_list) > 0:
|
|
255
|
+
if not all(timesteps_list[0] == timestep for timestep in timesteps_list):
|
|
256
|
+
raise ValueError("Inconsistent number of timesteps per input")
|
|
257
|
+
t = timesteps_list[0]
|
|
258
|
+
else:
|
|
259
|
+
t = 1
|
|
260
|
+
if len(height_list) > 0:
|
|
261
|
+
if not all(height_list[0] == height for height in height_list):
|
|
262
|
+
raise ValueError("Inconsistent heights per input")
|
|
263
|
+
if not all(width_list[0] == width for width in width_list):
|
|
264
|
+
raise ValueError("Inconsistent widths per input")
|
|
265
|
+
h = height_list[0]
|
|
266
|
+
w = width_list[0]
|
|
267
|
+
else:
|
|
268
|
+
h, w = 1, 1
|
|
269
|
+
|
|
270
|
+
# now, we can construct our empty input tensors. By default, everything is masked
|
|
271
|
+
s_t_x = torch.zeros(
|
|
272
|
+
(b, h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device
|
|
273
|
+
)
|
|
274
|
+
s_t_m = torch.ones(
|
|
275
|
+
(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)),
|
|
276
|
+
dtype=torch.float,
|
|
277
|
+
device=device,
|
|
278
|
+
)
|
|
279
|
+
sp_x = torch.zeros(
|
|
280
|
+
(b, h, w, len(SPACE_BANDS)), dtype=torch.float, device=device
|
|
281
|
+
)
|
|
282
|
+
sp_m = torch.ones(
|
|
283
|
+
(b, h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device
|
|
284
|
+
)
|
|
285
|
+
t_x = torch.zeros((b, t, len(TIME_BANDS)), dtype=torch.float, device=device)
|
|
286
|
+
t_m = torch.ones(
|
|
287
|
+
(b, t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device
|
|
288
|
+
)
|
|
289
|
+
st_x = torch.zeros((b, len(STATIC_BANDS)), dtype=torch.float, device=device)
|
|
290
|
+
st_m = torch.ones(
|
|
291
|
+
(b, len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
for x, bands_list, group_key in zip(
|
|
295
|
+
[s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]
|
|
296
|
+
):
|
|
297
|
+
if x is not None:
|
|
298
|
+
indices = [
|
|
299
|
+
idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list
|
|
300
|
+
]
|
|
301
|
+
groups_idx = [
|
|
302
|
+
idx
|
|
303
|
+
for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX)
|
|
304
|
+
if group_key in key
|
|
305
|
+
]
|
|
306
|
+
s_t_x[:, :, :, :, indices] = x
|
|
307
|
+
s_t_m[:, :, :, :, groups_idx] = 0
|
|
308
|
+
|
|
309
|
+
for x, bands_list, group_key in zip(
|
|
310
|
+
[srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"]
|
|
311
|
+
):
|
|
312
|
+
if x is not None:
|
|
313
|
+
indices = [
|
|
314
|
+
idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list
|
|
315
|
+
]
|
|
316
|
+
groups_idx = [
|
|
317
|
+
idx
|
|
318
|
+
for idx, key in enumerate(SPACE_BAND_GROUPS_IDX)
|
|
319
|
+
if group_key in key
|
|
320
|
+
]
|
|
321
|
+
sp_x[:, :, :, indices] = x
|
|
322
|
+
sp_m[:, :, :, groups_idx] = 0
|
|
323
|
+
|
|
324
|
+
for x, bands_list, group_key in zip(
|
|
325
|
+
[era5, tc, viirs],
|
|
326
|
+
[ERA5_BANDS, TC_BANDS, VIIRS_BANDS],
|
|
327
|
+
["ERA5", "TC", "VIIRS"],
|
|
328
|
+
):
|
|
329
|
+
if x is not None:
|
|
330
|
+
indices = [
|
|
331
|
+
idx for idx, val in enumerate(TIME_BANDS) if val in bands_list
|
|
332
|
+
]
|
|
333
|
+
groups_idx = [
|
|
334
|
+
idx
|
|
335
|
+
for idx, key in enumerate(TIME_BAND_GROUPS_IDX)
|
|
336
|
+
if group_key in key
|
|
337
|
+
]
|
|
338
|
+
t_x[:, :, indices] = x
|
|
339
|
+
t_m[:, :, groups_idx] = 0
|
|
340
|
+
|
|
341
|
+
for x, bands_list, group_key in zip(
|
|
342
|
+
[landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"]
|
|
343
|
+
):
|
|
344
|
+
if x is not None:
|
|
345
|
+
if group_key == "location":
|
|
346
|
+
# transform latlon to cartesian
|
|
347
|
+
x = cast(torch.Tensor, cls.to_cartesian(x[:, 0], x[:, 1]))
|
|
348
|
+
indices = [
|
|
349
|
+
idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list
|
|
350
|
+
]
|
|
351
|
+
groups_idx = [
|
|
352
|
+
idx
|
|
353
|
+
for idx, key in enumerate(STATIC_BAND_GROUPS_IDX)
|
|
354
|
+
if group_key in key
|
|
355
|
+
]
|
|
356
|
+
st_x[:, indices] = x
|
|
357
|
+
st_m[:, groups_idx] = 0
|
|
358
|
+
|
|
359
|
+
if months is None:
|
|
360
|
+
months = torch.ones((b, t), dtype=torch.long, device=device) * DEFAULT_MONTH
|
|
361
|
+
else:
|
|
362
|
+
if months.shape[1] != t:
|
|
363
|
+
raise ValueError("Incorrect number of input months")
|
|
364
|
+
|
|
365
|
+
if normalize:
|
|
366
|
+
s_t_x = (
|
|
367
|
+
torch.from_numpy(DEFAULT_NORMALIZER(s_t_x.cpu().numpy()))
|
|
368
|
+
.to(device)
|
|
369
|
+
.float()
|
|
370
|
+
)
|
|
371
|
+
sp_x = (
|
|
372
|
+
torch.from_numpy(DEFAULT_NORMALIZER(sp_x.cpu().numpy()))
|
|
373
|
+
.to(device)
|
|
374
|
+
.float()
|
|
375
|
+
)
|
|
376
|
+
t_x = (
|
|
377
|
+
torch.from_numpy(DEFAULT_NORMALIZER(t_x.cpu().numpy()))
|
|
378
|
+
.to(device)
|
|
379
|
+
.float()
|
|
380
|
+
)
|
|
381
|
+
st_x = (
|
|
382
|
+
torch.from_numpy(DEFAULT_NORMALIZER(st_x.cpu().numpy()))
|
|
383
|
+
.to(device)
|
|
384
|
+
.float()
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return MaskedOutput(
|
|
388
|
+
s_t_x=s_t_x,
|
|
389
|
+
s_t_m=s_t_m,
|
|
390
|
+
sp_x=sp_x,
|
|
391
|
+
sp_m=sp_m,
|
|
392
|
+
t_x=t_x,
|
|
393
|
+
t_m=t_m,
|
|
394
|
+
st_x=st_x,
|
|
395
|
+
st_m=st_m,
|
|
396
|
+
months=months,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
400
|
+
"""Compute feature maps from the Galileo backbone.
|
|
401
|
+
|
|
402
|
+
Inputs:
|
|
403
|
+
inputs: a dictionary of tensors, where the keys are one of Galileo.input_keys
|
|
404
|
+
(also documented below) and values are tensors of the following shapes,
|
|
405
|
+
per input key:
|
|
406
|
+
"s1": B (T * C) H W
|
|
407
|
+
"s2": B (T * C) H W
|
|
408
|
+
"era5": B (T * C) H W (we will average over the H, W dimensions)
|
|
409
|
+
"tc": B (T * C) H W (we will average over the H, W dimensions)
|
|
410
|
+
"viirs": B (T * C) H W (we will average over the H, W dimensions)
|
|
411
|
+
"srtm": B C H W (SRTM has no temporal dimension)
|
|
412
|
+
"dw": : B C H W (Dynamic World should be averaged over time)
|
|
413
|
+
"wc": B C H W (WorldCereal has no temporal dimension)
|
|
414
|
+
"landscan": B C H W (we will average over the H, W dimensions)
|
|
415
|
+
"latlon": B C H W (we will average over the H, W dimensions)
|
|
416
|
+
|
|
417
|
+
The output will be an embedding representing the pooled tokens. If there is
|
|
418
|
+
only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
|
|
419
|
+
a pool of all the unmasked tokens.
|
|
420
|
+
|
|
421
|
+
If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
|
|
422
|
+
take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
|
|
423
|
+
"""
|
|
424
|
+
stacked_inputs = {}
|
|
425
|
+
for key in inputs[0].keys():
|
|
426
|
+
# assume all the keys in an input are consistent
|
|
427
|
+
if key in self.input_keys:
|
|
428
|
+
stacked_inputs[key] = torch.stack([inp[key] for inp in inputs], dim=0)
|
|
429
|
+
s_t_channels = []
|
|
430
|
+
for space_time_modality in ["s1", "s2"]:
|
|
431
|
+
if space_time_modality not in stacked_inputs:
|
|
432
|
+
continue
|
|
433
|
+
if space_time_modality == "s1":
|
|
434
|
+
s_t_channels += self.s_t_channels_s1
|
|
435
|
+
else:
|
|
436
|
+
s_t_channels += self.s_t_channels_s2
|
|
437
|
+
cur = stacked_inputs[space_time_modality]
|
|
438
|
+
# Check if it's single or multitemporal, and reshape accordingly
|
|
439
|
+
num_bands = len(S2_BANDS) if space_time_modality == "s2" else len(S1_BANDS)
|
|
440
|
+
num_timesteps = cur.shape[1] // num_bands
|
|
441
|
+
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
|
|
442
|
+
stacked_inputs[space_time_modality] = cur
|
|
443
|
+
|
|
444
|
+
for space_modality in ["srtm", "dw", "wc"]:
|
|
445
|
+
if space_modality not in stacked_inputs:
|
|
446
|
+
continue
|
|
447
|
+
stacked_inputs[space_modality] = rearrange(
|
|
448
|
+
stacked_inputs[space_modality], "b c h w -> b h w c"
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
for time_modality in ["era5", "tc", "viirs"]:
|
|
452
|
+
if time_modality not in stacked_inputs:
|
|
453
|
+
continue
|
|
454
|
+
cur = stacked_inputs[time_modality]
|
|
455
|
+
# Check if it's single or multitemporal, and reshape accordingly
|
|
456
|
+
num_bands = {
|
|
457
|
+
"era5": len(ERA5_BANDS),
|
|
458
|
+
"tc": len(TC_BANDS),
|
|
459
|
+
"viirs": len(VIIRS_BANDS),
|
|
460
|
+
}[time_modality]
|
|
461
|
+
num_timesteps = cur.shape[1] // num_bands
|
|
462
|
+
# take the average over the h, w bands since Galileo
|
|
463
|
+
# treats it as a pixel-timeseries
|
|
464
|
+
cur = rearrange(
|
|
465
|
+
torch.nanmean(torch.nanmean(cur, dim=-1), dim=-1),
|
|
466
|
+
"b (t c) -> b t c",
|
|
467
|
+
t=num_timesteps,
|
|
468
|
+
)
|
|
469
|
+
stacked_inputs[time_modality] = cur
|
|
470
|
+
|
|
471
|
+
for static_modality in ["landscan", "latlon"]:
|
|
472
|
+
if static_modality not in stacked_inputs:
|
|
473
|
+
continue
|
|
474
|
+
cur = stacked_inputs[static_modality]
|
|
475
|
+
stacked_inputs[static_modality] = torch.nanmean(
|
|
476
|
+
torch.nanmean(cur, dim=-1), dim=-1
|
|
477
|
+
)
|
|
478
|
+
galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
|
|
479
|
+
h = galileo_input.s_t_x.shape[1]
|
|
480
|
+
if h < self.patch_size:
|
|
481
|
+
logger.warning(
|
|
482
|
+
f"Given patch size {self.patch_size} < h {h}. Reducing patch size to {h}"
|
|
483
|
+
)
|
|
484
|
+
patch_size = h
|
|
485
|
+
else:
|
|
486
|
+
patch_size = self.patch_size
|
|
487
|
+
outputs = self.model(
|
|
488
|
+
s_t_x=galileo_input.s_t_x,
|
|
489
|
+
s_t_m=galileo_input.s_t_m,
|
|
490
|
+
sp_x=galileo_input.sp_x,
|
|
491
|
+
sp_m=galileo_input.sp_m,
|
|
492
|
+
t_x=galileo_input.t_x,
|
|
493
|
+
t_m=galileo_input.t_m,
|
|
494
|
+
st_x=galileo_input.st_x,
|
|
495
|
+
st_m=galileo_input.st_m,
|
|
496
|
+
months=galileo_input.months,
|
|
497
|
+
patch_size=patch_size,
|
|
498
|
+
)
|
|
499
|
+
if h == patch_size:
|
|
500
|
+
# only one spatial patch, so we can just take an average
|
|
501
|
+
# of all the tokens to output b c_g 1 1
|
|
502
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, _ = outputs
|
|
503
|
+
averaged = self.model.average_tokens(
|
|
504
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
|
|
505
|
+
)
|
|
506
|
+
return [repeat(averaged, "b d -> b d 1 1")]
|
|
507
|
+
else:
|
|
508
|
+
s_t_x = outputs[0]
|
|
509
|
+
# we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
|
|
510
|
+
# s_t_x has shape [b, h, w, t, c_g, d]
|
|
511
|
+
# and we want [b, d, h, w]
|
|
512
|
+
return [
|
|
513
|
+
rearrange(
|
|
514
|
+
s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
|
|
515
|
+
"b h w c_g d -> b c_g d h w",
|
|
516
|
+
).mean(dim=1)
|
|
517
|
+
]
|