rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1678 @@
|
|
|
1
|
+
"""Galileo models."""
|
|
2
|
+
|
|
3
|
+
import collections.abc
|
|
4
|
+
import itertools
|
|
5
|
+
import json
|
|
6
|
+
import math
|
|
7
|
+
from abc import abstractmethod
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from collections import OrderedDict as OrderedDictType
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
from copy import deepcopy
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import NamedTuple, cast
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
from einops import rearrange, repeat
|
|
20
|
+
from torch import Tensor, vmap
|
|
21
|
+
from torch.jit import Final
|
|
22
|
+
from typing_extensions import override
|
|
23
|
+
|
|
24
|
+
from rslearn.log_utils import get_logger
|
|
25
|
+
|
|
26
|
+
logger = get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# constants
|
|
30
|
+
CONFIG_FILENAME = "config.json"
|
|
31
|
+
ENCODER_FILENAME = "encoder.pt"
|
|
32
|
+
BASE_GSD = 10
|
|
33
|
+
DEFAULT_MONTH = 5
|
|
34
|
+
|
|
35
|
+
# band information
|
|
36
|
+
S1_BANDS = ["VV", "VH"]
|
|
37
|
+
S1_SHIFT_VALUES = [25.0, 25.0]
|
|
38
|
+
S1_DIV_VALUES = [25.0, 25.0]
|
|
39
|
+
S2_BANDS = [
|
|
40
|
+
"B2",
|
|
41
|
+
"B3",
|
|
42
|
+
"B4",
|
|
43
|
+
"B5",
|
|
44
|
+
"B6",
|
|
45
|
+
"B7",
|
|
46
|
+
"B8",
|
|
47
|
+
"B8A",
|
|
48
|
+
"B11",
|
|
49
|
+
"B12",
|
|
50
|
+
]
|
|
51
|
+
S2_SHIFT_VALUES = [0.0] * len(S2_BANDS)
|
|
52
|
+
S2_DIV_VALUES = [1e4] * len(S2_BANDS)
|
|
53
|
+
ERA5_BANDS = ["temperature_2m", "total_precipitation_sum"]
|
|
54
|
+
# for temperature, shift to celcius and then divide by 35 based on notebook (ranges from)
|
|
55
|
+
# 37 to -22 degrees celcius
|
|
56
|
+
# For rainfall, based on
|
|
57
|
+
# https://github.com/nasaharvest/presto/blob/main/notebooks/exploratory_data_analysis.ipynb
|
|
58
|
+
ERA5_SHIFT_VALUES = [-272.15, 0.0]
|
|
59
|
+
ERA5_DIV_VALUES = [35.0, 0.03]
|
|
60
|
+
TC_BANDS = ["def", "soil", "aet"]
|
|
61
|
+
TC_SHIFT_VALUES = [0.0, 0.0, 0.0]
|
|
62
|
+
TC_DIV_VALUES = [4548, 8882, 2000]
|
|
63
|
+
VIIRS_BANDS = ["avg_rad"]
|
|
64
|
+
VIIRS_SHIFT_VALUES = [0.0]
|
|
65
|
+
# visually checked - this seems much more reasonable than
|
|
66
|
+
# the GEE estimate
|
|
67
|
+
VIIRS_DIV_VALUES = [100]
|
|
68
|
+
SRTM_BANDS = ["elevation", "slope"]
|
|
69
|
+
# visually gauged 90th percentile from
|
|
70
|
+
# https://github.com/nasaharvest/presto/blob/main/notebooks/exploratory_data_analysis.ipynb
|
|
71
|
+
SRTM_SHIFT_VALUES = [0.0, 0.0]
|
|
72
|
+
SRTM_DIV_VALUES = [2000.0, 50.0]
|
|
73
|
+
DW_BANDS = [
|
|
74
|
+
"DW_water",
|
|
75
|
+
"DW_trees",
|
|
76
|
+
"DW_grass",
|
|
77
|
+
"DW_flooded_vegetation",
|
|
78
|
+
"DW_crops",
|
|
79
|
+
"DW_shrub_and_scrub",
|
|
80
|
+
"DW_built",
|
|
81
|
+
"DW_bare",
|
|
82
|
+
"DW_snow_and_ice",
|
|
83
|
+
]
|
|
84
|
+
DW_SHIFT_VALUES = [0] * len(DW_BANDS)
|
|
85
|
+
DW_DIV_VALUES = [1] * len(DW_BANDS)
|
|
86
|
+
|
|
87
|
+
WC_BANDS = [
|
|
88
|
+
"WC_temporarycrops",
|
|
89
|
+
"WC_maize",
|
|
90
|
+
"WC_wintercereals",
|
|
91
|
+
"WC_springcereals",
|
|
92
|
+
"WC_irrigation",
|
|
93
|
+
]
|
|
94
|
+
WC_SHIFT_VALUES = [0] * len(WC_BANDS)
|
|
95
|
+
WC_DIV_VALUES = [100] * len(WC_BANDS)
|
|
96
|
+
STATIC_DW_BANDS = [f"{x}_static" for x in DW_BANDS]
|
|
97
|
+
STATIC_WC_BANDS = [f"{x}_static" for x in WC_BANDS]
|
|
98
|
+
|
|
99
|
+
LANDSCAN_BANDS = ["b1"]
|
|
100
|
+
# LANDSCAN values range from approximately 0 to 185000 in 2022: https://code.earthengine.google.com/?scriptPath=users/sat-io/awesome-gee-catalog-examples:population-socioeconomics/LANDSCAN-GLOBAL
|
|
101
|
+
LANDSCAN_SHIFT_VALUES = [92500]
|
|
102
|
+
LANDSCAN_DIV_VALUES = [92500]
|
|
103
|
+
LOCATION_BANDS = ["x", "y", "z"]
|
|
104
|
+
|
|
105
|
+
SPACE_TIME_BANDS = S1_BANDS + S2_BANDS + ["NDVI"]
|
|
106
|
+
TIME_BANDS = ERA5_BANDS + TC_BANDS + VIIRS_BANDS
|
|
107
|
+
SPACE_BANDS = SRTM_BANDS + DW_BANDS + WC_BANDS
|
|
108
|
+
STATIC_BANDS = LANDSCAN_BANDS + LOCATION_BANDS + STATIC_DW_BANDS + STATIC_WC_BANDS
|
|
109
|
+
|
|
110
|
+
# 0 for NDVI
|
|
111
|
+
SPACE_TIME_SHIFT_VALUES = np.array(S1_SHIFT_VALUES + S2_SHIFT_VALUES + [0])
|
|
112
|
+
SPACE_TIME_DIV_VALUES = np.array(S1_DIV_VALUES + S2_DIV_VALUES + [1])
|
|
113
|
+
TIME_SHIFT_VALUES = np.array(ERA5_SHIFT_VALUES + TC_SHIFT_VALUES + VIIRS_SHIFT_VALUES)
|
|
114
|
+
TIME_DIV_VALUES = np.array(ERA5_DIV_VALUES + TC_DIV_VALUES + VIIRS_DIV_VALUES)
|
|
115
|
+
SPACE_SHIFT_VALUES = np.array(SRTM_SHIFT_VALUES + DW_SHIFT_VALUES + WC_SHIFT_VALUES)
|
|
116
|
+
SPACE_DIV_VALUES = np.array(SRTM_DIV_VALUES + DW_DIV_VALUES + WC_DIV_VALUES)
|
|
117
|
+
# [0s, 1s] for the locations
|
|
118
|
+
STATIC_SHIFT_VALUES = np.array(
|
|
119
|
+
LANDSCAN_SHIFT_VALUES + [0, 0, 0] + DW_SHIFT_VALUES + WC_SHIFT_VALUES
|
|
120
|
+
)
|
|
121
|
+
STATIC_DIV_VALUES = np.array(
|
|
122
|
+
LANDSCAN_DIV_VALUES + [1, 1, 1] + DW_DIV_VALUES + WC_DIV_VALUES
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
SPACE_TIME_BANDS_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
|
|
126
|
+
{
|
|
127
|
+
"S1": [SPACE_TIME_BANDS.index(b) for b in S1_BANDS],
|
|
128
|
+
"S2_RGB": [SPACE_TIME_BANDS.index(b) for b in ["B2", "B3", "B4"]],
|
|
129
|
+
"S2_Red_Edge": [SPACE_TIME_BANDS.index(b) for b in ["B5", "B6", "B7"]],
|
|
130
|
+
"S2_NIR_10m": [SPACE_TIME_BANDS.index(b) for b in ["B8"]],
|
|
131
|
+
"S2_NIR_20m": [SPACE_TIME_BANDS.index(b) for b in ["B8A"]],
|
|
132
|
+
"S2_SWIR": [SPACE_TIME_BANDS.index(b) for b in ["B11", "B12"]],
|
|
133
|
+
"NDVI": [SPACE_TIME_BANDS.index("NDVI")],
|
|
134
|
+
}
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
TIME_BAND_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
|
|
138
|
+
{
|
|
139
|
+
"ERA5": [TIME_BANDS.index(b) for b in ERA5_BANDS],
|
|
140
|
+
"TC": [TIME_BANDS.index(b) for b in TC_BANDS],
|
|
141
|
+
"VIIRS": [TIME_BANDS.index(b) for b in VIIRS_BANDS],
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
SPACE_BAND_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
|
|
146
|
+
{
|
|
147
|
+
"SRTM": [SPACE_BANDS.index(b) for b in SRTM_BANDS],
|
|
148
|
+
"DW": [SPACE_BANDS.index(b) for b in DW_BANDS],
|
|
149
|
+
"WC": [SPACE_BANDS.index(b) for b in WC_BANDS],
|
|
150
|
+
}
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
STATIC_BAND_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
|
|
154
|
+
{
|
|
155
|
+
"LS": [STATIC_BANDS.index(b) for b in LANDSCAN_BANDS],
|
|
156
|
+
"location": [STATIC_BANDS.index(b) for b in LOCATION_BANDS],
|
|
157
|
+
"DW_static": [STATIC_BANDS.index(b) for b in STATIC_DW_BANDS],
|
|
158
|
+
"WC_static": [STATIC_BANDS.index(b) for b in STATIC_WC_BANDS],
|
|
159
|
+
}
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# this normalizing dict is sourced from
|
|
164
|
+
# https://github.com/nasaharvest/galileo/blob/main/config/normalization.json
|
|
165
|
+
# its used to normalize the data. The keys (e.g. "13") are used to query
|
|
166
|
+
# which tensor (e.g. space_time_x) is associated to the means and stds,
|
|
167
|
+
# where the key represents the number of dimensions in the tensor (i.e. x.shape[-1])
|
|
168
|
+
NORMALIZING_DICT = {
|
|
169
|
+
"total_n": 127155,
|
|
170
|
+
"sampled_n": 10000,
|
|
171
|
+
"13": {
|
|
172
|
+
"mean": [
|
|
173
|
+
-11.728724389184965,
|
|
174
|
+
-18.85558188024017,
|
|
175
|
+
1395.3408730676722,
|
|
176
|
+
1338.4026921784578,
|
|
177
|
+
1343.09883810357,
|
|
178
|
+
1543.8607982512297,
|
|
179
|
+
2186.2022069512263,
|
|
180
|
+
2525.0932853316694,
|
|
181
|
+
2410.3377187373408,
|
|
182
|
+
2750.2854646886753,
|
|
183
|
+
2234.911100061487,
|
|
184
|
+
1474.5311266077113,
|
|
185
|
+
0.2892116502999044,
|
|
186
|
+
],
|
|
187
|
+
"std": [
|
|
188
|
+
4.887145774840316,
|
|
189
|
+
5.730270320384293,
|
|
190
|
+
917.7041440370853,
|
|
191
|
+
913.2988423581528,
|
|
192
|
+
1092.678723527555,
|
|
193
|
+
1047.2206083460424,
|
|
194
|
+
1048.0101611156767,
|
|
195
|
+
1143.6903026819996,
|
|
196
|
+
1098.979177731649,
|
|
197
|
+
1204.472755085893,
|
|
198
|
+
1145.9774063078878,
|
|
199
|
+
980.2429840007796,
|
|
200
|
+
0.2720939024500081,
|
|
201
|
+
],
|
|
202
|
+
},
|
|
203
|
+
"16": {
|
|
204
|
+
"mean": [
|
|
205
|
+
673.0152819503361,
|
|
206
|
+
5.930092668915115,
|
|
207
|
+
0.10470439140978786,
|
|
208
|
+
0.23965913270066183,
|
|
209
|
+
0.08158044385860364,
|
|
210
|
+
0.04246976254259546,
|
|
211
|
+
0.11304392863520317,
|
|
212
|
+
0.17329647890362473,
|
|
213
|
+
0.0698981691616277,
|
|
214
|
+
0.12130267132802142,
|
|
215
|
+
0.04671318615236216,
|
|
216
|
+
10.973119802517362,
|
|
217
|
+
1.0927069179958768,
|
|
218
|
+
1.6991394232855903,
|
|
219
|
+
0.03720594618055555,
|
|
220
|
+
1.3671352688259548,
|
|
221
|
+
],
|
|
222
|
+
"std": [
|
|
223
|
+
983.0697298296237,
|
|
224
|
+
8.167406789813247,
|
|
225
|
+
0.18771647977504985,
|
|
226
|
+
0.2368313455675914,
|
|
227
|
+
0.08024268534756586,
|
|
228
|
+
0.04045374496146404,
|
|
229
|
+
0.11350342472061795,
|
|
230
|
+
0.1279898111718168,
|
|
231
|
+
0.12042341550438586,
|
|
232
|
+
0.13602408145504347,
|
|
233
|
+
0.043971116096060345,
|
|
234
|
+
31.255340146970997,
|
|
235
|
+
10.395974878206689,
|
|
236
|
+
12.92380617159917,
|
|
237
|
+
1.9285254295940466,
|
|
238
|
+
11.612179775408928,
|
|
239
|
+
],
|
|
240
|
+
},
|
|
241
|
+
"6": {
|
|
242
|
+
"mean": [
|
|
243
|
+
271.5674963541667,
|
|
244
|
+
0.08554303677156568,
|
|
245
|
+
657.3181260091111,
|
|
246
|
+
692.1291795806885,
|
|
247
|
+
562.781331880633,
|
|
248
|
+
1.5647115934036673,
|
|
249
|
+
],
|
|
250
|
+
"std": [
|
|
251
|
+
79.80828940314429,
|
|
252
|
+
0.11669547098151486,
|
|
253
|
+
704.0008695557707,
|
|
254
|
+
925.0116126406431,
|
|
255
|
+
453.2434022278578,
|
|
256
|
+
7.513020170832818,
|
|
257
|
+
],
|
|
258
|
+
},
|
|
259
|
+
"18": {
|
|
260
|
+
"mean": [
|
|
261
|
+
188.20315880851746,
|
|
262
|
+
0.2804946561574936,
|
|
263
|
+
0.11371652073860168,
|
|
264
|
+
0.058778801321983334,
|
|
265
|
+
0.10474256777763366,
|
|
266
|
+
0.2396918488264084,
|
|
267
|
+
0.08152248692512512,
|
|
268
|
+
0.04248040814399719,
|
|
269
|
+
0.11303179881572724,
|
|
270
|
+
0.17326324067115784,
|
|
271
|
+
0.06998309404850006,
|
|
272
|
+
0.12122812910079957,
|
|
273
|
+
0.04671641788482666,
|
|
274
|
+
10.98456594619751,
|
|
275
|
+
1.0968475807189941,
|
|
276
|
+
1.6947754135131836,
|
|
277
|
+
0.03320046615600586,
|
|
278
|
+
1.3602827312469483,
|
|
279
|
+
],
|
|
280
|
+
"std": [
|
|
281
|
+
1154.5919128300602,
|
|
282
|
+
0.5276998078079327,
|
|
283
|
+
0.7021637331734328,
|
|
284
|
+
0.36528892213195063,
|
|
285
|
+
0.17470213191865785,
|
|
286
|
+
0.20411195416718833,
|
|
287
|
+
0.0660782470089761,
|
|
288
|
+
0.03380702424871257,
|
|
289
|
+
0.09809195568521663,
|
|
290
|
+
0.11292471052124119,
|
|
291
|
+
0.09720748930233268,
|
|
292
|
+
0.12912217763726777,
|
|
293
|
+
0.0399973913151906,
|
|
294
|
+
23.725471823867462,
|
|
295
|
+
5.715238079725388,
|
|
296
|
+
9.030481416228302,
|
|
297
|
+
0.9950220242487364,
|
|
298
|
+
7.754429123862099,
|
|
299
|
+
],
|
|
300
|
+
},
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class Normalizer:
|
|
305
|
+
"""Normalize Galileo inputs."""
|
|
306
|
+
|
|
307
|
+
std_bands: dict[int, list] = {
|
|
308
|
+
# we exclude NDVI because its already between 0 and 1, so we don't
|
|
309
|
+
# want to apply further normalization to it.
|
|
310
|
+
len(SPACE_TIME_BANDS): [b for b in SPACE_TIME_BANDS if b != "NDVI"],
|
|
311
|
+
len(SPACE_BANDS): SRTM_BANDS,
|
|
312
|
+
len(TIME_BANDS): TIME_BANDS,
|
|
313
|
+
len(STATIC_BANDS): LANDSCAN_BANDS,
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
def __init__(self, std_multiplier: float = 2):
|
|
317
|
+
"""Normalize Galileo inputs.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
std_multiplier: std_multiplier to apply
|
|
321
|
+
"""
|
|
322
|
+
name_to_bands = {
|
|
323
|
+
len(SPACE_TIME_BANDS): SPACE_TIME_BANDS,
|
|
324
|
+
len(SPACE_BANDS): SPACE_BANDS,
|
|
325
|
+
len(TIME_BANDS): TIME_BANDS,
|
|
326
|
+
len(STATIC_BANDS): STATIC_BANDS,
|
|
327
|
+
}
|
|
328
|
+
self.shift_div_dict = {
|
|
329
|
+
len(SPACE_TIME_BANDS): {
|
|
330
|
+
"shift": deepcopy(SPACE_TIME_SHIFT_VALUES),
|
|
331
|
+
"div": deepcopy(SPACE_TIME_DIV_VALUES),
|
|
332
|
+
},
|
|
333
|
+
len(SPACE_BANDS): {
|
|
334
|
+
"shift": deepcopy(SPACE_SHIFT_VALUES),
|
|
335
|
+
"div": deepcopy(SPACE_DIV_VALUES),
|
|
336
|
+
},
|
|
337
|
+
len(TIME_BANDS): {
|
|
338
|
+
"shift": deepcopy(TIME_SHIFT_VALUES),
|
|
339
|
+
"div": deepcopy(TIME_DIV_VALUES),
|
|
340
|
+
},
|
|
341
|
+
len(STATIC_BANDS): {
|
|
342
|
+
"shift": deepcopy(STATIC_SHIFT_VALUES),
|
|
343
|
+
"div": deepcopy(STATIC_DIV_VALUES),
|
|
344
|
+
},
|
|
345
|
+
}
|
|
346
|
+
for key_as_str, val in NORMALIZING_DICT.items():
|
|
347
|
+
if "n" in key_as_str:
|
|
348
|
+
continue
|
|
349
|
+
key = int(key_as_str)
|
|
350
|
+
bands_to_replace = self.std_bands[key]
|
|
351
|
+
for band in bands_to_replace:
|
|
352
|
+
band_idx = name_to_bands[key].index(band)
|
|
353
|
+
mean = cast(dict[str, list], val)["mean"][band_idx]
|
|
354
|
+
std = cast(dict[str, list], val)["std"][band_idx]
|
|
355
|
+
min_value = mean - (std_multiplier * std)
|
|
356
|
+
max_value = mean + (std_multiplier * std)
|
|
357
|
+
div = max_value - min_value
|
|
358
|
+
if div == 0:
|
|
359
|
+
raise ValueError(f"{band} has div value of 0")
|
|
360
|
+
self.shift_div_dict[key]["shift"][band_idx] = min_value
|
|
361
|
+
self.shift_div_dict[key]["div"][band_idx] = div
|
|
362
|
+
|
|
363
|
+
@staticmethod
|
|
364
|
+
def _normalize(
|
|
365
|
+
x: torch.Tensor, shift_values: torch.Tensor, div_values: torch.Tensor
|
|
366
|
+
) -> torch.Tensor:
|
|
367
|
+
x = (x - shift_values) / div_values
|
|
368
|
+
return x
|
|
369
|
+
|
|
370
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
371
|
+
"""Apply the normalizer."""
|
|
372
|
+
div_values = self.shift_div_dict[x.shape[-1]]["div"]
|
|
373
|
+
shift_values = self.shift_div_dict[x.shape[-1]]["shift"]
|
|
374
|
+
return self._normalize(x, shift_values, div_values)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class MaskedOutput(NamedTuple):
|
|
378
|
+
"""A masked output (i.e. an input to Galileo).
|
|
379
|
+
|
|
380
|
+
A mask can take 3 values:
|
|
381
|
+
0: seen by the encoder (i.e. makes the key and value tokens in the decoder)
|
|
382
|
+
1: not seen by the encoder, and ignored by the decoder
|
|
383
|
+
2: not seen by the encoder, and processed by the decoder (the decoder's query values)
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
s_t_x: torch.Tensor # [B, H, W, T, len(SPACE_TIME_BANDS)]
|
|
387
|
+
sp_x: torch.Tensor # [B, H, W, len(SPACE_BANDS)]
|
|
388
|
+
t_x: torch.Tensor # [B, T, len(TIME_BANDS)]
|
|
389
|
+
st_x: torch.Tensor # [B, len(STATIC_BANDS)]
|
|
390
|
+
s_t_m: torch.Tensor # [B, H, W, T, len(SPACE_TIME_BANDS_GROUPS_IDX)]
|
|
391
|
+
sp_m: torch.Tensor # [B, H, W, len(SPACE_BAND_GROUPS_IDX)]
|
|
392
|
+
t_m: torch.Tensor # [B, T, len(TIME_BAND_GROUPS_IDX)]
|
|
393
|
+
st_m: torch.Tensor # [B, len(STATIC_BAND_GROUPS_IDX)]
|
|
394
|
+
months: torch.Tensor # [B, T]
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def get_2d_sincos_pos_embed_with_resolution(
|
|
398
|
+
embed_dim: int,
|
|
399
|
+
grid_size: int,
|
|
400
|
+
res: torch.Tensor,
|
|
401
|
+
cls_token: bool = False,
|
|
402
|
+
device: str = "cpu",
|
|
403
|
+
) -> torch.Tensor:
|
|
404
|
+
"""Create 2d sincos embeddings with resolution.
|
|
405
|
+
|
|
406
|
+
grid_size: int of the grid height and width
|
|
407
|
+
res: array of size n, representing the resolution of a pixel (say, in meters),
|
|
408
|
+
|
|
409
|
+
Return:
|
|
410
|
+
pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
411
|
+
"""
|
|
412
|
+
res = res.to(device)
|
|
413
|
+
grid_h = torch.arange(grid_size, device=device)
|
|
414
|
+
grid_w = torch.arange(grid_size, device=device)
|
|
415
|
+
grid = torch.meshgrid(
|
|
416
|
+
grid_w, grid_h, indexing="xy"
|
|
417
|
+
) # here h goes first,direction reversed for numpy
|
|
418
|
+
grid = torch.stack(grid, dim=0) # 2 x h x w
|
|
419
|
+
|
|
420
|
+
# grid = grid.reshape([2, 1, grid_size, grid_size])
|
|
421
|
+
grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
|
|
422
|
+
_, n, h, w = grid.shape
|
|
423
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid_torch(
|
|
424
|
+
embed_dim, grid
|
|
425
|
+
) # # (nxH*W, D/2)
|
|
426
|
+
pos_embed = pos_embed.reshape(n, h * w, embed_dim)
|
|
427
|
+
if cls_token:
|
|
428
|
+
pos_embed = torch.cat(
|
|
429
|
+
[
|
|
430
|
+
torch.zeros([n, 1, embed_dim], device=pos_embed.device),
|
|
431
|
+
pos_embed,
|
|
432
|
+
],
|
|
433
|
+
dim=1,
|
|
434
|
+
)
|
|
435
|
+
return pos_embed
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def get_2d_sincos_pos_embed_from_grid_torch(
|
|
439
|
+
embed_dim: int, grid: torch.Tensor
|
|
440
|
+
) -> torch.Tensor:
|
|
441
|
+
"""get_2d_sincos_pos_embed_from_grid_torch."""
|
|
442
|
+
assert embed_dim % 2 == 0
|
|
443
|
+
|
|
444
|
+
# use half of dimensions to encode grid_h
|
|
445
|
+
emb_h = get_1d_sincos_pos_embed_from_grid_torch(
|
|
446
|
+
embed_dim // 2, grid[0]
|
|
447
|
+
) # (H*W, D/2)
|
|
448
|
+
emb_w = get_1d_sincos_pos_embed_from_grid_torch(
|
|
449
|
+
embed_dim // 2, grid[1]
|
|
450
|
+
) # (H*W, D/2)
|
|
451
|
+
|
|
452
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
|
|
453
|
+
return emb
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def get_1d_sincos_pos_embed_from_grid_torch(
|
|
457
|
+
embed_dim: int, pos: torch.Tensor
|
|
458
|
+
) -> torch.Tensor:
|
|
459
|
+
"""get_1d_sincos_pos_embed_from_grid_torch.
|
|
460
|
+
|
|
461
|
+
embed_dim: output dimension for each position
|
|
462
|
+
pos: a list of positions to be encoded: size (M,)
|
|
463
|
+
out: (M, D)
|
|
464
|
+
"""
|
|
465
|
+
assert embed_dim % 2 == 0
|
|
466
|
+
omega = torch.arange(embed_dim // 2, device=pos.device) / embed_dim / 2.0
|
|
467
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
468
|
+
|
|
469
|
+
pos = pos.reshape(-1) # (M,)
|
|
470
|
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
471
|
+
|
|
472
|
+
emb_sin = torch.sin(out) # (M, D/2)
|
|
473
|
+
emb_cos = torch.cos(out) # (M, D/2)
|
|
474
|
+
|
|
475
|
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
|
476
|
+
return emb
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def get_month_encoding_table(embed_dim: int) -> torch.Tensor:
|
|
480
|
+
"""Sinusoid month encoding table, for 12 months indexed from 0-11."""
|
|
481
|
+
assert embed_dim % 2 == 0
|
|
482
|
+
angles = torch.arange(0, 13) / (12 / (2 * np.pi))
|
|
483
|
+
|
|
484
|
+
sin_table = torch.sin(torch.stack([angles for _ in range(embed_dim // 2)], axis=-1))
|
|
485
|
+
cos_table = torch.cos(torch.stack([angles for _ in range(embed_dim // 2)], axis=-1))
|
|
486
|
+
month_table = torch.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
|
|
487
|
+
|
|
488
|
+
return month_table # (M, D)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def adjust_learning_rate(
|
|
492
|
+
optimizer: torch.optim.Optimizer,
|
|
493
|
+
epoch: int,
|
|
494
|
+
warmup_epochs: int,
|
|
495
|
+
total_epochs: int,
|
|
496
|
+
max_lr: float,
|
|
497
|
+
min_lr: float,
|
|
498
|
+
) -> float:
|
|
499
|
+
"""Decay the learning rate with half-cycle cosine after warmup."""
|
|
500
|
+
if epoch < warmup_epochs:
|
|
501
|
+
lr = max_lr * epoch / warmup_epochs
|
|
502
|
+
else:
|
|
503
|
+
lr = min_lr + (max_lr - min_lr) * 0.5 * (
|
|
504
|
+
1.0
|
|
505
|
+
+ math.cos(
|
|
506
|
+
math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
|
|
507
|
+
)
|
|
508
|
+
)
|
|
509
|
+
for group in optimizer.param_groups:
|
|
510
|
+
group["lr"] = lr
|
|
511
|
+
return lr
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
# thanks to https://github.com/bwconrad/flexivit/ for this nice implementation
|
|
515
|
+
# of the FlexiPatchEmbed module
|
|
516
|
+
def to_2tuple(x: int | tuple[int, int]) -> tuple[int, int]:
|
|
517
|
+
"""to_2tuple."""
|
|
518
|
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
|
519
|
+
return tuple(x) # type: ignore
|
|
520
|
+
return tuple(itertools.repeat(x, 2)) # type: ignore
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
class FlexiPatchEmbed(nn.Module):
|
|
524
|
+
"""FlexiPatchEmbed."""
|
|
525
|
+
|
|
526
|
+
def __init__(
|
|
527
|
+
self,
|
|
528
|
+
patch_size: int | tuple[int, int],
|
|
529
|
+
in_chans: int = 3,
|
|
530
|
+
embed_dim: int = 128,
|
|
531
|
+
norm_layer: nn.Module | None = None,
|
|
532
|
+
bias: bool = True,
|
|
533
|
+
patch_size_seq: Sequence[int] = (1, 2, 3, 4, 5, 6),
|
|
534
|
+
interpolation: str = "bicubic",
|
|
535
|
+
antialias: bool = True,
|
|
536
|
+
) -> None:
|
|
537
|
+
"""2D image to patch embedding w/ flexible patch sizes.
|
|
538
|
+
|
|
539
|
+
Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
|
|
540
|
+
by https://github.com/bwconrad/flexivit/
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
patch_size: Base patch size. i.e the size of the parameter buffer
|
|
544
|
+
in_chans: Number of input image channels
|
|
545
|
+
embed_dim: Network embedding dimension size
|
|
546
|
+
norm_layer: Optional normalization layer
|
|
547
|
+
bias: Whether to use bias in convolution
|
|
548
|
+
patch_size_seq: List of patch sizes to randomly sample from
|
|
549
|
+
interpolation: Resize interpolation type
|
|
550
|
+
antialias: Whether to apply antialiasing resizing
|
|
551
|
+
"""
|
|
552
|
+
super().__init__()
|
|
553
|
+
|
|
554
|
+
self.patch_size = to_2tuple(patch_size)
|
|
555
|
+
|
|
556
|
+
self.proj = nn.Conv2d(
|
|
557
|
+
in_chans,
|
|
558
|
+
embed_dim,
|
|
559
|
+
kernel_size=self.patch_size,
|
|
560
|
+
stride=self.patch_size,
|
|
561
|
+
bias=bias,
|
|
562
|
+
)
|
|
563
|
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
564
|
+
|
|
565
|
+
# Flexi specific attributes
|
|
566
|
+
self.interpolation = interpolation
|
|
567
|
+
self.antialias = antialias
|
|
568
|
+
|
|
569
|
+
self.patch_size_seq = patch_size_seq
|
|
570
|
+
|
|
571
|
+
# Pre-calculate pinvs
|
|
572
|
+
self.pinvs = self._cache_pinvs()
|
|
573
|
+
|
|
574
|
+
def _cache_pinvs(self) -> dict:
|
|
575
|
+
"""Pre-calculate all pinv matrices."""
|
|
576
|
+
pinvs = {}
|
|
577
|
+
for ps in self.patch_size_seq:
|
|
578
|
+
tuple_ps = to_2tuple(ps)
|
|
579
|
+
pinvs[tuple_ps] = self._calculate_pinv(self.patch_size, tuple_ps)
|
|
580
|
+
return pinvs
|
|
581
|
+
|
|
582
|
+
def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor:
|
|
583
|
+
x_resized = F.interpolate(
|
|
584
|
+
x[None, None, ...],
|
|
585
|
+
shape,
|
|
586
|
+
mode=self.interpolation,
|
|
587
|
+
antialias=self.antialias,
|
|
588
|
+
)
|
|
589
|
+
return x_resized[0, 0, ...]
|
|
590
|
+
|
|
591
|
+
def _calculate_pinv(
|
|
592
|
+
self, old_shape: tuple[int, int], new_shape: tuple[int, int]
|
|
593
|
+
) -> Tensor:
|
|
594
|
+
mat = []
|
|
595
|
+
for i in range(np.prod(old_shape)):
|
|
596
|
+
basis_vec = torch.zeros(old_shape)
|
|
597
|
+
basis_vec[np.unravel_index(i, old_shape)] = 1.0
|
|
598
|
+
mat.append(self._resize(basis_vec, new_shape).reshape(-1))
|
|
599
|
+
resize_matrix = torch.stack(mat)
|
|
600
|
+
return torch.linalg.pinv(resize_matrix)
|
|
601
|
+
|
|
602
|
+
def resize_patch_embed(
|
|
603
|
+
self, patch_embed: Tensor, new_patch_size: tuple[int, int]
|
|
604
|
+
) -> torch.Tensor:
|
|
605
|
+
"""Resize patch_embed to target resolution via pseudo-inverse resizing."""
|
|
606
|
+
# Return original kernel if no resize is necessary
|
|
607
|
+
if self.patch_size == new_patch_size:
|
|
608
|
+
return patch_embed
|
|
609
|
+
|
|
610
|
+
# Calculate pseudo-inverse of resize matrix
|
|
611
|
+
if new_patch_size not in self.pinvs:
|
|
612
|
+
self.pinvs[new_patch_size] = self._calculate_pinv(
|
|
613
|
+
self.patch_size, new_patch_size
|
|
614
|
+
)
|
|
615
|
+
pinv = self.pinvs[new_patch_size]
|
|
616
|
+
pinv = pinv.to(patch_embed.device)
|
|
617
|
+
|
|
618
|
+
def resample_patch_embed(patch_embed: Tensor) -> torch.Tensor:
|
|
619
|
+
h, w = new_patch_size
|
|
620
|
+
resampled_kernel = pinv @ patch_embed.reshape(-1)
|
|
621
|
+
return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
|
|
622
|
+
|
|
623
|
+
v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
|
|
624
|
+
|
|
625
|
+
return v_resample_patch_embed(patch_embed)
|
|
626
|
+
|
|
627
|
+
def forward(
|
|
628
|
+
self,
|
|
629
|
+
x: Tensor,
|
|
630
|
+
patch_size: int | tuple[int, int] | None = None,
|
|
631
|
+
) -> Tensor | tuple[Tensor, tuple[int, int]]:
|
|
632
|
+
"""Forward pass."""
|
|
633
|
+
# x has input shape [b, h, w, (t), c]
|
|
634
|
+
batch_size = x.shape[0]
|
|
635
|
+
has_time_dimension = False
|
|
636
|
+
num_timesteps = 0 # ignored if has_time_dimension is False
|
|
637
|
+
if len(x.shape) == 5:
|
|
638
|
+
has_time_dimension = True
|
|
639
|
+
num_timesteps = x.shape[3]
|
|
640
|
+
x = rearrange(x, "b h w t c -> (b t) c h w")
|
|
641
|
+
else:
|
|
642
|
+
x = rearrange(x, "b h w c -> b c h w")
|
|
643
|
+
|
|
644
|
+
if not patch_size:
|
|
645
|
+
# During evaluation use base patch size if not specified
|
|
646
|
+
patch_size = self.patch_size
|
|
647
|
+
|
|
648
|
+
patch_size = to_2tuple(patch_size)
|
|
649
|
+
|
|
650
|
+
# Resize conv weights
|
|
651
|
+
if patch_size == self.patch_size:
|
|
652
|
+
weight = self.proj.weight
|
|
653
|
+
else:
|
|
654
|
+
weight = self.resize_patch_embed(self.proj.weight, patch_size)
|
|
655
|
+
# Apply conv with resized weights
|
|
656
|
+
x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
|
|
657
|
+
|
|
658
|
+
if has_time_dimension:
|
|
659
|
+
x = rearrange(x, "(b t) c h w -> b h w t c", b=batch_size, t=num_timesteps)
|
|
660
|
+
else:
|
|
661
|
+
x = rearrange(x, "b c h w -> b h w c")
|
|
662
|
+
x = self.norm(x)
|
|
663
|
+
|
|
664
|
+
return x
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class Attention(nn.Module):
|
|
668
|
+
"""Attention."""
|
|
669
|
+
|
|
670
|
+
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
|
671
|
+
fast_attn: Final[bool]
|
|
672
|
+
|
|
673
|
+
def __init__(
|
|
674
|
+
self,
|
|
675
|
+
dim: int,
|
|
676
|
+
num_heads: int = 8,
|
|
677
|
+
qkv_bias: bool = False,
|
|
678
|
+
qk_norm: bool = False,
|
|
679
|
+
attn_drop: float = 0.0,
|
|
680
|
+
proj_drop: float = 0.0,
|
|
681
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
682
|
+
cross_attn: bool = False,
|
|
683
|
+
) -> None:
|
|
684
|
+
"""Initialize attention."""
|
|
685
|
+
super().__init__()
|
|
686
|
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
|
687
|
+
self.num_heads = num_heads
|
|
688
|
+
self.head_dim = dim // num_heads
|
|
689
|
+
self.scale = self.head_dim**-0.5
|
|
690
|
+
self.fast_attn = hasattr(
|
|
691
|
+
torch.nn.functional, "scaled_dot_product_attention"
|
|
692
|
+
) # FIXME
|
|
693
|
+
|
|
694
|
+
self.cross_attn = cross_attn
|
|
695
|
+
|
|
696
|
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
|
697
|
+
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
|
698
|
+
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
|
699
|
+
|
|
700
|
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
701
|
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
702
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
703
|
+
self.proj = nn.Linear(dim, dim)
|
|
704
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
705
|
+
|
|
706
|
+
def forward(
|
|
707
|
+
self,
|
|
708
|
+
x: torch.Tensor,
|
|
709
|
+
y: torch.Tensor | None = None,
|
|
710
|
+
attn_mask: torch.Tensor | None = None,
|
|
711
|
+
) -> torch.Tensor:
|
|
712
|
+
"""Forward pass."""
|
|
713
|
+
B, N, C = x.shape
|
|
714
|
+
|
|
715
|
+
q = self.q(x)
|
|
716
|
+
|
|
717
|
+
if y is None:
|
|
718
|
+
assert not self.cross_attn
|
|
719
|
+
k = self.k(x)
|
|
720
|
+
v = self.v(x)
|
|
721
|
+
else:
|
|
722
|
+
assert self.cross_attn
|
|
723
|
+
k = self.k(y)
|
|
724
|
+
v = self.v(y)
|
|
725
|
+
|
|
726
|
+
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
|
|
727
|
+
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
|
|
728
|
+
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
|
|
729
|
+
|
|
730
|
+
q, k = self.q_norm(q), self.k_norm(k)
|
|
731
|
+
if self.fast_attn:
|
|
732
|
+
if attn_mask is not None:
|
|
733
|
+
attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, N, 1))
|
|
734
|
+
x = F.scaled_dot_product_attention(
|
|
735
|
+
q,
|
|
736
|
+
k,
|
|
737
|
+
v,
|
|
738
|
+
# a value of True indicates that the element should take part in attention
|
|
739
|
+
attn_mask=attn_mask,
|
|
740
|
+
dropout_p=self.attn_drop.p,
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
if attn_mask is not None:
|
|
744
|
+
raise NotImplementedError
|
|
745
|
+
q = q * self.scale
|
|
746
|
+
attn = q @ k.transpose(-2, -1)
|
|
747
|
+
attn = attn.softmax(dim=-1)
|
|
748
|
+
attn = self.attn_drop(attn)
|
|
749
|
+
x = attn @ v
|
|
750
|
+
|
|
751
|
+
x = x.transpose(1, 2).reshape(B, N, C)
|
|
752
|
+
x = self.proj(x)
|
|
753
|
+
x = self.proj_drop(x)
|
|
754
|
+
return x
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
class Mlp(nn.Module):
|
|
758
|
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks."""
|
|
759
|
+
|
|
760
|
+
def __init__(
|
|
761
|
+
self,
|
|
762
|
+
in_features: int,
|
|
763
|
+
hidden_features: int | None = None,
|
|
764
|
+
out_features: int | None = None,
|
|
765
|
+
act_layer: nn.Module = nn.GELU,
|
|
766
|
+
bias: bool = True,
|
|
767
|
+
drop: float = 0.0,
|
|
768
|
+
) -> None:
|
|
769
|
+
"""Initialize the MLP."""
|
|
770
|
+
super().__init__()
|
|
771
|
+
out_features = out_features or in_features
|
|
772
|
+
hidden_features = hidden_features or in_features
|
|
773
|
+
|
|
774
|
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
775
|
+
self.act = act_layer()
|
|
776
|
+
self.drop1 = nn.Dropout(drop)
|
|
777
|
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
|
778
|
+
self.drop2 = nn.Dropout(drop)
|
|
779
|
+
|
|
780
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
781
|
+
"""Forward pass."""
|
|
782
|
+
x = self.fc1(x)
|
|
783
|
+
x = self.act(x)
|
|
784
|
+
x = self.drop1(x)
|
|
785
|
+
x = self.fc2(x)
|
|
786
|
+
x = self.drop2(x)
|
|
787
|
+
return x
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
class LayerScale(nn.Module):
|
|
791
|
+
"""LayerScale."""
|
|
792
|
+
|
|
793
|
+
def __init__(
|
|
794
|
+
self, dim: int, init_values: float = 1e-5, inplace: bool = False
|
|
795
|
+
) -> None:
|
|
796
|
+
"""Init layerscale."""
|
|
797
|
+
super().__init__()
|
|
798
|
+
self.inplace = inplace
|
|
799
|
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
800
|
+
|
|
801
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
802
|
+
"""Forward pass."""
|
|
803
|
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def drop_path(
|
|
807
|
+
x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
|
|
808
|
+
) -> torch.Tensor:
|
|
809
|
+
"""Drop path."""
|
|
810
|
+
if drop_prob == 0.0 or not training:
|
|
811
|
+
return x
|
|
812
|
+
keep_prob = 1 - drop_prob
|
|
813
|
+
shape = (x.shape[0],) + (1,) * (
|
|
814
|
+
x.ndim - 1
|
|
815
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
|
816
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
817
|
+
random_tensor.floor_() # binarize
|
|
818
|
+
output = x.div(keep_prob) * random_tensor
|
|
819
|
+
return output
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
class DropPath(nn.Module):
|
|
823
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
|
824
|
+
|
|
825
|
+
def __init__(self, drop_prob: float) -> None:
|
|
826
|
+
"""Init."""
|
|
827
|
+
super().__init__()
|
|
828
|
+
self.drop_prob = drop_prob
|
|
829
|
+
|
|
830
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
831
|
+
"""Forward."""
|
|
832
|
+
return drop_path(x, self.drop_prob, self.training)
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
class Block(nn.Module):
|
|
836
|
+
"""An Attention block."""
|
|
837
|
+
|
|
838
|
+
def __init__(
|
|
839
|
+
self,
|
|
840
|
+
dim: int,
|
|
841
|
+
num_heads: int,
|
|
842
|
+
mlp_ratio: float = 4.0,
|
|
843
|
+
qkv_bias: bool = False,
|
|
844
|
+
qk_norm: bool = False,
|
|
845
|
+
drop: float = 0.0,
|
|
846
|
+
attn_drop: float = 0.0,
|
|
847
|
+
drop_path: float = 0.0,
|
|
848
|
+
init_values: float | None = None,
|
|
849
|
+
act_layer: nn.Module = nn.GELU,
|
|
850
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
851
|
+
cross_attn: bool = False,
|
|
852
|
+
) -> None:
|
|
853
|
+
"""Init."""
|
|
854
|
+
super().__init__()
|
|
855
|
+
self.norm1 = norm_layer(dim)
|
|
856
|
+
self.attn = Attention(
|
|
857
|
+
dim,
|
|
858
|
+
num_heads=num_heads,
|
|
859
|
+
qkv_bias=qkv_bias,
|
|
860
|
+
qk_norm=qk_norm,
|
|
861
|
+
attn_drop=attn_drop,
|
|
862
|
+
proj_drop=drop,
|
|
863
|
+
norm_layer=norm_layer,
|
|
864
|
+
cross_attn=cross_attn,
|
|
865
|
+
)
|
|
866
|
+
self.ls1 = (
|
|
867
|
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
868
|
+
)
|
|
869
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
870
|
+
|
|
871
|
+
self.norm2 = norm_layer(dim)
|
|
872
|
+
self.mlp = Mlp(
|
|
873
|
+
in_features=dim,
|
|
874
|
+
hidden_features=int(dim * mlp_ratio),
|
|
875
|
+
act_layer=act_layer,
|
|
876
|
+
drop=drop,
|
|
877
|
+
)
|
|
878
|
+
self.ls2 = (
|
|
879
|
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
def forward(
|
|
883
|
+
self, x: torch.Tensor, y: torch.Tensor, attn_mask: torch.Tensor | None
|
|
884
|
+
) -> torch.Tensor:
|
|
885
|
+
"""Forward."""
|
|
886
|
+
x = x + self.drop_path(self.ls1(self.attn(self.norm1(x), y, attn_mask)))
|
|
887
|
+
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
|
|
888
|
+
return x
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
class ModuleListWithInit(nn.ModuleList):
|
|
892
|
+
"""module list with an init function."""
|
|
893
|
+
|
|
894
|
+
def _init_weights(self, m: nn.Module) -> None:
|
|
895
|
+
if isinstance(m, nn.Linear):
|
|
896
|
+
# we use xavier_uniform following official JAX ViT:
|
|
897
|
+
torch.nn.init.xavier_uniform_(m.weight)
|
|
898
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
899
|
+
nn.init.constant_(m.bias, 0)
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
class GalileoBase(nn.Module):
|
|
903
|
+
"""Galileo Base."""
|
|
904
|
+
|
|
905
|
+
def __init__(
|
|
906
|
+
self,
|
|
907
|
+
embedding_size: int = 128,
|
|
908
|
+
depth: int = 2,
|
|
909
|
+
mlp_ratio: int = 2,
|
|
910
|
+
num_heads: int = 8,
|
|
911
|
+
max_sequence_length: int = 24,
|
|
912
|
+
base_patch_size: int = 4,
|
|
913
|
+
use_channel_embs: bool = True,
|
|
914
|
+
drop_path: float = 0.0,
|
|
915
|
+
) -> None:
|
|
916
|
+
"""Init."""
|
|
917
|
+
super().__init__()
|
|
918
|
+
|
|
919
|
+
self.space_time_groups = SPACE_TIME_BANDS_GROUPS_IDX
|
|
920
|
+
self.space_groups = SPACE_BAND_GROUPS_IDX
|
|
921
|
+
self.time_groups = TIME_BAND_GROUPS_IDX
|
|
922
|
+
self.static_groups = STATIC_BAND_GROUPS_IDX
|
|
923
|
+
self.embedding_size = embedding_size
|
|
924
|
+
self.base_patch_size = base_patch_size
|
|
925
|
+
|
|
926
|
+
self.blocks = ModuleListWithInit(
|
|
927
|
+
[
|
|
928
|
+
Block(
|
|
929
|
+
embedding_size,
|
|
930
|
+
num_heads,
|
|
931
|
+
mlp_ratio,
|
|
932
|
+
qkv_bias=True,
|
|
933
|
+
norm_layer=nn.LayerNorm,
|
|
934
|
+
cross_attn=self.cross_attn,
|
|
935
|
+
drop_path=drop_path,
|
|
936
|
+
)
|
|
937
|
+
for _ in range(depth)
|
|
938
|
+
]
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
self.max_sequence_length = max_sequence_length
|
|
942
|
+
# we have 4 embeddings (pos_in_time, pos_in_space, month, channel) so each get
|
|
943
|
+
# 0.25 of the dimension. This will change soon anyway
|
|
944
|
+
self.pos_embed = nn.Parameter(
|
|
945
|
+
get_1d_sincos_pos_embed_from_grid_torch(
|
|
946
|
+
int(embedding_size * 0.25), torch.arange(max_sequence_length)
|
|
947
|
+
),
|
|
948
|
+
requires_grad=False,
|
|
949
|
+
)
|
|
950
|
+
month_tab = get_month_encoding_table(int(embedding_size * 0.25))
|
|
951
|
+
self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
|
|
952
|
+
if use_channel_embs:
|
|
953
|
+
args = {"requires_grad": True}
|
|
954
|
+
else:
|
|
955
|
+
args = {"requires_grad": False}
|
|
956
|
+
self.s_t_channel_embed = nn.Parameter(
|
|
957
|
+
torch.zeros(len(SPACE_TIME_BANDS_GROUPS_IDX), int(embedding_size * 0.25)),
|
|
958
|
+
**args,
|
|
959
|
+
)
|
|
960
|
+
self.sp_channel_embed = nn.Parameter(
|
|
961
|
+
torch.zeros(len(SPACE_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args
|
|
962
|
+
)
|
|
963
|
+
self.t_channel_embed = nn.Parameter(
|
|
964
|
+
torch.zeros(len(TIME_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args
|
|
965
|
+
)
|
|
966
|
+
self.st_channel_embed = nn.Parameter(
|
|
967
|
+
torch.zeros(len(STATIC_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
self.apply(self._init_weights)
|
|
971
|
+
|
|
972
|
+
@property
|
|
973
|
+
@abstractmethod
|
|
974
|
+
def cross_attn(self) -> bool:
|
|
975
|
+
"""Whether to use cross attention."""
|
|
976
|
+
pass
|
|
977
|
+
|
|
978
|
+
def _init_weights(self, m: nn.Module) -> None:
|
|
979
|
+
if isinstance(m, nn.Linear):
|
|
980
|
+
# we use xavier_uniform following official JAX ViT:
|
|
981
|
+
torch.nn.init.xavier_uniform_(m.weight)
|
|
982
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
983
|
+
nn.init.constant_(m.bias, 0)
|
|
984
|
+
|
|
985
|
+
@classmethod
|
|
986
|
+
def collapse_and_combine_hwtc(
|
|
987
|
+
cls,
|
|
988
|
+
s_t_x: torch.Tensor,
|
|
989
|
+
sp_x: torch.Tensor,
|
|
990
|
+
t_x: torch.Tensor,
|
|
991
|
+
st_x: torch.Tensor,
|
|
992
|
+
s_t_m: torch.Tensor,
|
|
993
|
+
sp_m: torch.Tensor,
|
|
994
|
+
t_m: torch.Tensor,
|
|
995
|
+
st_m: torch.Tensor,
|
|
996
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
997
|
+
"""collapse_and_combine_hwtc."""
|
|
998
|
+
s_t_x = rearrange(s_t_x, "b h w t c_g d -> b (h w t c_g) d")
|
|
999
|
+
sp_x = rearrange(sp_x, "b h w c_g d -> b (h w c_g) d")
|
|
1000
|
+
t_x = rearrange(t_x, "b t c_g d -> b (t c_g) d")
|
|
1001
|
+
|
|
1002
|
+
s_t_m = rearrange(s_t_m, "b h w t c_g-> b (h w t c_g)")
|
|
1003
|
+
sp_m = rearrange(sp_m, "b h w c_g-> b (h w c_g)")
|
|
1004
|
+
t_m = rearrange(t_m, "b t c_g -> b (t c_g)")
|
|
1005
|
+
|
|
1006
|
+
x = torch.cat(
|
|
1007
|
+
[
|
|
1008
|
+
s_t_x,
|
|
1009
|
+
sp_x,
|
|
1010
|
+
t_x,
|
|
1011
|
+
st_x,
|
|
1012
|
+
],
|
|
1013
|
+
dim=1,
|
|
1014
|
+
)
|
|
1015
|
+
m = torch.cat([s_t_m, sp_m, t_m, st_m], dim=1)
|
|
1016
|
+
return x, m
|
|
1017
|
+
|
|
1018
|
+
@classmethod
|
|
1019
|
+
def split_and_expand_hwtc(
|
|
1020
|
+
cls,
|
|
1021
|
+
x: torch.Tensor,
|
|
1022
|
+
h: int,
|
|
1023
|
+
w: int,
|
|
1024
|
+
t: int,
|
|
1025
|
+
s_t_c_g: int,
|
|
1026
|
+
sp_c_g: int,
|
|
1027
|
+
t_c_g: int,
|
|
1028
|
+
st_c_g: int,
|
|
1029
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1030
|
+
"""split_and_expand_hwtc."""
|
|
1031
|
+
n_s_t_t = h * w * t * s_t_c_g
|
|
1032
|
+
n_t_t = t * t_c_g
|
|
1033
|
+
|
|
1034
|
+
s_t_x = rearrange(
|
|
1035
|
+
x[:, :n_s_t_t], "b (h w t c) d -> b h w t c d", h=h, w=w, t=t, c=s_t_c_g
|
|
1036
|
+
)
|
|
1037
|
+
sp_x = rearrange(
|
|
1038
|
+
x[:, n_s_t_t : -(n_t_t + st_c_g)],
|
|
1039
|
+
"b (h w c) d -> b h w c d",
|
|
1040
|
+
h=h,
|
|
1041
|
+
w=w,
|
|
1042
|
+
c=sp_c_g,
|
|
1043
|
+
)
|
|
1044
|
+
t_x = rearrange(
|
|
1045
|
+
x[:, -(n_t_t + st_c_g) : -st_c_g], "b (t c) d -> b t c d", t=t, c=t_c_g
|
|
1046
|
+
)
|
|
1047
|
+
st_x = x[:, -st_c_g:]
|
|
1048
|
+
|
|
1049
|
+
return s_t_x, sp_x, t_x, st_x
|
|
1050
|
+
|
|
1051
|
+
def apply_encodings(
|
|
1052
|
+
self,
|
|
1053
|
+
s_t_x: torch.Tensor,
|
|
1054
|
+
sp_x: torch.Tensor,
|
|
1055
|
+
t_x: torch.Tensor,
|
|
1056
|
+
st_x: torch.Tensor,
|
|
1057
|
+
months: torch.Tensor,
|
|
1058
|
+
patch_size: int,
|
|
1059
|
+
input_res: int,
|
|
1060
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1061
|
+
"""apply_encodings."""
|
|
1062
|
+
b, h, w, t, s_t_c_g, _ = s_t_x.shape
|
|
1063
|
+
sp_c_g, t_c_g = sp_x.shape[-2], t_x.shape[-2]
|
|
1064
|
+
st_c_g = st_x.shape[-2]
|
|
1065
|
+
|
|
1066
|
+
s_t_channel = repeat(
|
|
1067
|
+
self.s_t_channel_embed, "c_g d -> b h w t c_g d", b=b, h=h, w=w, t=t
|
|
1068
|
+
)
|
|
1069
|
+
t_channel = repeat(self.t_channel_embed, "c_g d -> b t c_g d", b=b, t=t)
|
|
1070
|
+
st_channel = repeat(self.st_channel_embed, "c_g d -> b c_g d", b=b)
|
|
1071
|
+
sp_channel = repeat(
|
|
1072
|
+
self.sp_channel_embed, "c_g d -> b h w c_g d", b=b, h=h, w=w
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
pos_embed_s_t = repeat(
|
|
1076
|
+
self.pos_embed[:t], "t d -> b h w t c_g d", b=b, h=h, w=w, c_g=s_t_c_g
|
|
1077
|
+
)
|
|
1078
|
+
m_embed_s_t = repeat(
|
|
1079
|
+
self.month_embed(months), "b t d -> b h w t c_g d", h=h, w=w, c_g=s_t_c_g
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
pos_embed_t = repeat(self.pos_embed[:t], "t d -> b t c_g d", b=b, c_g=t_c_g)
|
|
1083
|
+
m_embed_t = repeat(self.month_embed(months), "b t d -> b t c_g d", c_g=t_c_g)
|
|
1084
|
+
t_zeros = torch.zeros(
|
|
1085
|
+
b, t, t_c_g, int(self.embedding_size * 0.25), device=t_x.device
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
sp_zeros = torch.zeros(
|
|
1089
|
+
b,
|
|
1090
|
+
h,
|
|
1091
|
+
w,
|
|
1092
|
+
sp_c_g,
|
|
1093
|
+
sp_channel.shape[-1] * 2,
|
|
1094
|
+
device=sp_channel.device,
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
st_zeros = torch.zeros(
|
|
1098
|
+
b, st_c_g, st_channel.shape[-1] * 3, device=st_channel.device
|
|
1099
|
+
)
|
|
1100
|
+
|
|
1101
|
+
# find the resolution that each token represents, which will be
|
|
1102
|
+
# the number of pixels in a patch * the resolution of each pixel
|
|
1103
|
+
if patch_size is None:
|
|
1104
|
+
patch_size = self.base_patch_size
|
|
1105
|
+
token_res = input_res * patch_size
|
|
1106
|
+
gsd_ratio = token_res / BASE_GSD
|
|
1107
|
+
|
|
1108
|
+
assert h == w, (
|
|
1109
|
+
"get_2d_sincos_pos_embed_with_resolution currently requires that h==w"
|
|
1110
|
+
)
|
|
1111
|
+
spatial_embed = get_2d_sincos_pos_embed_with_resolution(
|
|
1112
|
+
int(self.embedding_size * 0.25),
|
|
1113
|
+
h,
|
|
1114
|
+
torch.ones(b).to(s_t_x.device) * gsd_ratio,
|
|
1115
|
+
device=s_t_x.device,
|
|
1116
|
+
)
|
|
1117
|
+
spatial_embed = rearrange(spatial_embed, "b (h w) d -> b h w d", h=h, w=w)
|
|
1118
|
+
spatial_embed_s_t = repeat(
|
|
1119
|
+
spatial_embed, "b h w d -> b h w t c_g d", h=h, w=w, t=t, c_g=s_t_c_g
|
|
1120
|
+
)
|
|
1121
|
+
spatial_embed_s = repeat(
|
|
1122
|
+
spatial_embed, "b h w d -> b h w c_g d", h=h, w=w, c_g=sp_c_g
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
s_t_embed = torch.cat(
|
|
1126
|
+
[s_t_channel, pos_embed_s_t, m_embed_s_t, spatial_embed_s_t], dim=-1
|
|
1127
|
+
)
|
|
1128
|
+
sp_embed = torch.cat([sp_channel, sp_zeros, spatial_embed_s], dim=-1)
|
|
1129
|
+
t_embed = torch.cat([t_channel, pos_embed_t, m_embed_t, t_zeros], dim=-1)
|
|
1130
|
+
st_embed = torch.cat([st_channel, st_zeros], dim=-1)
|
|
1131
|
+
return s_t_x + s_t_embed, sp_x + sp_embed, t_x + t_embed, st_x + st_embed
|
|
1132
|
+
|
|
1133
|
+
|
|
1134
|
+
class Encoder(GalileoBase):
|
|
1135
|
+
"""Galileo Encoder."""
|
|
1136
|
+
|
|
1137
|
+
def __init__(
|
|
1138
|
+
self,
|
|
1139
|
+
max_patch_size: int = 8,
|
|
1140
|
+
embedding_size: int = 128,
|
|
1141
|
+
depth: int = 2,
|
|
1142
|
+
mlp_ratio: int = 2,
|
|
1143
|
+
num_heads: int = 8,
|
|
1144
|
+
max_sequence_length: int = 24,
|
|
1145
|
+
freeze_projections: bool = False,
|
|
1146
|
+
drop_path: float = 0.0,
|
|
1147
|
+
) -> None:
|
|
1148
|
+
"""Init."""
|
|
1149
|
+
super().__init__(
|
|
1150
|
+
embedding_size,
|
|
1151
|
+
depth,
|
|
1152
|
+
mlp_ratio,
|
|
1153
|
+
num_heads,
|
|
1154
|
+
max_sequence_length,
|
|
1155
|
+
max_patch_size,
|
|
1156
|
+
use_channel_embs=True,
|
|
1157
|
+
drop_path=drop_path,
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
self.space_time_embed = nn.ModuleDict(
|
|
1161
|
+
{
|
|
1162
|
+
group_name: FlexiPatchEmbed(
|
|
1163
|
+
in_chans=len(group),
|
|
1164
|
+
embed_dim=embedding_size,
|
|
1165
|
+
patch_size=max_patch_size,
|
|
1166
|
+
)
|
|
1167
|
+
for group_name, group in self.space_time_groups.items()
|
|
1168
|
+
}
|
|
1169
|
+
)
|
|
1170
|
+
self.space_embed = nn.ModuleDict(
|
|
1171
|
+
{
|
|
1172
|
+
group_name: FlexiPatchEmbed(
|
|
1173
|
+
in_chans=len(group),
|
|
1174
|
+
embed_dim=embedding_size,
|
|
1175
|
+
patch_size=max_patch_size,
|
|
1176
|
+
)
|
|
1177
|
+
for group_name, group in self.space_groups.items()
|
|
1178
|
+
}
|
|
1179
|
+
)
|
|
1180
|
+
self.time_embed = nn.ModuleDict(
|
|
1181
|
+
{
|
|
1182
|
+
group_name: nn.Linear(
|
|
1183
|
+
in_features=len(group), out_features=embedding_size
|
|
1184
|
+
)
|
|
1185
|
+
for group_name, group in self.time_groups.items()
|
|
1186
|
+
}
|
|
1187
|
+
)
|
|
1188
|
+
self.static_embed = nn.ModuleDict(
|
|
1189
|
+
{
|
|
1190
|
+
group_name: nn.Linear(
|
|
1191
|
+
in_features=len(group), out_features=embedding_size
|
|
1192
|
+
)
|
|
1193
|
+
for group_name, group in self.static_groups.items()
|
|
1194
|
+
}
|
|
1195
|
+
)
|
|
1196
|
+
if freeze_projections:
|
|
1197
|
+
self.space_time_embed.requires_grad_(False)
|
|
1198
|
+
self.space_embed.requires_grad_(False)
|
|
1199
|
+
self.time_embed.requires_grad_(False)
|
|
1200
|
+
self.static_embed.requires_grad_(False)
|
|
1201
|
+
self.norm = nn.LayerNorm(embedding_size)
|
|
1202
|
+
|
|
1203
|
+
self.apply(self._init_weights)
|
|
1204
|
+
|
|
1205
|
+
@property
|
|
1206
|
+
@override
|
|
1207
|
+
def cross_attn(self) -> bool:
|
|
1208
|
+
"""Whether to use cross attention."""
|
|
1209
|
+
return False
|
|
1210
|
+
|
|
1211
|
+
def _init_weights(self, m: nn.Module) -> None:
|
|
1212
|
+
if isinstance(m, nn.Linear):
|
|
1213
|
+
# we use xavier_uniform following official JAX ViT:
|
|
1214
|
+
torch.nn.init.xavier_uniform_(m.weight)
|
|
1215
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
1216
|
+
nn.init.constant_(m.bias, 0)
|
|
1217
|
+
|
|
1218
|
+
def apply_linear_projection(
|
|
1219
|
+
self,
|
|
1220
|
+
s_t_x: torch.Tensor,
|
|
1221
|
+
sp_x: torch.Tensor,
|
|
1222
|
+
t_x: torch.Tensor,
|
|
1223
|
+
st_x: torch.Tensor,
|
|
1224
|
+
s_t_m: torch.Tensor,
|
|
1225
|
+
sp_m: torch.Tensor,
|
|
1226
|
+
t_m: torch.Tensor,
|
|
1227
|
+
st_m: torch.Tensor,
|
|
1228
|
+
patch_size: int,
|
|
1229
|
+
) -> tuple[
|
|
1230
|
+
torch.Tensor,
|
|
1231
|
+
torch.Tensor,
|
|
1232
|
+
torch.Tensor,
|
|
1233
|
+
torch.Tensor,
|
|
1234
|
+
torch.Tensor,
|
|
1235
|
+
torch.Tensor,
|
|
1236
|
+
torch.Tensor,
|
|
1237
|
+
torch.Tensor,
|
|
1238
|
+
]:
|
|
1239
|
+
"""apply_linear_projection.
|
|
1240
|
+
|
|
1241
|
+
Given a [B, H, W, (T), C] inputs, returns a [B, H, W, (T), C_G, D] output.
|
|
1242
|
+
We assume that the spatial masks are consistent for the given patch size,
|
|
1243
|
+
so that if patch_size == 2 then one possible mask would be
|
|
1244
|
+
[0, 0, 1, 1]
|
|
1245
|
+
[0, 0, 1, 1]
|
|
1246
|
+
[1, 1, 0, 0]
|
|
1247
|
+
[1, 1, 0, 0]
|
|
1248
|
+
for the H, W dimensions
|
|
1249
|
+
"""
|
|
1250
|
+
b, h, w, t, _ = s_t_x.shape
|
|
1251
|
+
new_h, new_w = h // patch_size, w // patch_size
|
|
1252
|
+
|
|
1253
|
+
s_t_l, sp_l, t_l, st_l, s_t_m_l, sp_m_l, t_m_l, st_m_l = (
|
|
1254
|
+
[],
|
|
1255
|
+
[],
|
|
1256
|
+
[],
|
|
1257
|
+
[],
|
|
1258
|
+
[],
|
|
1259
|
+
[],
|
|
1260
|
+
[],
|
|
1261
|
+
[],
|
|
1262
|
+
)
|
|
1263
|
+
for idx, (channel_group, channel_idxs) in enumerate(
|
|
1264
|
+
self.space_time_groups.items()
|
|
1265
|
+
):
|
|
1266
|
+
s_t_m_l.append(s_t_m[:, 0::patch_size, 0::patch_size, :, idx])
|
|
1267
|
+
if s_t_m_l[-1].min() == 0:
|
|
1268
|
+
s_t_l.append(
|
|
1269
|
+
self.space_time_embed[channel_group](
|
|
1270
|
+
s_t_x[:, :, :, :, channel_idxs], patch_size=patch_size
|
|
1271
|
+
)
|
|
1272
|
+
)
|
|
1273
|
+
else:
|
|
1274
|
+
s_t_l.append(
|
|
1275
|
+
torch.zeros(
|
|
1276
|
+
b,
|
|
1277
|
+
new_h,
|
|
1278
|
+
new_w,
|
|
1279
|
+
t,
|
|
1280
|
+
self.embedding_size,
|
|
1281
|
+
dtype=s_t_x.dtype,
|
|
1282
|
+
device=s_t_x.device,
|
|
1283
|
+
)
|
|
1284
|
+
)
|
|
1285
|
+
for idx, (channel_group, channel_idxs) in enumerate(self.space_groups.items()):
|
|
1286
|
+
sp_m_l.append(sp_m[:, 0::patch_size, 0::patch_size, idx])
|
|
1287
|
+
if sp_m_l[-1].min() == 0:
|
|
1288
|
+
sp_l.append(
|
|
1289
|
+
self.space_embed[channel_group](
|
|
1290
|
+
sp_x[:, :, :, channel_idxs], patch_size=patch_size
|
|
1291
|
+
)
|
|
1292
|
+
)
|
|
1293
|
+
else:
|
|
1294
|
+
sp_l.append(
|
|
1295
|
+
torch.zeros(
|
|
1296
|
+
b,
|
|
1297
|
+
new_h,
|
|
1298
|
+
new_w,
|
|
1299
|
+
self.embedding_size,
|
|
1300
|
+
dtype=sp_x.dtype,
|
|
1301
|
+
device=sp_x.device,
|
|
1302
|
+
)
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
for idx, (channel_group, channel_idxs) in enumerate(self.time_groups.items()):
|
|
1306
|
+
t_m_l.append(t_m[:, :, idx])
|
|
1307
|
+
if t_m_l[-1].min() == 0:
|
|
1308
|
+
t_l.append(self.time_embed[channel_group](t_x[:, :, channel_idxs]))
|
|
1309
|
+
else:
|
|
1310
|
+
t_l.append(
|
|
1311
|
+
torch.zeros(
|
|
1312
|
+
b, t, self.embedding_size, dtype=t_x.dtype, device=t_x.device
|
|
1313
|
+
)
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
for idx, (channel_group, channel_idxs) in enumerate(self.static_groups.items()):
|
|
1317
|
+
st_m_l.append(st_m[:, idx])
|
|
1318
|
+
if st_m_l[-1].min() == 0:
|
|
1319
|
+
st_l.append(self.static_embed[channel_group](st_x[:, channel_idxs]))
|
|
1320
|
+
else:
|
|
1321
|
+
st_l.append(
|
|
1322
|
+
torch.zeros(
|
|
1323
|
+
b, self.embedding_size, dtype=st_x.dtype, device=st_x.device
|
|
1324
|
+
)
|
|
1325
|
+
)
|
|
1326
|
+
|
|
1327
|
+
return (
|
|
1328
|
+
torch.stack(s_t_l, dim=-2),
|
|
1329
|
+
torch.stack(sp_l, dim=-2),
|
|
1330
|
+
torch.stack(t_l, dim=-2),
|
|
1331
|
+
torch.stack(st_l, dim=-2),
|
|
1332
|
+
torch.stack(s_t_m_l, dim=-1),
|
|
1333
|
+
torch.stack(sp_m_l, dim=-1),
|
|
1334
|
+
torch.stack(t_m_l, dim=-1),
|
|
1335
|
+
torch.stack(st_m_l, dim=-1),
|
|
1336
|
+
)
|
|
1337
|
+
|
|
1338
|
+
@staticmethod
|
|
1339
|
+
def remove_masked_tokens(
|
|
1340
|
+
x: torch.Tensor, mask: torch.Tensor
|
|
1341
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1342
|
+
"""Remove masked tokens."""
|
|
1343
|
+
org_mask_dtype = mask.dtype
|
|
1344
|
+
mask = mask.bool()
|
|
1345
|
+
# https://stackoverflow.com/a/68621610/2332296
|
|
1346
|
+
# move all non-masked values to the front of their rows
|
|
1347
|
+
sorted_mask, indices = torch.sort(
|
|
1348
|
+
(~mask).int(), dim=1, descending=True, stable=True
|
|
1349
|
+
)
|
|
1350
|
+
x = x.gather(1, indices[:, :, None].expand_as(x))
|
|
1351
|
+
# set masked values to 0 (not really necessary since we'll ignore them anyway)
|
|
1352
|
+
x = x * sorted_mask.unsqueeze(-1)
|
|
1353
|
+
|
|
1354
|
+
# cut off to the length of the longest sequence
|
|
1355
|
+
max_length = sorted_mask.sum(-1).max()
|
|
1356
|
+
x = x[:, :max_length]
|
|
1357
|
+
updated_mask = 1 - sorted_mask[:, :max_length]
|
|
1358
|
+
|
|
1359
|
+
return x, indices, updated_mask.to(dtype=org_mask_dtype)
|
|
1360
|
+
|
|
1361
|
+
@staticmethod
|
|
1362
|
+
def add_removed_tokens(
|
|
1363
|
+
x: torch.Tensor, indices: torch.Tensor, mask: torch.Tensor
|
|
1364
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1365
|
+
"""add_removed_tokens."""
|
|
1366
|
+
masked_tokens = repeat(
|
|
1367
|
+
torch.zeros_like(x[0, 0, :]), "d -> b t d", b=x.shape[0], t=indices.shape[1]
|
|
1368
|
+
)
|
|
1369
|
+
full_mask = torch.cat(
|
|
1370
|
+
(
|
|
1371
|
+
mask,
|
|
1372
|
+
torch.ones(
|
|
1373
|
+
(x.shape[0], indices.shape[1] - x.shape[1]),
|
|
1374
|
+
device=x.device,
|
|
1375
|
+
dtype=mask.dtype,
|
|
1376
|
+
),
|
|
1377
|
+
),
|
|
1378
|
+
dim=-1,
|
|
1379
|
+
)
|
|
1380
|
+
# can't set value on leaf variable
|
|
1381
|
+
out = masked_tokens.clone()
|
|
1382
|
+
# put tokens in full masked tensor (at the first N positions in every row)
|
|
1383
|
+
out[~full_mask.bool()] = x[~mask.bool()]
|
|
1384
|
+
# then move them to their original positions
|
|
1385
|
+
out = out.scatter(1, indices[:, :, None].expand_as(out), out)
|
|
1386
|
+
full_mask = full_mask.scatter(1, indices.expand_as(full_mask), full_mask)
|
|
1387
|
+
return out, full_mask
|
|
1388
|
+
|
|
1389
|
+
def apply_attn(
|
|
1390
|
+
self,
|
|
1391
|
+
s_t_x: torch.Tensor,
|
|
1392
|
+
sp_x: torch.Tensor,
|
|
1393
|
+
t_x: torch.Tensor,
|
|
1394
|
+
st_x: torch.Tensor,
|
|
1395
|
+
s_t_m: torch.Tensor,
|
|
1396
|
+
sp_m: torch.Tensor,
|
|
1397
|
+
t_m: torch.Tensor,
|
|
1398
|
+
st_m: torch.Tensor,
|
|
1399
|
+
months: torch.Tensor,
|
|
1400
|
+
patch_size: int,
|
|
1401
|
+
input_res: int,
|
|
1402
|
+
exit_after: int | None,
|
|
1403
|
+
token_exit_cfg: dict | None,
|
|
1404
|
+
) -> tuple[
|
|
1405
|
+
torch.Tensor,
|
|
1406
|
+
torch.Tensor,
|
|
1407
|
+
torch.Tensor,
|
|
1408
|
+
torch.Tensor,
|
|
1409
|
+
torch.Tensor,
|
|
1410
|
+
torch.Tensor,
|
|
1411
|
+
torch.Tensor,
|
|
1412
|
+
torch.Tensor,
|
|
1413
|
+
]:
|
|
1414
|
+
"""apply_attn."""
|
|
1415
|
+
if token_exit_cfg:
|
|
1416
|
+
exit_s_t, exit_sp, exit_t, exit_st = self.create_token_exit_ids(
|
|
1417
|
+
s_t_x, sp_x, t_x, st_x, token_exit_cfg
|
|
1418
|
+
)
|
|
1419
|
+
exit_ids_seq, _ = self.collapse_and_combine_hwtc(
|
|
1420
|
+
exit_s_t, exit_sp, exit_t, exit_st, s_t_m, sp_m, t_m, st_m
|
|
1421
|
+
)
|
|
1422
|
+
# exited_tokens starts as linear projections!
|
|
1423
|
+
exited_tokens, _ = self.collapse_and_combine_hwtc(
|
|
1424
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
|
|
1425
|
+
)
|
|
1426
|
+
else:
|
|
1427
|
+
exit_ids_seq = None
|
|
1428
|
+
exited_tokens = None
|
|
1429
|
+
|
|
1430
|
+
_, h, w, t, s_t_c_g, _ = s_t_x.shape
|
|
1431
|
+
sp_c_g, t_c_g, st_c_g = sp_x.shape[3], t_x.shape[-2], st_x.shape[-2]
|
|
1432
|
+
s_t_x, sp_x, t_x, st_x = self.apply_encodings(
|
|
1433
|
+
s_t_x, sp_x, t_x, st_x, months, patch_size, input_res
|
|
1434
|
+
)
|
|
1435
|
+
x, m = self.collapse_and_combine_hwtc(
|
|
1436
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
|
|
1437
|
+
)
|
|
1438
|
+
|
|
1439
|
+
# we only care about the values >= 1 for this mask, since 2 just tells the decoder
|
|
1440
|
+
# to decode those tokens. From the perspective of the encoder, 1 and 2 are equivalent
|
|
1441
|
+
# since they both represent masked values
|
|
1442
|
+
new_m = m >= 1
|
|
1443
|
+
x, indices, new_m = self.remove_masked_tokens(
|
|
1444
|
+
x, new_m
|
|
1445
|
+
) # new_m is shape (bsz, seq_len)
|
|
1446
|
+
|
|
1447
|
+
if exit_ids_seq is not None:
|
|
1448
|
+
exit_ids_seq, _, _ = self.remove_masked_tokens(exit_ids_seq, m >= 1)
|
|
1449
|
+
# still linear projections
|
|
1450
|
+
exited_tokens, _, _ = self.remove_masked_tokens(exited_tokens, m >= 1)
|
|
1451
|
+
|
|
1452
|
+
for i_blk, blk in enumerate(self.blocks):
|
|
1453
|
+
if (exit_after is not None) and ((i_blk + 1) > exit_after):
|
|
1454
|
+
# if exit_after is N, then we exit after the Nth layer
|
|
1455
|
+
# if exit_after is 0, then all layers are skipped
|
|
1456
|
+
break
|
|
1457
|
+
|
|
1458
|
+
# skip the 0th block since this is just the linear
|
|
1459
|
+
# projection
|
|
1460
|
+
if (exit_ids_seq is not None) and (i_blk > 0):
|
|
1461
|
+
assert exited_tokens is not None
|
|
1462
|
+
# half depth
|
|
1463
|
+
exited_tokens = torch.where(
|
|
1464
|
+
condition=(exit_ids_seq == i_blk),
|
|
1465
|
+
input=x.detach(),
|
|
1466
|
+
other=exited_tokens.detach(),
|
|
1467
|
+
)
|
|
1468
|
+
|
|
1469
|
+
# we take the inverse of the mask because a value
|
|
1470
|
+
# of True indicates the value *should* take part in
|
|
1471
|
+
# attention
|
|
1472
|
+
temp_mask = ~new_m.bool()
|
|
1473
|
+
if temp_mask.all():
|
|
1474
|
+
# if all the tokens are used in attention we can pass a None mask
|
|
1475
|
+
# to the attention block
|
|
1476
|
+
temp_mask = None
|
|
1477
|
+
|
|
1478
|
+
x = blk(x=x, y=None, attn_mask=temp_mask)
|
|
1479
|
+
|
|
1480
|
+
if exit_ids_seq is not None:
|
|
1481
|
+
assert exited_tokens is not None
|
|
1482
|
+
# full depth
|
|
1483
|
+
# IMPORTANT: write this to x
|
|
1484
|
+
x = torch.where(
|
|
1485
|
+
condition=(exit_ids_seq == (i_blk + 1)), # 2 for full depth
|
|
1486
|
+
input=x.detach(),
|
|
1487
|
+
other=exited_tokens.detach(),
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
# we don't care about the mask returned by add_removed_tokens, since we will
|
|
1491
|
+
# just use the original, unclipped mask here
|
|
1492
|
+
x, _ = self.add_removed_tokens(x, indices, new_m)
|
|
1493
|
+
return (
|
|
1494
|
+
*self.split_and_expand_hwtc(x, h, w, t, s_t_c_g, sp_c_g, t_c_g, st_c_g),
|
|
1495
|
+
s_t_m,
|
|
1496
|
+
sp_m,
|
|
1497
|
+
t_m,
|
|
1498
|
+
st_m,
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
@classmethod
|
|
1502
|
+
def average_tokens(
|
|
1503
|
+
cls,
|
|
1504
|
+
s_t_x: torch.Tensor,
|
|
1505
|
+
sp_x: torch.Tensor,
|
|
1506
|
+
t_x: torch.Tensor,
|
|
1507
|
+
st_x: torch.Tensor,
|
|
1508
|
+
s_t_m: torch.Tensor,
|
|
1509
|
+
sp_m: torch.Tensor,
|
|
1510
|
+
t_m: torch.Tensor,
|
|
1511
|
+
st_m: torch.Tensor,
|
|
1512
|
+
) -> torch.Tensor:
|
|
1513
|
+
"""average_tokens."""
|
|
1514
|
+
x, m = cls.collapse_and_combine_hwtc(
|
|
1515
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
|
|
1516
|
+
)
|
|
1517
|
+
x, _, m = cls.remove_masked_tokens(x, m)
|
|
1518
|
+
x_for_mean = x * (1 - m.unsqueeze(-1))
|
|
1519
|
+
return x_for_mean.sum(dim=1) / torch.sum(1 - m, -1, keepdim=True)
|
|
1520
|
+
|
|
1521
|
+
@classmethod
|
|
1522
|
+
def apply_mask_and_average_tokens_per_patch(
|
|
1523
|
+
cls,
|
|
1524
|
+
s_t_x: torch.Tensor,
|
|
1525
|
+
sp_x: torch.Tensor,
|
|
1526
|
+
t_x: torch.Tensor,
|
|
1527
|
+
st_x: torch.Tensor,
|
|
1528
|
+
s_t_m: torch.Tensor,
|
|
1529
|
+
sp_m: torch.Tensor,
|
|
1530
|
+
t_m: torch.Tensor,
|
|
1531
|
+
st_m: torch.Tensor,
|
|
1532
|
+
) -> torch.Tensor:
|
|
1533
|
+
"""apply_mask_and_average_tokens_per_patch."""
|
|
1534
|
+
s_t_x = rearrange(s_t_x, "b t_h t_w t c_g d -> b (t_h t_w) (t c_g) d")
|
|
1535
|
+
sp_x = rearrange(sp_x, "b t_h t_w c_g d -> b (t_h t_w) c_g d")
|
|
1536
|
+
# repeat time tokens over space
|
|
1537
|
+
t_x = repeat(
|
|
1538
|
+
rearrange(t_x, "b t c_g d -> b (t c_g) d"),
|
|
1539
|
+
"b n d -> b s n d",
|
|
1540
|
+
s=sp_x.shape[1],
|
|
1541
|
+
)
|
|
1542
|
+
st_x = repeat(st_x, "b c_g d -> b s c_g d", s=sp_x.shape[1])
|
|
1543
|
+
s_t_m = rearrange(s_t_m, "b t_h t_w t c_g-> b (t_h t_w) (t c_g)")
|
|
1544
|
+
sp_m = rearrange(sp_m, "b t_h t_w c_g-> b (t_h t_w) c_g")
|
|
1545
|
+
t_m = repeat(
|
|
1546
|
+
rearrange(t_m, "b t c_g -> b (t c_g)"), "b n -> b s n", s=sp_x.shape[1]
|
|
1547
|
+
)
|
|
1548
|
+
st_m = repeat(st_m, "b c_g -> b s c_g", s=sp_x.shape[1])
|
|
1549
|
+
|
|
1550
|
+
x = torch.cat([s_t_x, sp_x, t_x, st_x], dim=2) # B, S, N, D
|
|
1551
|
+
m = torch.cat([s_t_m, sp_m, t_m, st_m], dim=2) # B, S, N
|
|
1552
|
+
|
|
1553
|
+
x_for_mean = x * (1 - m.unsqueeze(-1))
|
|
1554
|
+
|
|
1555
|
+
return x_for_mean.sum(dim=2) / torch.sum(1 - m, -1, keepdim=True)
|
|
1556
|
+
|
|
1557
|
+
def create_token_exit_ids(
|
|
1558
|
+
self,
|
|
1559
|
+
s_t_x: torch.Tensor,
|
|
1560
|
+
sp_x: torch.Tensor,
|
|
1561
|
+
t_x: torch.Tensor,
|
|
1562
|
+
st_x: torch.Tensor,
|
|
1563
|
+
token_exit_cfg: dict,
|
|
1564
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1565
|
+
"""create_token_exit_ids."""
|
|
1566
|
+
exit_s_t = torch.zeros_like(s_t_x)
|
|
1567
|
+
exit_sp = torch.zeros_like(sp_x)
|
|
1568
|
+
exit_t = torch.zeros_like(t_x)
|
|
1569
|
+
exit_st = torch.zeros_like(st_x)
|
|
1570
|
+
|
|
1571
|
+
for idx, (key, _) in enumerate(self.space_time_groups.items()):
|
|
1572
|
+
exit_s_t[:, :, :, :, idx, :] = token_exit_cfg[key]
|
|
1573
|
+
|
|
1574
|
+
for idx, (key, _) in enumerate(self.space_groups.items()):
|
|
1575
|
+
exit_sp[:, :, :, idx, :] = token_exit_cfg[key]
|
|
1576
|
+
|
|
1577
|
+
for idx, (key, _) in enumerate(self.time_groups.items()):
|
|
1578
|
+
exit_t[:, :, idx, :] = token_exit_cfg[key]
|
|
1579
|
+
|
|
1580
|
+
for idx, (key, _) in enumerate(self.static_groups.items()):
|
|
1581
|
+
exit_st[:, idx, :] = token_exit_cfg[key]
|
|
1582
|
+
return exit_s_t, exit_sp, exit_t, exit_st
|
|
1583
|
+
|
|
1584
|
+
def forward(
|
|
1585
|
+
self,
|
|
1586
|
+
s_t_x: torch.Tensor,
|
|
1587
|
+
sp_x: torch.Tensor,
|
|
1588
|
+
t_x: torch.Tensor,
|
|
1589
|
+
st_x: torch.Tensor,
|
|
1590
|
+
s_t_m: torch.Tensor,
|
|
1591
|
+
sp_m: torch.Tensor,
|
|
1592
|
+
t_m: torch.Tensor,
|
|
1593
|
+
st_m: torch.Tensor,
|
|
1594
|
+
months: torch.Tensor,
|
|
1595
|
+
patch_size: int,
|
|
1596
|
+
input_resolution_m: int = BASE_GSD,
|
|
1597
|
+
exit_after: int | None = None,
|
|
1598
|
+
token_exit_cfg: dict | None = None,
|
|
1599
|
+
add_layernorm_on_exit: bool = True,
|
|
1600
|
+
) -> tuple[
|
|
1601
|
+
torch.Tensor,
|
|
1602
|
+
torch.Tensor,
|
|
1603
|
+
torch.Tensor,
|
|
1604
|
+
torch.Tensor,
|
|
1605
|
+
torch.Tensor,
|
|
1606
|
+
torch.Tensor,
|
|
1607
|
+
torch.Tensor,
|
|
1608
|
+
torch.Tensor,
|
|
1609
|
+
torch.Tensor,
|
|
1610
|
+
]:
|
|
1611
|
+
"""Forward."""
|
|
1612
|
+
(
|
|
1613
|
+
s_t_x,
|
|
1614
|
+
sp_x,
|
|
1615
|
+
t_x,
|
|
1616
|
+
st_x,
|
|
1617
|
+
s_t_m,
|
|
1618
|
+
sp_m,
|
|
1619
|
+
t_m,
|
|
1620
|
+
st_m,
|
|
1621
|
+
) = self.apply_linear_projection(
|
|
1622
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, patch_size
|
|
1623
|
+
)
|
|
1624
|
+
|
|
1625
|
+
if (exit_after is None) or (exit_after > 0):
|
|
1626
|
+
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m = self.apply_attn(
|
|
1627
|
+
s_t_x,
|
|
1628
|
+
sp_x,
|
|
1629
|
+
t_x,
|
|
1630
|
+
st_x,
|
|
1631
|
+
s_t_m,
|
|
1632
|
+
sp_m,
|
|
1633
|
+
t_m,
|
|
1634
|
+
st_m,
|
|
1635
|
+
months,
|
|
1636
|
+
patch_size,
|
|
1637
|
+
input_resolution_m,
|
|
1638
|
+
exit_after=exit_after,
|
|
1639
|
+
token_exit_cfg=token_exit_cfg,
|
|
1640
|
+
)
|
|
1641
|
+
|
|
1642
|
+
if add_layernorm_on_exit:
|
|
1643
|
+
s_t_x = self.norm(s_t_x)
|
|
1644
|
+
sp_x = self.norm(sp_x)
|
|
1645
|
+
t_x = self.norm(t_x)
|
|
1646
|
+
st_x = self.norm(st_x)
|
|
1647
|
+
|
|
1648
|
+
return (s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, months)
|
|
1649
|
+
|
|
1650
|
+
@classmethod
|
|
1651
|
+
def load_from_folder(cls, folder: Path, device: torch.device) -> "Encoder":
|
|
1652
|
+
"""Load a model from a folder containing an encoder.pt and config.json."""
|
|
1653
|
+
if not (folder / CONFIG_FILENAME).exists():
|
|
1654
|
+
all_files_in_folder = [f.name for f in folder.glob("*")]
|
|
1655
|
+
raise ValueError(
|
|
1656
|
+
f"Expected {CONFIG_FILENAME} in {folder}, found {all_files_in_folder}"
|
|
1657
|
+
)
|
|
1658
|
+
if not (folder / ENCODER_FILENAME).exists():
|
|
1659
|
+
all_files_in_folder = [f.name for f in folder.glob("*")]
|
|
1660
|
+
raise ValueError(
|
|
1661
|
+
f"Expected {ENCODER_FILENAME} in {folder}, found {all_files_in_folder}"
|
|
1662
|
+
)
|
|
1663
|
+
|
|
1664
|
+
with (folder / CONFIG_FILENAME).open("r") as f:
|
|
1665
|
+
config = json.load(f)
|
|
1666
|
+
model_config = config["model"]
|
|
1667
|
+
encoder_config = model_config["encoder"]
|
|
1668
|
+
encoder = cls(**encoder_config)
|
|
1669
|
+
|
|
1670
|
+
state_dict = torch.load(
|
|
1671
|
+
folder / ENCODER_FILENAME, map_location=device, weights_only=True
|
|
1672
|
+
)
|
|
1673
|
+
for key in list(state_dict.keys()):
|
|
1674
|
+
# this cleans the state dict, which occasionally had an extra
|
|
1675
|
+
# ".backbone" included in the key names
|
|
1676
|
+
state_dict[key.replace(".backbone", "")] = state_dict.pop(key)
|
|
1677
|
+
encoder.load_state_dict(state_dict)
|
|
1678
|
+
return encoder
|