rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
"""Galileo models."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import tempfile
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from enum import StrEnum
|
|
8
|
+
from typing import cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from einops import rearrange, repeat
|
|
13
|
+
from huggingface_hub import hf_hub_download
|
|
14
|
+
from upath import UPath
|
|
15
|
+
|
|
16
|
+
from rslearn.log_utils import get_logger
|
|
17
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
18
|
+
from rslearn.models.galileo.single_file_galileo import (
|
|
19
|
+
CONFIG_FILENAME,
|
|
20
|
+
DW_BANDS,
|
|
21
|
+
ENCODER_FILENAME,
|
|
22
|
+
ERA5_BANDS,
|
|
23
|
+
LANDSCAN_BANDS,
|
|
24
|
+
LOCATION_BANDS,
|
|
25
|
+
S1_BANDS,
|
|
26
|
+
S2_BANDS,
|
|
27
|
+
SPACE_BAND_GROUPS_IDX,
|
|
28
|
+
SPACE_BANDS,
|
|
29
|
+
SPACE_TIME_BANDS,
|
|
30
|
+
SPACE_TIME_BANDS_GROUPS_IDX,
|
|
31
|
+
SRTM_BANDS,
|
|
32
|
+
STATIC_BAND_GROUPS_IDX,
|
|
33
|
+
STATIC_BANDS,
|
|
34
|
+
TC_BANDS,
|
|
35
|
+
TIME_BAND_GROUPS_IDX,
|
|
36
|
+
TIME_BANDS,
|
|
37
|
+
VIIRS_BANDS,
|
|
38
|
+
WC_BANDS,
|
|
39
|
+
Encoder,
|
|
40
|
+
MaskedOutput,
|
|
41
|
+
Normalizer,
|
|
42
|
+
)
|
|
43
|
+
from rslearn.train.model_context import ModelContext
|
|
44
|
+
|
|
45
|
+
logger = get_logger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
HF_HUB_ID = "nasaharvest/galileo"
|
|
49
|
+
DEFAULT_MONTH = 5
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Galileo provides three sizes: nano, tiny, base
|
|
53
|
+
class GalileoSize(StrEnum):
|
|
54
|
+
"""Size of the Galileo model."""
|
|
55
|
+
|
|
56
|
+
NANO = "nano"
|
|
57
|
+
TINY = "tiny"
|
|
58
|
+
BASE = "base"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
pretrained_weights: dict[GalileoSize, str] = {
|
|
62
|
+
GalileoSize.NANO: "models/nano",
|
|
63
|
+
GalileoSize.TINY: "models/tiny",
|
|
64
|
+
GalileoSize.BASE: "models/base",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
DEFAULT_NORMALIZER = Normalizer()
|
|
68
|
+
|
|
69
|
+
AUTOCAST_DTYPE_MAP = {
|
|
70
|
+
"bfloat16": torch.bfloat16,
|
|
71
|
+
"float32": torch.float32,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class GalileoModel(FeatureExtractor):
|
|
76
|
+
"""Galileo backbones."""
|
|
77
|
+
|
|
78
|
+
input_keys = [
|
|
79
|
+
"s1",
|
|
80
|
+
"s2",
|
|
81
|
+
"era5",
|
|
82
|
+
"tc",
|
|
83
|
+
"viirs",
|
|
84
|
+
"srtm",
|
|
85
|
+
"dw",
|
|
86
|
+
"wc",
|
|
87
|
+
"landscan",
|
|
88
|
+
"latlon",
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
size: GalileoSize,
|
|
94
|
+
patch_size: int = 4,
|
|
95
|
+
pretrained_path: str | UPath | None = None,
|
|
96
|
+
autocast_dtype: str | None = "bfloat16",
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Initialize the Galileo model.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
size: The size of the Galileo model.
|
|
102
|
+
patch_size: The patch size to use.
|
|
103
|
+
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
104
|
+
downloaded and cached in temp directory.
|
|
105
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
106
|
+
"""
|
|
107
|
+
super().__init__()
|
|
108
|
+
if pretrained_path is None:
|
|
109
|
+
pretrained_path = UPath(tempfile.gettempdir(), "rslearn_cache", "galileo")
|
|
110
|
+
|
|
111
|
+
pretrained_path_for_size = UPath(pretrained_path) / pretrained_weights[size]
|
|
112
|
+
if not (pretrained_path_for_size / CONFIG_FILENAME).exists():
|
|
113
|
+
_ = hf_hub_download(
|
|
114
|
+
local_dir=pretrained_path,
|
|
115
|
+
repo_id=HF_HUB_ID,
|
|
116
|
+
filename=f"{pretrained_weights[size]}/{CONFIG_FILENAME}",
|
|
117
|
+
revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
|
|
118
|
+
)
|
|
119
|
+
if not (pretrained_path_for_size / ENCODER_FILENAME).exists():
|
|
120
|
+
_ = hf_hub_download(
|
|
121
|
+
local_dir=pretrained_path,
|
|
122
|
+
repo_id=HF_HUB_ID,
|
|
123
|
+
filename=f"{pretrained_weights[size]}/{ENCODER_FILENAME}",
|
|
124
|
+
revision="f039dd5dde966a931baeda47eb680fa89b253e4e",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
assert (pretrained_path_for_size / ENCODER_FILENAME).exists()
|
|
128
|
+
assert (pretrained_path_for_size / CONFIG_FILENAME).exists()
|
|
129
|
+
|
|
130
|
+
self.model = Encoder.load_from_folder(
|
|
131
|
+
pretrained_path_for_size, device=torch.device("cpu")
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self.s_t_channels_s2 = [
|
|
135
|
+
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key
|
|
136
|
+
]
|
|
137
|
+
self.s_t_channels_s1 = [
|
|
138
|
+
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
self.size = size
|
|
142
|
+
self.patch_size = patch_size
|
|
143
|
+
|
|
144
|
+
if autocast_dtype is not None:
|
|
145
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
146
|
+
else:
|
|
147
|
+
self.autocast_dtype = None
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def to_cartesian(
|
|
151
|
+
lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
|
|
152
|
+
) -> np.ndarray | torch.Tensor:
|
|
153
|
+
"""Transform latitudes and longitudes to cartesian coordinates."""
|
|
154
|
+
if isinstance(lat, float):
|
|
155
|
+
assert -90 <= lat <= 90, (
|
|
156
|
+
f"lat out of range ({lat}). Make sure you are in EPSG:4326"
|
|
157
|
+
)
|
|
158
|
+
assert -180 <= lon <= 180, (
|
|
159
|
+
f"lon out of range ({lon}). Make sure you are in EPSG:4326"
|
|
160
|
+
)
|
|
161
|
+
assert isinstance(lon, float), f"Expected float got {type(lon)}"
|
|
162
|
+
# transform to radians
|
|
163
|
+
lat = lat * math.pi / 180
|
|
164
|
+
lon = lon * math.pi / 180
|
|
165
|
+
x = math.cos(lat) * math.cos(lon)
|
|
166
|
+
y = math.cos(lat) * math.sin(lon)
|
|
167
|
+
z = math.sin(lat)
|
|
168
|
+
return np.array([x, y, z])
|
|
169
|
+
elif isinstance(lon, np.ndarray):
|
|
170
|
+
assert -90 <= lat.min(), (
|
|
171
|
+
f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
|
|
172
|
+
)
|
|
173
|
+
assert 90 >= lat.max(), (
|
|
174
|
+
f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
|
|
175
|
+
)
|
|
176
|
+
assert -180 <= lon.min(), (
|
|
177
|
+
f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
|
|
178
|
+
)
|
|
179
|
+
assert 180 >= lon.max(), (
|
|
180
|
+
f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
|
|
181
|
+
)
|
|
182
|
+
assert isinstance(lat, np.ndarray), f"Expected np.ndarray got {type(lat)}"
|
|
183
|
+
# transform to radians
|
|
184
|
+
lat = lat * math.pi / 180
|
|
185
|
+
lon = lon * math.pi / 180
|
|
186
|
+
x_np = np.cos(lat) * np.cos(lon)
|
|
187
|
+
y_np = np.cos(lat) * np.sin(lon)
|
|
188
|
+
z_np = np.sin(lat)
|
|
189
|
+
return np.stack([x_np, y_np, z_np], axis=-1)
|
|
190
|
+
elif isinstance(lon, torch.Tensor):
|
|
191
|
+
assert -90 <= lat.min(), (
|
|
192
|
+
f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326"
|
|
193
|
+
)
|
|
194
|
+
assert 90 >= lat.max(), (
|
|
195
|
+
f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326"
|
|
196
|
+
)
|
|
197
|
+
assert -180 <= lon.min(), (
|
|
198
|
+
f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326"
|
|
199
|
+
)
|
|
200
|
+
assert 180 >= lon.max(), (
|
|
201
|
+
f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326"
|
|
202
|
+
)
|
|
203
|
+
assert isinstance(lat, torch.Tensor), (
|
|
204
|
+
f"Expected torch.Tensor got {type(lat)}"
|
|
205
|
+
)
|
|
206
|
+
# transform to radians
|
|
207
|
+
lat = lat * math.pi / 180
|
|
208
|
+
lon = lon * math.pi / 180
|
|
209
|
+
x_t = torch.cos(lat) * torch.cos(lon)
|
|
210
|
+
y_t = torch.cos(lat) * torch.sin(lon)
|
|
211
|
+
z_t = torch.sin(lat)
|
|
212
|
+
return torch.stack([x_t, y_t, z_t], dim=-1)
|
|
213
|
+
else:
|
|
214
|
+
raise AssertionError(f"Unexpected input type {type(lon)}")
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
def construct_galileo_input(
|
|
218
|
+
cls,
|
|
219
|
+
s1: torch.Tensor | None = None, # [H, W, T, D]
|
|
220
|
+
s2: torch.Tensor | None = None, # [H, W, T, D]
|
|
221
|
+
era5: torch.Tensor | None = None, # [T, D]
|
|
222
|
+
tc: torch.Tensor | None = None, # [T, D]
|
|
223
|
+
viirs: torch.Tensor | None = None, # [T, D]
|
|
224
|
+
srtm: torch.Tensor | None = None, # [H, W, D]
|
|
225
|
+
dw: torch.Tensor | None = None, # [H, W, D]
|
|
226
|
+
wc: torch.Tensor | None = None, # [H, W, D]
|
|
227
|
+
landscan: torch.Tensor | None = None, # [D]
|
|
228
|
+
latlon: torch.Tensor | None = None, # [D]
|
|
229
|
+
months: torch.Tensor | None = None, # [T]
|
|
230
|
+
normalize: bool = False,
|
|
231
|
+
) -> MaskedOutput:
|
|
232
|
+
"""Construct a Galileo input."""
|
|
233
|
+
space_time_inputs = [s1, s2]
|
|
234
|
+
time_inputs = [era5, tc, viirs]
|
|
235
|
+
space_inputs = [srtm, dw, wc]
|
|
236
|
+
static_inputs = [landscan, latlon]
|
|
237
|
+
devices = [
|
|
238
|
+
x.device
|
|
239
|
+
for x in space_time_inputs + time_inputs + space_inputs + static_inputs
|
|
240
|
+
if x is not None
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
if len(devices) == 0:
|
|
244
|
+
raise ValueError("At least one input must be not None")
|
|
245
|
+
if not all(devices[0] == device for device in devices):
|
|
246
|
+
raise ValueError("Received tensors on multiple devices")
|
|
247
|
+
device = devices[0]
|
|
248
|
+
|
|
249
|
+
# first, check all the input shapes are consistent
|
|
250
|
+
batch_list = (
|
|
251
|
+
[x.shape[0] for x in space_time_inputs if x is not None]
|
|
252
|
+
+ [x.shape[0] for x in time_inputs if x is not None]
|
|
253
|
+
+ [x.shape[0] for x in space_inputs if x is not None]
|
|
254
|
+
+ [x.shape[0] for x in static_inputs if x is not None]
|
|
255
|
+
)
|
|
256
|
+
timesteps_list = [x.shape[3] for x in space_time_inputs if x is not None] + [
|
|
257
|
+
x.shape[1] for x in time_inputs if x is not None
|
|
258
|
+
]
|
|
259
|
+
height_list = [x.shape[1] for x in space_time_inputs if x is not None] + [
|
|
260
|
+
x.shape[1] for x in space_inputs if x is not None
|
|
261
|
+
]
|
|
262
|
+
width_list = [x.shape[2] for x in space_time_inputs if x is not None] + [
|
|
263
|
+
x.shape[2] for x in space_inputs if x is not None
|
|
264
|
+
]
|
|
265
|
+
if len(batch_list) > 0:
|
|
266
|
+
if len(set(batch_list)) > 1:
|
|
267
|
+
raise ValueError("Inconsistent number of batch sizes per input")
|
|
268
|
+
b = batch_list[0]
|
|
269
|
+
|
|
270
|
+
if len(timesteps_list) > 0:
|
|
271
|
+
if not all(timesteps_list[0] == timestep for timestep in timesteps_list):
|
|
272
|
+
raise ValueError("Inconsistent number of timesteps per input")
|
|
273
|
+
t = timesteps_list[0]
|
|
274
|
+
else:
|
|
275
|
+
t = 1
|
|
276
|
+
if len(height_list) > 0:
|
|
277
|
+
if not all(height_list[0] == height for height in height_list):
|
|
278
|
+
raise ValueError("Inconsistent heights per input")
|
|
279
|
+
if not all(width_list[0] == width for width in width_list):
|
|
280
|
+
raise ValueError("Inconsistent widths per input")
|
|
281
|
+
h = height_list[0]
|
|
282
|
+
w = width_list[0]
|
|
283
|
+
else:
|
|
284
|
+
h, w = 1, 1
|
|
285
|
+
|
|
286
|
+
# now, we can construct our empty input tensors. By default, everything is masked
|
|
287
|
+
s_t_x = torch.zeros(
|
|
288
|
+
(b, h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device
|
|
289
|
+
)
|
|
290
|
+
s_t_m = torch.ones(
|
|
291
|
+
(b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)),
|
|
292
|
+
dtype=torch.float,
|
|
293
|
+
device=device,
|
|
294
|
+
)
|
|
295
|
+
sp_x = torch.zeros(
|
|
296
|
+
(b, h, w, len(SPACE_BANDS)), dtype=torch.float, device=device
|
|
297
|
+
)
|
|
298
|
+
sp_m = torch.ones(
|
|
299
|
+
(b, h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device
|
|
300
|
+
)
|
|
301
|
+
t_x = torch.zeros((b, t, len(TIME_BANDS)), dtype=torch.float, device=device)
|
|
302
|
+
t_m = torch.ones(
|
|
303
|
+
(b, t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device
|
|
304
|
+
)
|
|
305
|
+
st_x = torch.zeros((b, len(STATIC_BANDS)), dtype=torch.float, device=device)
|
|
306
|
+
st_m = torch.ones(
|
|
307
|
+
(b, len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
for x, bands_list, group_key in zip(
|
|
311
|
+
[s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]
|
|
312
|
+
):
|
|
313
|
+
if x is not None:
|
|
314
|
+
indices = [
|
|
315
|
+
idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list
|
|
316
|
+
]
|
|
317
|
+
groups_idx = [
|
|
318
|
+
idx
|
|
319
|
+
for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX)
|
|
320
|
+
if group_key in key
|
|
321
|
+
]
|
|
322
|
+
s_t_x[:, :, :, :, indices] = x
|
|
323
|
+
s_t_m[:, :, :, :, groups_idx] = 0
|
|
324
|
+
|
|
325
|
+
for x, bands_list, group_key in zip(
|
|
326
|
+
[srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"]
|
|
327
|
+
):
|
|
328
|
+
if x is not None:
|
|
329
|
+
indices = [
|
|
330
|
+
idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list
|
|
331
|
+
]
|
|
332
|
+
groups_idx = [
|
|
333
|
+
idx
|
|
334
|
+
for idx, key in enumerate(SPACE_BAND_GROUPS_IDX)
|
|
335
|
+
if group_key in key
|
|
336
|
+
]
|
|
337
|
+
sp_x[:, :, :, indices] = x
|
|
338
|
+
sp_m[:, :, :, groups_idx] = 0
|
|
339
|
+
|
|
340
|
+
for x, bands_list, group_key in zip(
|
|
341
|
+
[era5, tc, viirs],
|
|
342
|
+
[ERA5_BANDS, TC_BANDS, VIIRS_BANDS],
|
|
343
|
+
["ERA5", "TC", "VIIRS"],
|
|
344
|
+
):
|
|
345
|
+
if x is not None:
|
|
346
|
+
indices = [
|
|
347
|
+
idx for idx, val in enumerate(TIME_BANDS) if val in bands_list
|
|
348
|
+
]
|
|
349
|
+
groups_idx = [
|
|
350
|
+
idx
|
|
351
|
+
for idx, key in enumerate(TIME_BAND_GROUPS_IDX)
|
|
352
|
+
if group_key in key
|
|
353
|
+
]
|
|
354
|
+
t_x[:, :, indices] = x
|
|
355
|
+
t_m[:, :, groups_idx] = 0
|
|
356
|
+
|
|
357
|
+
for x, bands_list, group_key in zip(
|
|
358
|
+
[landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"]
|
|
359
|
+
):
|
|
360
|
+
if x is not None:
|
|
361
|
+
if group_key == "location":
|
|
362
|
+
# transform latlon to cartesian
|
|
363
|
+
x = cast(torch.Tensor, cls.to_cartesian(x[:, 0], x[:, 1]))
|
|
364
|
+
indices = [
|
|
365
|
+
idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list
|
|
366
|
+
]
|
|
367
|
+
groups_idx = [
|
|
368
|
+
idx
|
|
369
|
+
for idx, key in enumerate(STATIC_BAND_GROUPS_IDX)
|
|
370
|
+
if group_key in key
|
|
371
|
+
]
|
|
372
|
+
st_x[:, indices] = x
|
|
373
|
+
st_m[:, groups_idx] = 0
|
|
374
|
+
|
|
375
|
+
if months is None:
|
|
376
|
+
months = torch.ones((b, t), dtype=torch.long, device=device) * DEFAULT_MONTH
|
|
377
|
+
else:
|
|
378
|
+
if months.shape[1] != t:
|
|
379
|
+
raise ValueError("Incorrect number of input months")
|
|
380
|
+
|
|
381
|
+
if normalize:
|
|
382
|
+
s_t_x = (
|
|
383
|
+
torch.from_numpy(DEFAULT_NORMALIZER(s_t_x.cpu().numpy()))
|
|
384
|
+
.to(device)
|
|
385
|
+
.float()
|
|
386
|
+
)
|
|
387
|
+
sp_x = (
|
|
388
|
+
torch.from_numpy(DEFAULT_NORMALIZER(sp_x.cpu().numpy()))
|
|
389
|
+
.to(device)
|
|
390
|
+
.float()
|
|
391
|
+
)
|
|
392
|
+
t_x = (
|
|
393
|
+
torch.from_numpy(DEFAULT_NORMALIZER(t_x.cpu().numpy()))
|
|
394
|
+
.to(device)
|
|
395
|
+
.float()
|
|
396
|
+
)
|
|
397
|
+
st_x = (
|
|
398
|
+
torch.from_numpy(DEFAULT_NORMALIZER(st_x.cpu().numpy()))
|
|
399
|
+
.to(device)
|
|
400
|
+
.float()
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
return MaskedOutput(
|
|
404
|
+
s_t_x=s_t_x,
|
|
405
|
+
s_t_m=s_t_m,
|
|
406
|
+
sp_x=sp_x,
|
|
407
|
+
sp_m=sp_m,
|
|
408
|
+
t_x=t_x,
|
|
409
|
+
t_m=t_m,
|
|
410
|
+
st_x=st_x,
|
|
411
|
+
st_m=st_m,
|
|
412
|
+
months=months,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
@staticmethod
|
|
416
|
+
def time_ranges_to_timestamps(
|
|
417
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
418
|
+
device: torch.device,
|
|
419
|
+
) -> torch.Tensor:
|
|
420
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by Galileo.
|
|
421
|
+
|
|
422
|
+
Galileo only uses the month associated with each timestamp, so we take the midpoint
|
|
423
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
424
|
+
time so that start_time == end_time == mid_time.
|
|
425
|
+
"""
|
|
426
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
427
|
+
# months are indexed 0-11
|
|
428
|
+
return torch.tensor(
|
|
429
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
433
|
+
"""Compute feature maps from the Galileo backbone.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
|
|
437
|
+
(also documented below) and values are tensors of the following shapes,
|
|
438
|
+
per input key:
|
|
439
|
+
"s1": B C T H W
|
|
440
|
+
"s2": B C T H W
|
|
441
|
+
"era5": B C T H W (we will average over the H, W dimensions)
|
|
442
|
+
"tc": B C T H W (we will average over the H, W dimensions)
|
|
443
|
+
"viirs": B C T H W (we will average over the H, W dimensions)
|
|
444
|
+
"srtm": B C 1 H W (SRTM has no temporal dimension)
|
|
445
|
+
"dw": : B C 1 H W (Dynamic World should be averaged over time)
|
|
446
|
+
"wc": B C 1 H W (WorldCereal has no temporal dimension)
|
|
447
|
+
"landscan": B C 1 H W (we will average over the H, W dimensions)
|
|
448
|
+
"latlon": B C 1 H W (we will average over the H, W dimensions)
|
|
449
|
+
|
|
450
|
+
The output will be an embedding representing the pooled tokens. If there is
|
|
451
|
+
only a single token per h/w dimension (i.e. patch_size == h,w), then we will take
|
|
452
|
+
a pool of all the unmasked tokens.
|
|
453
|
+
|
|
454
|
+
If there are many spatial tokens per h/w dimension (patch_size > h,w), then we will
|
|
455
|
+
take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
|
|
456
|
+
"""
|
|
457
|
+
space_time_modalities = ["s1", "s2"]
|
|
458
|
+
time_modalities = ["era5", "tc", "viirs"]
|
|
459
|
+
stacked_inputs = {}
|
|
460
|
+
months: torch.Tensor | None = None
|
|
461
|
+
for key in context.inputs[0].keys():
|
|
462
|
+
# assume all the keys in an input are consistent
|
|
463
|
+
if key in self.input_keys:
|
|
464
|
+
stacked_inputs[key] = torch.stack(
|
|
465
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
466
|
+
)
|
|
467
|
+
if key in space_time_modalities + time_modalities:
|
|
468
|
+
if months is None:
|
|
469
|
+
if context.inputs[0][key].timestamps is not None:
|
|
470
|
+
months = torch.stack(
|
|
471
|
+
[
|
|
472
|
+
self.time_ranges_to_timestamps(
|
|
473
|
+
inp[key].timestamps, # type: ignore
|
|
474
|
+
device=stacked_inputs[key].device,
|
|
475
|
+
)
|
|
476
|
+
for inp in context.inputs
|
|
477
|
+
],
|
|
478
|
+
dim=0,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
if months is not None:
|
|
482
|
+
stacked_inputs["months"] = months
|
|
483
|
+
|
|
484
|
+
s_t_channels = []
|
|
485
|
+
for space_time_modality in space_time_modalities:
|
|
486
|
+
if space_time_modality not in stacked_inputs:
|
|
487
|
+
continue
|
|
488
|
+
if space_time_modality == "s1":
|
|
489
|
+
s_t_channels += self.s_t_channels_s1
|
|
490
|
+
else:
|
|
491
|
+
s_t_channels += self.s_t_channels_s2
|
|
492
|
+
cur = stacked_inputs[space_time_modality]
|
|
493
|
+
cur = rearrange(cur, "b c t h w -> b h w t c")
|
|
494
|
+
stacked_inputs[space_time_modality] = cur
|
|
495
|
+
|
|
496
|
+
for space_modality in ["srtm", "dw", "wc"]:
|
|
497
|
+
if space_modality not in stacked_inputs:
|
|
498
|
+
continue
|
|
499
|
+
# take the first (and assumed only) timestep
|
|
500
|
+
stacked_inputs[space_modality] = stacked_inputs[space_modality][:, :, 0]
|
|
501
|
+
stacked_inputs[space_modality] = rearrange(
|
|
502
|
+
stacked_inputs[space_modality], "b c h w -> b h w c"
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
for time_modality in time_modalities:
|
|
506
|
+
if time_modality not in stacked_inputs:
|
|
507
|
+
continue
|
|
508
|
+
cur = stacked_inputs[time_modality]
|
|
509
|
+
# take the average over the h, w bands since Galileo
|
|
510
|
+
# treats it as a pixel-timeseries
|
|
511
|
+
cur = rearrange(
|
|
512
|
+
torch.nanmean(cur, dim=(-1, -2)),
|
|
513
|
+
"b c t -> b t c",
|
|
514
|
+
)
|
|
515
|
+
stacked_inputs[time_modality] = cur
|
|
516
|
+
|
|
517
|
+
for static_modality in ["landscan", "latlon"]:
|
|
518
|
+
if static_modality not in stacked_inputs:
|
|
519
|
+
continue
|
|
520
|
+
cur = stacked_inputs[static_modality]
|
|
521
|
+
stacked_inputs[static_modality] = torch.nanmean(cur, dim=(2, 3, 4))
|
|
522
|
+
|
|
523
|
+
galileo_input = self.construct_galileo_input(**stacked_inputs, normalize=True)
|
|
524
|
+
h = galileo_input.s_t_x.shape[1]
|
|
525
|
+
if h < self.patch_size:
|
|
526
|
+
logger.warning(
|
|
527
|
+
f"Given patch size {self.patch_size} < h {h}. Reducing patch size to {h}"
|
|
528
|
+
)
|
|
529
|
+
patch_size = h
|
|
530
|
+
else:
|
|
531
|
+
patch_size = self.patch_size
|
|
532
|
+
|
|
533
|
+
# Decide context based on self.autocast_dtype.
|
|
534
|
+
device = galileo_input.s_t_x.device
|
|
535
|
+
if self.autocast_dtype is None:
|
|
536
|
+
torch_context = nullcontext()
|
|
537
|
+
else:
|
|
538
|
+
assert device is not None
|
|
539
|
+
torch_context = torch.amp.autocast(
|
|
540
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
541
|
+
)
|
|
542
|
+
with torch_context:
|
|
543
|
+
outputs = self.model(
|
|
544
|
+
s_t_x=galileo_input.s_t_x,
|
|
545
|
+
s_t_m=galileo_input.s_t_m,
|
|
546
|
+
sp_x=galileo_input.sp_x,
|
|
547
|
+
sp_m=galileo_input.sp_m,
|
|
548
|
+
t_x=galileo_input.t_x,
|
|
549
|
+
t_m=galileo_input.t_m,
|
|
550
|
+
st_x=galileo_input.st_x,
|
|
551
|
+
st_m=galileo_input.st_m,
|
|
552
|
+
months=galileo_input.months,
|
|
553
|
+
patch_size=patch_size,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
if h == patch_size:
|
|
557
|
+
# only one spatial patch, so we can just take an average
|
|
558
|
+
# of all the tokens to output b c_g 1 1
|
|
559
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, _ = outputs
|
|
560
|
+
averaged = self.model.average_tokens(
|
|
561
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
|
|
562
|
+
)
|
|
563
|
+
return FeatureMaps([repeat(averaged, "b d -> b d 1 1")])
|
|
564
|
+
else:
|
|
565
|
+
s_t_x = outputs[0]
|
|
566
|
+
# we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
|
|
567
|
+
# s_t_x has shape [b, h, w, t, c_g, d]
|
|
568
|
+
# and we want [b, d, h, w]
|
|
569
|
+
return FeatureMaps(
|
|
570
|
+
[
|
|
571
|
+
rearrange(
|
|
572
|
+
s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
|
|
573
|
+
"b h w c_g d -> b c_g d h w",
|
|
574
|
+
).mean(dim=1)
|
|
575
|
+
]
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
def get_backbone_channels(self) -> list:
|
|
579
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
580
|
+
|
|
581
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
582
|
+
to the feature maps that the backbone returns.
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
586
|
+
"""
|
|
587
|
+
if self.size == GalileoSize.BASE:
|
|
588
|
+
depth = 768
|
|
589
|
+
elif self.model_size == GalileoSize.TINY:
|
|
590
|
+
depth = 192
|
|
591
|
+
elif self.model_size == GalileoSize.NANO:
|
|
592
|
+
depth = 128
|
|
593
|
+
else:
|
|
594
|
+
raise ValueError(f"Invalid model size: {self.size}")
|
|
595
|
+
return [(self.patch_size, depth)]
|