rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1678 @@
1
+ """Galileo models."""
2
+
3
+ import collections.abc
4
+ import itertools
5
+ import json
6
+ import math
7
+ from abc import abstractmethod
8
+ from collections import OrderedDict
9
+ from collections import OrderedDict as OrderedDictType
10
+ from collections.abc import Sequence
11
+ from copy import deepcopy
12
+ from pathlib import Path
13
+ from typing import NamedTuple, cast
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from einops import rearrange, repeat
20
+ from torch import Tensor, vmap
21
+ from torch.jit import Final
22
+ from typing_extensions import override
23
+
24
+ from rslearn.log_utils import get_logger
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ # constants
30
+ CONFIG_FILENAME = "config.json"
31
+ ENCODER_FILENAME = "encoder.pt"
32
+ BASE_GSD = 10
33
+ DEFAULT_MONTH = 5
34
+
35
+ # band information
36
+ S1_BANDS = ["VV", "VH"]
37
+ S1_SHIFT_VALUES = [25.0, 25.0]
38
+ S1_DIV_VALUES = [25.0, 25.0]
39
+ S2_BANDS = [
40
+ "B2",
41
+ "B3",
42
+ "B4",
43
+ "B5",
44
+ "B6",
45
+ "B7",
46
+ "B8",
47
+ "B8A",
48
+ "B11",
49
+ "B12",
50
+ ]
51
+ S2_SHIFT_VALUES = [0.0] * len(S2_BANDS)
52
+ S2_DIV_VALUES = [1e4] * len(S2_BANDS)
53
+ ERA5_BANDS = ["temperature_2m", "total_precipitation_sum"]
54
+ # for temperature, shift to celcius and then divide by 35 based on notebook (ranges from)
55
+ # 37 to -22 degrees celcius
56
+ # For rainfall, based on
57
+ # https://github.com/nasaharvest/presto/blob/main/notebooks/exploratory_data_analysis.ipynb
58
+ ERA5_SHIFT_VALUES = [-272.15, 0.0]
59
+ ERA5_DIV_VALUES = [35.0, 0.03]
60
+ TC_BANDS = ["def", "soil", "aet"]
61
+ TC_SHIFT_VALUES = [0.0, 0.0, 0.0]
62
+ TC_DIV_VALUES = [4548, 8882, 2000]
63
+ VIIRS_BANDS = ["avg_rad"]
64
+ VIIRS_SHIFT_VALUES = [0.0]
65
+ # visually checked - this seems much more reasonable than
66
+ # the GEE estimate
67
+ VIIRS_DIV_VALUES = [100]
68
+ SRTM_BANDS = ["elevation", "slope"]
69
+ # visually gauged 90th percentile from
70
+ # https://github.com/nasaharvest/presto/blob/main/notebooks/exploratory_data_analysis.ipynb
71
+ SRTM_SHIFT_VALUES = [0.0, 0.0]
72
+ SRTM_DIV_VALUES = [2000.0, 50.0]
73
+ DW_BANDS = [
74
+ "DW_water",
75
+ "DW_trees",
76
+ "DW_grass",
77
+ "DW_flooded_vegetation",
78
+ "DW_crops",
79
+ "DW_shrub_and_scrub",
80
+ "DW_built",
81
+ "DW_bare",
82
+ "DW_snow_and_ice",
83
+ ]
84
+ DW_SHIFT_VALUES = [0] * len(DW_BANDS)
85
+ DW_DIV_VALUES = [1] * len(DW_BANDS)
86
+
87
+ WC_BANDS = [
88
+ "WC_temporarycrops",
89
+ "WC_maize",
90
+ "WC_wintercereals",
91
+ "WC_springcereals",
92
+ "WC_irrigation",
93
+ ]
94
+ WC_SHIFT_VALUES = [0] * len(WC_BANDS)
95
+ WC_DIV_VALUES = [100] * len(WC_BANDS)
96
+ STATIC_DW_BANDS = [f"{x}_static" for x in DW_BANDS]
97
+ STATIC_WC_BANDS = [f"{x}_static" for x in WC_BANDS]
98
+
99
+ LANDSCAN_BANDS = ["b1"]
100
+ # LANDSCAN values range from approximately 0 to 185000 in 2022: https://code.earthengine.google.com/?scriptPath=users/sat-io/awesome-gee-catalog-examples:population-socioeconomics/LANDSCAN-GLOBAL
101
+ LANDSCAN_SHIFT_VALUES = [92500]
102
+ LANDSCAN_DIV_VALUES = [92500]
103
+ LOCATION_BANDS = ["x", "y", "z"]
104
+
105
+ SPACE_TIME_BANDS = S1_BANDS + S2_BANDS + ["NDVI"]
106
+ TIME_BANDS = ERA5_BANDS + TC_BANDS + VIIRS_BANDS
107
+ SPACE_BANDS = SRTM_BANDS + DW_BANDS + WC_BANDS
108
+ STATIC_BANDS = LANDSCAN_BANDS + LOCATION_BANDS + STATIC_DW_BANDS + STATIC_WC_BANDS
109
+
110
+ # 0 for NDVI
111
+ SPACE_TIME_SHIFT_VALUES = np.array(S1_SHIFT_VALUES + S2_SHIFT_VALUES + [0])
112
+ SPACE_TIME_DIV_VALUES = np.array(S1_DIV_VALUES + S2_DIV_VALUES + [1])
113
+ TIME_SHIFT_VALUES = np.array(ERA5_SHIFT_VALUES + TC_SHIFT_VALUES + VIIRS_SHIFT_VALUES)
114
+ TIME_DIV_VALUES = np.array(ERA5_DIV_VALUES + TC_DIV_VALUES + VIIRS_DIV_VALUES)
115
+ SPACE_SHIFT_VALUES = np.array(SRTM_SHIFT_VALUES + DW_SHIFT_VALUES + WC_SHIFT_VALUES)
116
+ SPACE_DIV_VALUES = np.array(SRTM_DIV_VALUES + DW_DIV_VALUES + WC_DIV_VALUES)
117
+ # [0s, 1s] for the locations
118
+ STATIC_SHIFT_VALUES = np.array(
119
+ LANDSCAN_SHIFT_VALUES + [0, 0, 0] + DW_SHIFT_VALUES + WC_SHIFT_VALUES
120
+ )
121
+ STATIC_DIV_VALUES = np.array(
122
+ LANDSCAN_DIV_VALUES + [1, 1, 1] + DW_DIV_VALUES + WC_DIV_VALUES
123
+ )
124
+
125
+ SPACE_TIME_BANDS_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
126
+ {
127
+ "S1": [SPACE_TIME_BANDS.index(b) for b in S1_BANDS],
128
+ "S2_RGB": [SPACE_TIME_BANDS.index(b) for b in ["B2", "B3", "B4"]],
129
+ "S2_Red_Edge": [SPACE_TIME_BANDS.index(b) for b in ["B5", "B6", "B7"]],
130
+ "S2_NIR_10m": [SPACE_TIME_BANDS.index(b) for b in ["B8"]],
131
+ "S2_NIR_20m": [SPACE_TIME_BANDS.index(b) for b in ["B8A"]],
132
+ "S2_SWIR": [SPACE_TIME_BANDS.index(b) for b in ["B11", "B12"]],
133
+ "NDVI": [SPACE_TIME_BANDS.index("NDVI")],
134
+ }
135
+ )
136
+
137
+ TIME_BAND_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
138
+ {
139
+ "ERA5": [TIME_BANDS.index(b) for b in ERA5_BANDS],
140
+ "TC": [TIME_BANDS.index(b) for b in TC_BANDS],
141
+ "VIIRS": [TIME_BANDS.index(b) for b in VIIRS_BANDS],
142
+ }
143
+ )
144
+
145
+ SPACE_BAND_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
146
+ {
147
+ "SRTM": [SPACE_BANDS.index(b) for b in SRTM_BANDS],
148
+ "DW": [SPACE_BANDS.index(b) for b in DW_BANDS],
149
+ "WC": [SPACE_BANDS.index(b) for b in WC_BANDS],
150
+ }
151
+ )
152
+
153
+ STATIC_BAND_GROUPS_IDX: OrderedDictType[str, list[int]] = OrderedDict(
154
+ {
155
+ "LS": [STATIC_BANDS.index(b) for b in LANDSCAN_BANDS],
156
+ "location": [STATIC_BANDS.index(b) for b in LOCATION_BANDS],
157
+ "DW_static": [STATIC_BANDS.index(b) for b in STATIC_DW_BANDS],
158
+ "WC_static": [STATIC_BANDS.index(b) for b in STATIC_WC_BANDS],
159
+ }
160
+ )
161
+
162
+
163
+ # this normalizing dict is sourced from
164
+ # https://github.com/nasaharvest/galileo/blob/main/config/normalization.json
165
+ # its used to normalize the data. The keys (e.g. "13") are used to query
166
+ # which tensor (e.g. space_time_x) is associated to the means and stds,
167
+ # where the key represents the number of dimensions in the tensor (i.e. x.shape[-1])
168
+ NORMALIZING_DICT = {
169
+ "total_n": 127155,
170
+ "sampled_n": 10000,
171
+ "13": {
172
+ "mean": [
173
+ -11.728724389184965,
174
+ -18.85558188024017,
175
+ 1395.3408730676722,
176
+ 1338.4026921784578,
177
+ 1343.09883810357,
178
+ 1543.8607982512297,
179
+ 2186.2022069512263,
180
+ 2525.0932853316694,
181
+ 2410.3377187373408,
182
+ 2750.2854646886753,
183
+ 2234.911100061487,
184
+ 1474.5311266077113,
185
+ 0.2892116502999044,
186
+ ],
187
+ "std": [
188
+ 4.887145774840316,
189
+ 5.730270320384293,
190
+ 917.7041440370853,
191
+ 913.2988423581528,
192
+ 1092.678723527555,
193
+ 1047.2206083460424,
194
+ 1048.0101611156767,
195
+ 1143.6903026819996,
196
+ 1098.979177731649,
197
+ 1204.472755085893,
198
+ 1145.9774063078878,
199
+ 980.2429840007796,
200
+ 0.2720939024500081,
201
+ ],
202
+ },
203
+ "16": {
204
+ "mean": [
205
+ 673.0152819503361,
206
+ 5.930092668915115,
207
+ 0.10470439140978786,
208
+ 0.23965913270066183,
209
+ 0.08158044385860364,
210
+ 0.04246976254259546,
211
+ 0.11304392863520317,
212
+ 0.17329647890362473,
213
+ 0.0698981691616277,
214
+ 0.12130267132802142,
215
+ 0.04671318615236216,
216
+ 10.973119802517362,
217
+ 1.0927069179958768,
218
+ 1.6991394232855903,
219
+ 0.03720594618055555,
220
+ 1.3671352688259548,
221
+ ],
222
+ "std": [
223
+ 983.0697298296237,
224
+ 8.167406789813247,
225
+ 0.18771647977504985,
226
+ 0.2368313455675914,
227
+ 0.08024268534756586,
228
+ 0.04045374496146404,
229
+ 0.11350342472061795,
230
+ 0.1279898111718168,
231
+ 0.12042341550438586,
232
+ 0.13602408145504347,
233
+ 0.043971116096060345,
234
+ 31.255340146970997,
235
+ 10.395974878206689,
236
+ 12.92380617159917,
237
+ 1.9285254295940466,
238
+ 11.612179775408928,
239
+ ],
240
+ },
241
+ "6": {
242
+ "mean": [
243
+ 271.5674963541667,
244
+ 0.08554303677156568,
245
+ 657.3181260091111,
246
+ 692.1291795806885,
247
+ 562.781331880633,
248
+ 1.5647115934036673,
249
+ ],
250
+ "std": [
251
+ 79.80828940314429,
252
+ 0.11669547098151486,
253
+ 704.0008695557707,
254
+ 925.0116126406431,
255
+ 453.2434022278578,
256
+ 7.513020170832818,
257
+ ],
258
+ },
259
+ "18": {
260
+ "mean": [
261
+ 188.20315880851746,
262
+ 0.2804946561574936,
263
+ 0.11371652073860168,
264
+ 0.058778801321983334,
265
+ 0.10474256777763366,
266
+ 0.2396918488264084,
267
+ 0.08152248692512512,
268
+ 0.04248040814399719,
269
+ 0.11303179881572724,
270
+ 0.17326324067115784,
271
+ 0.06998309404850006,
272
+ 0.12122812910079957,
273
+ 0.04671641788482666,
274
+ 10.98456594619751,
275
+ 1.0968475807189941,
276
+ 1.6947754135131836,
277
+ 0.03320046615600586,
278
+ 1.3602827312469483,
279
+ ],
280
+ "std": [
281
+ 1154.5919128300602,
282
+ 0.5276998078079327,
283
+ 0.7021637331734328,
284
+ 0.36528892213195063,
285
+ 0.17470213191865785,
286
+ 0.20411195416718833,
287
+ 0.0660782470089761,
288
+ 0.03380702424871257,
289
+ 0.09809195568521663,
290
+ 0.11292471052124119,
291
+ 0.09720748930233268,
292
+ 0.12912217763726777,
293
+ 0.0399973913151906,
294
+ 23.725471823867462,
295
+ 5.715238079725388,
296
+ 9.030481416228302,
297
+ 0.9950220242487364,
298
+ 7.754429123862099,
299
+ ],
300
+ },
301
+ }
302
+
303
+
304
+ class Normalizer:
305
+ """Normalize Galileo inputs."""
306
+
307
+ std_bands: dict[int, list] = {
308
+ # we exclude NDVI because its already between 0 and 1, so we don't
309
+ # want to apply further normalization to it.
310
+ len(SPACE_TIME_BANDS): [b for b in SPACE_TIME_BANDS if b != "NDVI"],
311
+ len(SPACE_BANDS): SRTM_BANDS,
312
+ len(TIME_BANDS): TIME_BANDS,
313
+ len(STATIC_BANDS): LANDSCAN_BANDS,
314
+ }
315
+
316
+ def __init__(self, std_multiplier: float = 2):
317
+ """Normalize Galileo inputs.
318
+
319
+ Args:
320
+ std_multiplier: std_multiplier to apply
321
+ """
322
+ name_to_bands = {
323
+ len(SPACE_TIME_BANDS): SPACE_TIME_BANDS,
324
+ len(SPACE_BANDS): SPACE_BANDS,
325
+ len(TIME_BANDS): TIME_BANDS,
326
+ len(STATIC_BANDS): STATIC_BANDS,
327
+ }
328
+ self.shift_div_dict = {
329
+ len(SPACE_TIME_BANDS): {
330
+ "shift": deepcopy(SPACE_TIME_SHIFT_VALUES),
331
+ "div": deepcopy(SPACE_TIME_DIV_VALUES),
332
+ },
333
+ len(SPACE_BANDS): {
334
+ "shift": deepcopy(SPACE_SHIFT_VALUES),
335
+ "div": deepcopy(SPACE_DIV_VALUES),
336
+ },
337
+ len(TIME_BANDS): {
338
+ "shift": deepcopy(TIME_SHIFT_VALUES),
339
+ "div": deepcopy(TIME_DIV_VALUES),
340
+ },
341
+ len(STATIC_BANDS): {
342
+ "shift": deepcopy(STATIC_SHIFT_VALUES),
343
+ "div": deepcopy(STATIC_DIV_VALUES),
344
+ },
345
+ }
346
+ for key_as_str, val in NORMALIZING_DICT.items():
347
+ if "n" in key_as_str:
348
+ continue
349
+ key = int(key_as_str)
350
+ bands_to_replace = self.std_bands[key]
351
+ for band in bands_to_replace:
352
+ band_idx = name_to_bands[key].index(band)
353
+ mean = cast(dict[str, list], val)["mean"][band_idx]
354
+ std = cast(dict[str, list], val)["std"][band_idx]
355
+ min_value = mean - (std_multiplier * std)
356
+ max_value = mean + (std_multiplier * std)
357
+ div = max_value - min_value
358
+ if div == 0:
359
+ raise ValueError(f"{band} has div value of 0")
360
+ self.shift_div_dict[key]["shift"][band_idx] = min_value
361
+ self.shift_div_dict[key]["div"][band_idx] = div
362
+
363
+ @staticmethod
364
+ def _normalize(
365
+ x: torch.Tensor, shift_values: torch.Tensor, div_values: torch.Tensor
366
+ ) -> torch.Tensor:
367
+ x = (x - shift_values) / div_values
368
+ return x
369
+
370
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
371
+ """Apply the normalizer."""
372
+ div_values = self.shift_div_dict[x.shape[-1]]["div"]
373
+ shift_values = self.shift_div_dict[x.shape[-1]]["shift"]
374
+ return self._normalize(x, shift_values, div_values)
375
+
376
+
377
+ class MaskedOutput(NamedTuple):
378
+ """A masked output (i.e. an input to Galileo).
379
+
380
+ A mask can take 3 values:
381
+ 0: seen by the encoder (i.e. makes the key and value tokens in the decoder)
382
+ 1: not seen by the encoder, and ignored by the decoder
383
+ 2: not seen by the encoder, and processed by the decoder (the decoder's query values)
384
+ """
385
+
386
+ s_t_x: torch.Tensor # [B, H, W, T, len(SPACE_TIME_BANDS)]
387
+ sp_x: torch.Tensor # [B, H, W, len(SPACE_BANDS)]
388
+ t_x: torch.Tensor # [B, T, len(TIME_BANDS)]
389
+ st_x: torch.Tensor # [B, len(STATIC_BANDS)]
390
+ s_t_m: torch.Tensor # [B, H, W, T, len(SPACE_TIME_BANDS_GROUPS_IDX)]
391
+ sp_m: torch.Tensor # [B, H, W, len(SPACE_BAND_GROUPS_IDX)]
392
+ t_m: torch.Tensor # [B, T, len(TIME_BAND_GROUPS_IDX)]
393
+ st_m: torch.Tensor # [B, len(STATIC_BAND_GROUPS_IDX)]
394
+ months: torch.Tensor # [B, T]
395
+
396
+
397
+ def get_2d_sincos_pos_embed_with_resolution(
398
+ embed_dim: int,
399
+ grid_size: int,
400
+ res: torch.Tensor,
401
+ cls_token: bool = False,
402
+ device: str = "cpu",
403
+ ) -> torch.Tensor:
404
+ """Create 2d sincos embeddings with resolution.
405
+
406
+ grid_size: int of the grid height and width
407
+ res: array of size n, representing the resolution of a pixel (say, in meters),
408
+
409
+ Return:
410
+ pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
411
+ """
412
+ res = res.to(device)
413
+ grid_h = torch.arange(grid_size, device=device)
414
+ grid_w = torch.arange(grid_size, device=device)
415
+ grid = torch.meshgrid(
416
+ grid_w, grid_h, indexing="xy"
417
+ ) # here h goes first,direction reversed for numpy
418
+ grid = torch.stack(grid, dim=0) # 2 x h x w
419
+
420
+ # grid = grid.reshape([2, 1, grid_size, grid_size])
421
+ grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
422
+ _, n, h, w = grid.shape
423
+ pos_embed = get_2d_sincos_pos_embed_from_grid_torch(
424
+ embed_dim, grid
425
+ ) # # (nxH*W, D/2)
426
+ pos_embed = pos_embed.reshape(n, h * w, embed_dim)
427
+ if cls_token:
428
+ pos_embed = torch.cat(
429
+ [
430
+ torch.zeros([n, 1, embed_dim], device=pos_embed.device),
431
+ pos_embed,
432
+ ],
433
+ dim=1,
434
+ )
435
+ return pos_embed
436
+
437
+
438
+ def get_2d_sincos_pos_embed_from_grid_torch(
439
+ embed_dim: int, grid: torch.Tensor
440
+ ) -> torch.Tensor:
441
+ """get_2d_sincos_pos_embed_from_grid_torch."""
442
+ assert embed_dim % 2 == 0
443
+
444
+ # use half of dimensions to encode grid_h
445
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(
446
+ embed_dim // 2, grid[0]
447
+ ) # (H*W, D/2)
448
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(
449
+ embed_dim // 2, grid[1]
450
+ ) # (H*W, D/2)
451
+
452
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
453
+ return emb
454
+
455
+
456
+ def get_1d_sincos_pos_embed_from_grid_torch(
457
+ embed_dim: int, pos: torch.Tensor
458
+ ) -> torch.Tensor:
459
+ """get_1d_sincos_pos_embed_from_grid_torch.
460
+
461
+ embed_dim: output dimension for each position
462
+ pos: a list of positions to be encoded: size (M,)
463
+ out: (M, D)
464
+ """
465
+ assert embed_dim % 2 == 0
466
+ omega = torch.arange(embed_dim // 2, device=pos.device) / embed_dim / 2.0
467
+ omega = 1.0 / 10000**omega # (D/2,)
468
+
469
+ pos = pos.reshape(-1) # (M,)
470
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
471
+
472
+ emb_sin = torch.sin(out) # (M, D/2)
473
+ emb_cos = torch.cos(out) # (M, D/2)
474
+
475
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
476
+ return emb
477
+
478
+
479
+ def get_month_encoding_table(embed_dim: int) -> torch.Tensor:
480
+ """Sinusoid month encoding table, for 12 months indexed from 0-11."""
481
+ assert embed_dim % 2 == 0
482
+ angles = torch.arange(0, 13) / (12 / (2 * np.pi))
483
+
484
+ sin_table = torch.sin(torch.stack([angles for _ in range(embed_dim // 2)], axis=-1))
485
+ cos_table = torch.cos(torch.stack([angles for _ in range(embed_dim // 2)], axis=-1))
486
+ month_table = torch.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
487
+
488
+ return month_table # (M, D)
489
+
490
+
491
+ def adjust_learning_rate(
492
+ optimizer: torch.optim.Optimizer,
493
+ epoch: int,
494
+ warmup_epochs: int,
495
+ total_epochs: int,
496
+ max_lr: float,
497
+ min_lr: float,
498
+ ) -> float:
499
+ """Decay the learning rate with half-cycle cosine after warmup."""
500
+ if epoch < warmup_epochs:
501
+ lr = max_lr * epoch / warmup_epochs
502
+ else:
503
+ lr = min_lr + (max_lr - min_lr) * 0.5 * (
504
+ 1.0
505
+ + math.cos(
506
+ math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
507
+ )
508
+ )
509
+ for group in optimizer.param_groups:
510
+ group["lr"] = lr
511
+ return lr
512
+
513
+
514
+ # thanks to https://github.com/bwconrad/flexivit/ for this nice implementation
515
+ # of the FlexiPatchEmbed module
516
+ def to_2tuple(x: int | tuple[int, int]) -> tuple[int, int]:
517
+ """to_2tuple."""
518
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
519
+ return tuple(x) # type: ignore
520
+ return tuple(itertools.repeat(x, 2)) # type: ignore
521
+
522
+
523
+ class FlexiPatchEmbed(nn.Module):
524
+ """FlexiPatchEmbed."""
525
+
526
+ def __init__(
527
+ self,
528
+ patch_size: int | tuple[int, int],
529
+ in_chans: int = 3,
530
+ embed_dim: int = 128,
531
+ norm_layer: nn.Module | None = None,
532
+ bias: bool = True,
533
+ patch_size_seq: Sequence[int] = (1, 2, 3, 4, 5, 6),
534
+ interpolation: str = "bicubic",
535
+ antialias: bool = True,
536
+ ) -> None:
537
+ """2D image to patch embedding w/ flexible patch sizes.
538
+
539
+ Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
540
+ by https://github.com/bwconrad/flexivit/
541
+
542
+ Args:
543
+ patch_size: Base patch size. i.e the size of the parameter buffer
544
+ in_chans: Number of input image channels
545
+ embed_dim: Network embedding dimension size
546
+ norm_layer: Optional normalization layer
547
+ bias: Whether to use bias in convolution
548
+ patch_size_seq: List of patch sizes to randomly sample from
549
+ interpolation: Resize interpolation type
550
+ antialias: Whether to apply antialiasing resizing
551
+ """
552
+ super().__init__()
553
+
554
+ self.patch_size = to_2tuple(patch_size)
555
+
556
+ self.proj = nn.Conv2d(
557
+ in_chans,
558
+ embed_dim,
559
+ kernel_size=self.patch_size,
560
+ stride=self.patch_size,
561
+ bias=bias,
562
+ )
563
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
564
+
565
+ # Flexi specific attributes
566
+ self.interpolation = interpolation
567
+ self.antialias = antialias
568
+
569
+ self.patch_size_seq = patch_size_seq
570
+
571
+ # Pre-calculate pinvs
572
+ self.pinvs = self._cache_pinvs()
573
+
574
+ def _cache_pinvs(self) -> dict:
575
+ """Pre-calculate all pinv matrices."""
576
+ pinvs = {}
577
+ for ps in self.patch_size_seq:
578
+ tuple_ps = to_2tuple(ps)
579
+ pinvs[tuple_ps] = self._calculate_pinv(self.patch_size, tuple_ps)
580
+ return pinvs
581
+
582
+ def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor:
583
+ x_resized = F.interpolate(
584
+ x[None, None, ...],
585
+ shape,
586
+ mode=self.interpolation,
587
+ antialias=self.antialias,
588
+ )
589
+ return x_resized[0, 0, ...]
590
+
591
+ def _calculate_pinv(
592
+ self, old_shape: tuple[int, int], new_shape: tuple[int, int]
593
+ ) -> Tensor:
594
+ mat = []
595
+ for i in range(np.prod(old_shape)):
596
+ basis_vec = torch.zeros(old_shape)
597
+ basis_vec[np.unravel_index(i, old_shape)] = 1.0
598
+ mat.append(self._resize(basis_vec, new_shape).reshape(-1))
599
+ resize_matrix = torch.stack(mat)
600
+ return torch.linalg.pinv(resize_matrix)
601
+
602
+ def resize_patch_embed(
603
+ self, patch_embed: Tensor, new_patch_size: tuple[int, int]
604
+ ) -> torch.Tensor:
605
+ """Resize patch_embed to target resolution via pseudo-inverse resizing."""
606
+ # Return original kernel if no resize is necessary
607
+ if self.patch_size == new_patch_size:
608
+ return patch_embed
609
+
610
+ # Calculate pseudo-inverse of resize matrix
611
+ if new_patch_size not in self.pinvs:
612
+ self.pinvs[new_patch_size] = self._calculate_pinv(
613
+ self.patch_size, new_patch_size
614
+ )
615
+ pinv = self.pinvs[new_patch_size]
616
+ pinv = pinv.to(patch_embed.device)
617
+
618
+ def resample_patch_embed(patch_embed: Tensor) -> torch.Tensor:
619
+ h, w = new_patch_size
620
+ resampled_kernel = pinv @ patch_embed.reshape(-1)
621
+ return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
622
+
623
+ v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
624
+
625
+ return v_resample_patch_embed(patch_embed)
626
+
627
+ def forward(
628
+ self,
629
+ x: Tensor,
630
+ patch_size: int | tuple[int, int] | None = None,
631
+ ) -> Tensor | tuple[Tensor, tuple[int, int]]:
632
+ """Forward pass."""
633
+ # x has input shape [b, h, w, (t), c]
634
+ batch_size = x.shape[0]
635
+ has_time_dimension = False
636
+ num_timesteps = 0 # ignored if has_time_dimension is False
637
+ if len(x.shape) == 5:
638
+ has_time_dimension = True
639
+ num_timesteps = x.shape[3]
640
+ x = rearrange(x, "b h w t c -> (b t) c h w")
641
+ else:
642
+ x = rearrange(x, "b h w c -> b c h w")
643
+
644
+ if not patch_size:
645
+ # During evaluation use base patch size if not specified
646
+ patch_size = self.patch_size
647
+
648
+ patch_size = to_2tuple(patch_size)
649
+
650
+ # Resize conv weights
651
+ if patch_size == self.patch_size:
652
+ weight = self.proj.weight
653
+ else:
654
+ weight = self.resize_patch_embed(self.proj.weight, patch_size)
655
+ # Apply conv with resized weights
656
+ x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
657
+
658
+ if has_time_dimension:
659
+ x = rearrange(x, "(b t) c h w -> b h w t c", b=batch_size, t=num_timesteps)
660
+ else:
661
+ x = rearrange(x, "b c h w -> b h w c")
662
+ x = self.norm(x)
663
+
664
+ return x
665
+
666
+
667
+ class Attention(nn.Module):
668
+ """Attention."""
669
+
670
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
671
+ fast_attn: Final[bool]
672
+
673
+ def __init__(
674
+ self,
675
+ dim: int,
676
+ num_heads: int = 8,
677
+ qkv_bias: bool = False,
678
+ qk_norm: bool = False,
679
+ attn_drop: float = 0.0,
680
+ proj_drop: float = 0.0,
681
+ norm_layer: nn.Module = nn.LayerNorm,
682
+ cross_attn: bool = False,
683
+ ) -> None:
684
+ """Initialize attention."""
685
+ super().__init__()
686
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
687
+ self.num_heads = num_heads
688
+ self.head_dim = dim // num_heads
689
+ self.scale = self.head_dim**-0.5
690
+ self.fast_attn = hasattr(
691
+ torch.nn.functional, "scaled_dot_product_attention"
692
+ ) # FIXME
693
+
694
+ self.cross_attn = cross_attn
695
+
696
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
697
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
698
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
699
+
700
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
701
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
702
+ self.attn_drop = nn.Dropout(attn_drop)
703
+ self.proj = nn.Linear(dim, dim)
704
+ self.proj_drop = nn.Dropout(proj_drop)
705
+
706
+ def forward(
707
+ self,
708
+ x: torch.Tensor,
709
+ y: torch.Tensor | None = None,
710
+ attn_mask: torch.Tensor | None = None,
711
+ ) -> torch.Tensor:
712
+ """Forward pass."""
713
+ B, N, C = x.shape
714
+
715
+ q = self.q(x)
716
+
717
+ if y is None:
718
+ assert not self.cross_attn
719
+ k = self.k(x)
720
+ v = self.v(x)
721
+ else:
722
+ assert self.cross_attn
723
+ k = self.k(y)
724
+ v = self.v(y)
725
+
726
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
727
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
728
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
729
+
730
+ q, k = self.q_norm(q), self.k_norm(k)
731
+ if self.fast_attn:
732
+ if attn_mask is not None:
733
+ attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, N, 1))
734
+ x = F.scaled_dot_product_attention(
735
+ q,
736
+ k,
737
+ v,
738
+ # a value of True indicates that the element should take part in attention
739
+ attn_mask=attn_mask,
740
+ dropout_p=self.attn_drop.p,
741
+ )
742
+ else:
743
+ if attn_mask is not None:
744
+ raise NotImplementedError
745
+ q = q * self.scale
746
+ attn = q @ k.transpose(-2, -1)
747
+ attn = attn.softmax(dim=-1)
748
+ attn = self.attn_drop(attn)
749
+ x = attn @ v
750
+
751
+ x = x.transpose(1, 2).reshape(B, N, C)
752
+ x = self.proj(x)
753
+ x = self.proj_drop(x)
754
+ return x
755
+
756
+
757
+ class Mlp(nn.Module):
758
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks."""
759
+
760
+ def __init__(
761
+ self,
762
+ in_features: int,
763
+ hidden_features: int | None = None,
764
+ out_features: int | None = None,
765
+ act_layer: nn.Module = nn.GELU,
766
+ bias: bool = True,
767
+ drop: float = 0.0,
768
+ ) -> None:
769
+ """Initialize the MLP."""
770
+ super().__init__()
771
+ out_features = out_features or in_features
772
+ hidden_features = hidden_features or in_features
773
+
774
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
775
+ self.act = act_layer()
776
+ self.drop1 = nn.Dropout(drop)
777
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
778
+ self.drop2 = nn.Dropout(drop)
779
+
780
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
781
+ """Forward pass."""
782
+ x = self.fc1(x)
783
+ x = self.act(x)
784
+ x = self.drop1(x)
785
+ x = self.fc2(x)
786
+ x = self.drop2(x)
787
+ return x
788
+
789
+
790
+ class LayerScale(nn.Module):
791
+ """LayerScale."""
792
+
793
+ def __init__(
794
+ self, dim: int, init_values: float = 1e-5, inplace: bool = False
795
+ ) -> None:
796
+ """Init layerscale."""
797
+ super().__init__()
798
+ self.inplace = inplace
799
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
800
+
801
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
802
+ """Forward pass."""
803
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
804
+
805
+
806
+ def drop_path(
807
+ x: torch.Tensor, drop_prob: float = 0.0, training: bool = False
808
+ ) -> torch.Tensor:
809
+ """Drop path."""
810
+ if drop_prob == 0.0 or not training:
811
+ return x
812
+ keep_prob = 1 - drop_prob
813
+ shape = (x.shape[0],) + (1,) * (
814
+ x.ndim - 1
815
+ ) # work with diff dim tensors, not just 2D ConvNets
816
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
817
+ random_tensor.floor_() # binarize
818
+ output = x.div(keep_prob) * random_tensor
819
+ return output
820
+
821
+
822
+ class DropPath(nn.Module):
823
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
824
+
825
+ def __init__(self, drop_prob: float) -> None:
826
+ """Init."""
827
+ super().__init__()
828
+ self.drop_prob = drop_prob
829
+
830
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
831
+ """Forward."""
832
+ return drop_path(x, self.drop_prob, self.training)
833
+
834
+
835
+ class Block(nn.Module):
836
+ """An Attention block."""
837
+
838
+ def __init__(
839
+ self,
840
+ dim: int,
841
+ num_heads: int,
842
+ mlp_ratio: float = 4.0,
843
+ qkv_bias: bool = False,
844
+ qk_norm: bool = False,
845
+ drop: float = 0.0,
846
+ attn_drop: float = 0.0,
847
+ drop_path: float = 0.0,
848
+ init_values: float | None = None,
849
+ act_layer: nn.Module = nn.GELU,
850
+ norm_layer: nn.Module = nn.LayerNorm,
851
+ cross_attn: bool = False,
852
+ ) -> None:
853
+ """Init."""
854
+ super().__init__()
855
+ self.norm1 = norm_layer(dim)
856
+ self.attn = Attention(
857
+ dim,
858
+ num_heads=num_heads,
859
+ qkv_bias=qkv_bias,
860
+ qk_norm=qk_norm,
861
+ attn_drop=attn_drop,
862
+ proj_drop=drop,
863
+ norm_layer=norm_layer,
864
+ cross_attn=cross_attn,
865
+ )
866
+ self.ls1 = (
867
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
868
+ )
869
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
870
+
871
+ self.norm2 = norm_layer(dim)
872
+ self.mlp = Mlp(
873
+ in_features=dim,
874
+ hidden_features=int(dim * mlp_ratio),
875
+ act_layer=act_layer,
876
+ drop=drop,
877
+ )
878
+ self.ls2 = (
879
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
880
+ )
881
+
882
+ def forward(
883
+ self, x: torch.Tensor, y: torch.Tensor, attn_mask: torch.Tensor | None
884
+ ) -> torch.Tensor:
885
+ """Forward."""
886
+ x = x + self.drop_path(self.ls1(self.attn(self.norm1(x), y, attn_mask)))
887
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
888
+ return x
889
+
890
+
891
+ class ModuleListWithInit(nn.ModuleList):
892
+ """module list with an init function."""
893
+
894
+ def _init_weights(self, m: nn.Module) -> None:
895
+ if isinstance(m, nn.Linear):
896
+ # we use xavier_uniform following official JAX ViT:
897
+ torch.nn.init.xavier_uniform_(m.weight)
898
+ if isinstance(m, nn.Linear) and m.bias is not None:
899
+ nn.init.constant_(m.bias, 0)
900
+
901
+
902
+ class GalileoBase(nn.Module):
903
+ """Galileo Base."""
904
+
905
+ def __init__(
906
+ self,
907
+ embedding_size: int = 128,
908
+ depth: int = 2,
909
+ mlp_ratio: int = 2,
910
+ num_heads: int = 8,
911
+ max_sequence_length: int = 24,
912
+ base_patch_size: int = 4,
913
+ use_channel_embs: bool = True,
914
+ drop_path: float = 0.0,
915
+ ) -> None:
916
+ """Init."""
917
+ super().__init__()
918
+
919
+ self.space_time_groups = SPACE_TIME_BANDS_GROUPS_IDX
920
+ self.space_groups = SPACE_BAND_GROUPS_IDX
921
+ self.time_groups = TIME_BAND_GROUPS_IDX
922
+ self.static_groups = STATIC_BAND_GROUPS_IDX
923
+ self.embedding_size = embedding_size
924
+ self.base_patch_size = base_patch_size
925
+
926
+ self.blocks = ModuleListWithInit(
927
+ [
928
+ Block(
929
+ embedding_size,
930
+ num_heads,
931
+ mlp_ratio,
932
+ qkv_bias=True,
933
+ norm_layer=nn.LayerNorm,
934
+ cross_attn=self.cross_attn,
935
+ drop_path=drop_path,
936
+ )
937
+ for _ in range(depth)
938
+ ]
939
+ )
940
+
941
+ self.max_sequence_length = max_sequence_length
942
+ # we have 4 embeddings (pos_in_time, pos_in_space, month, channel) so each get
943
+ # 0.25 of the dimension. This will change soon anyway
944
+ self.pos_embed = nn.Parameter(
945
+ get_1d_sincos_pos_embed_from_grid_torch(
946
+ int(embedding_size * 0.25), torch.arange(max_sequence_length)
947
+ ),
948
+ requires_grad=False,
949
+ )
950
+ month_tab = get_month_encoding_table(int(embedding_size * 0.25))
951
+ self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
952
+ if use_channel_embs:
953
+ args = {"requires_grad": True}
954
+ else:
955
+ args = {"requires_grad": False}
956
+ self.s_t_channel_embed = nn.Parameter(
957
+ torch.zeros(len(SPACE_TIME_BANDS_GROUPS_IDX), int(embedding_size * 0.25)),
958
+ **args,
959
+ )
960
+ self.sp_channel_embed = nn.Parameter(
961
+ torch.zeros(len(SPACE_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args
962
+ )
963
+ self.t_channel_embed = nn.Parameter(
964
+ torch.zeros(len(TIME_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args
965
+ )
966
+ self.st_channel_embed = nn.Parameter(
967
+ torch.zeros(len(STATIC_BAND_GROUPS_IDX), int(embedding_size * 0.25)), **args
968
+ )
969
+
970
+ self.apply(self._init_weights)
971
+
972
+ @property
973
+ @abstractmethod
974
+ def cross_attn(self) -> bool:
975
+ """Whether to use cross attention."""
976
+ pass
977
+
978
+ def _init_weights(self, m: nn.Module) -> None:
979
+ if isinstance(m, nn.Linear):
980
+ # we use xavier_uniform following official JAX ViT:
981
+ torch.nn.init.xavier_uniform_(m.weight)
982
+ if isinstance(m, nn.Linear) and m.bias is not None:
983
+ nn.init.constant_(m.bias, 0)
984
+
985
+ @classmethod
986
+ def collapse_and_combine_hwtc(
987
+ cls,
988
+ s_t_x: torch.Tensor,
989
+ sp_x: torch.Tensor,
990
+ t_x: torch.Tensor,
991
+ st_x: torch.Tensor,
992
+ s_t_m: torch.Tensor,
993
+ sp_m: torch.Tensor,
994
+ t_m: torch.Tensor,
995
+ st_m: torch.Tensor,
996
+ ) -> tuple[torch.Tensor, torch.Tensor]:
997
+ """collapse_and_combine_hwtc."""
998
+ s_t_x = rearrange(s_t_x, "b h w t c_g d -> b (h w t c_g) d")
999
+ sp_x = rearrange(sp_x, "b h w c_g d -> b (h w c_g) d")
1000
+ t_x = rearrange(t_x, "b t c_g d -> b (t c_g) d")
1001
+
1002
+ s_t_m = rearrange(s_t_m, "b h w t c_g-> b (h w t c_g)")
1003
+ sp_m = rearrange(sp_m, "b h w c_g-> b (h w c_g)")
1004
+ t_m = rearrange(t_m, "b t c_g -> b (t c_g)")
1005
+
1006
+ x = torch.cat(
1007
+ [
1008
+ s_t_x,
1009
+ sp_x,
1010
+ t_x,
1011
+ st_x,
1012
+ ],
1013
+ dim=1,
1014
+ )
1015
+ m = torch.cat([s_t_m, sp_m, t_m, st_m], dim=1)
1016
+ return x, m
1017
+
1018
+ @classmethod
1019
+ def split_and_expand_hwtc(
1020
+ cls,
1021
+ x: torch.Tensor,
1022
+ h: int,
1023
+ w: int,
1024
+ t: int,
1025
+ s_t_c_g: int,
1026
+ sp_c_g: int,
1027
+ t_c_g: int,
1028
+ st_c_g: int,
1029
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1030
+ """split_and_expand_hwtc."""
1031
+ n_s_t_t = h * w * t * s_t_c_g
1032
+ n_t_t = t * t_c_g
1033
+
1034
+ s_t_x = rearrange(
1035
+ x[:, :n_s_t_t], "b (h w t c) d -> b h w t c d", h=h, w=w, t=t, c=s_t_c_g
1036
+ )
1037
+ sp_x = rearrange(
1038
+ x[:, n_s_t_t : -(n_t_t + st_c_g)],
1039
+ "b (h w c) d -> b h w c d",
1040
+ h=h,
1041
+ w=w,
1042
+ c=sp_c_g,
1043
+ )
1044
+ t_x = rearrange(
1045
+ x[:, -(n_t_t + st_c_g) : -st_c_g], "b (t c) d -> b t c d", t=t, c=t_c_g
1046
+ )
1047
+ st_x = x[:, -st_c_g:]
1048
+
1049
+ return s_t_x, sp_x, t_x, st_x
1050
+
1051
+ def apply_encodings(
1052
+ self,
1053
+ s_t_x: torch.Tensor,
1054
+ sp_x: torch.Tensor,
1055
+ t_x: torch.Tensor,
1056
+ st_x: torch.Tensor,
1057
+ months: torch.Tensor,
1058
+ patch_size: int,
1059
+ input_res: int,
1060
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1061
+ """apply_encodings."""
1062
+ b, h, w, t, s_t_c_g, _ = s_t_x.shape
1063
+ sp_c_g, t_c_g = sp_x.shape[-2], t_x.shape[-2]
1064
+ st_c_g = st_x.shape[-2]
1065
+
1066
+ s_t_channel = repeat(
1067
+ self.s_t_channel_embed, "c_g d -> b h w t c_g d", b=b, h=h, w=w, t=t
1068
+ )
1069
+ t_channel = repeat(self.t_channel_embed, "c_g d -> b t c_g d", b=b, t=t)
1070
+ st_channel = repeat(self.st_channel_embed, "c_g d -> b c_g d", b=b)
1071
+ sp_channel = repeat(
1072
+ self.sp_channel_embed, "c_g d -> b h w c_g d", b=b, h=h, w=w
1073
+ )
1074
+
1075
+ pos_embed_s_t = repeat(
1076
+ self.pos_embed[:t], "t d -> b h w t c_g d", b=b, h=h, w=w, c_g=s_t_c_g
1077
+ )
1078
+ m_embed_s_t = repeat(
1079
+ self.month_embed(months), "b t d -> b h w t c_g d", h=h, w=w, c_g=s_t_c_g
1080
+ )
1081
+
1082
+ pos_embed_t = repeat(self.pos_embed[:t], "t d -> b t c_g d", b=b, c_g=t_c_g)
1083
+ m_embed_t = repeat(self.month_embed(months), "b t d -> b t c_g d", c_g=t_c_g)
1084
+ t_zeros = torch.zeros(
1085
+ b, t, t_c_g, int(self.embedding_size * 0.25), device=t_x.device
1086
+ )
1087
+
1088
+ sp_zeros = torch.zeros(
1089
+ b,
1090
+ h,
1091
+ w,
1092
+ sp_c_g,
1093
+ sp_channel.shape[-1] * 2,
1094
+ device=sp_channel.device,
1095
+ )
1096
+
1097
+ st_zeros = torch.zeros(
1098
+ b, st_c_g, st_channel.shape[-1] * 3, device=st_channel.device
1099
+ )
1100
+
1101
+ # find the resolution that each token represents, which will be
1102
+ # the number of pixels in a patch * the resolution of each pixel
1103
+ if patch_size is None:
1104
+ patch_size = self.base_patch_size
1105
+ token_res = input_res * patch_size
1106
+ gsd_ratio = token_res / BASE_GSD
1107
+
1108
+ assert h == w, (
1109
+ "get_2d_sincos_pos_embed_with_resolution currently requires that h==w"
1110
+ )
1111
+ spatial_embed = get_2d_sincos_pos_embed_with_resolution(
1112
+ int(self.embedding_size * 0.25),
1113
+ h,
1114
+ torch.ones(b).to(s_t_x.device) * gsd_ratio,
1115
+ device=s_t_x.device,
1116
+ )
1117
+ spatial_embed = rearrange(spatial_embed, "b (h w) d -> b h w d", h=h, w=w)
1118
+ spatial_embed_s_t = repeat(
1119
+ spatial_embed, "b h w d -> b h w t c_g d", h=h, w=w, t=t, c_g=s_t_c_g
1120
+ )
1121
+ spatial_embed_s = repeat(
1122
+ spatial_embed, "b h w d -> b h w c_g d", h=h, w=w, c_g=sp_c_g
1123
+ )
1124
+
1125
+ s_t_embed = torch.cat(
1126
+ [s_t_channel, pos_embed_s_t, m_embed_s_t, spatial_embed_s_t], dim=-1
1127
+ )
1128
+ sp_embed = torch.cat([sp_channel, sp_zeros, spatial_embed_s], dim=-1)
1129
+ t_embed = torch.cat([t_channel, pos_embed_t, m_embed_t, t_zeros], dim=-1)
1130
+ st_embed = torch.cat([st_channel, st_zeros], dim=-1)
1131
+ return s_t_x + s_t_embed, sp_x + sp_embed, t_x + t_embed, st_x + st_embed
1132
+
1133
+
1134
+ class Encoder(GalileoBase):
1135
+ """Galileo Encoder."""
1136
+
1137
+ def __init__(
1138
+ self,
1139
+ max_patch_size: int = 8,
1140
+ embedding_size: int = 128,
1141
+ depth: int = 2,
1142
+ mlp_ratio: int = 2,
1143
+ num_heads: int = 8,
1144
+ max_sequence_length: int = 24,
1145
+ freeze_projections: bool = False,
1146
+ drop_path: float = 0.0,
1147
+ ) -> None:
1148
+ """Init."""
1149
+ super().__init__(
1150
+ embedding_size,
1151
+ depth,
1152
+ mlp_ratio,
1153
+ num_heads,
1154
+ max_sequence_length,
1155
+ max_patch_size,
1156
+ use_channel_embs=True,
1157
+ drop_path=drop_path,
1158
+ )
1159
+
1160
+ self.space_time_embed = nn.ModuleDict(
1161
+ {
1162
+ group_name: FlexiPatchEmbed(
1163
+ in_chans=len(group),
1164
+ embed_dim=embedding_size,
1165
+ patch_size=max_patch_size,
1166
+ )
1167
+ for group_name, group in self.space_time_groups.items()
1168
+ }
1169
+ )
1170
+ self.space_embed = nn.ModuleDict(
1171
+ {
1172
+ group_name: FlexiPatchEmbed(
1173
+ in_chans=len(group),
1174
+ embed_dim=embedding_size,
1175
+ patch_size=max_patch_size,
1176
+ )
1177
+ for group_name, group in self.space_groups.items()
1178
+ }
1179
+ )
1180
+ self.time_embed = nn.ModuleDict(
1181
+ {
1182
+ group_name: nn.Linear(
1183
+ in_features=len(group), out_features=embedding_size
1184
+ )
1185
+ for group_name, group in self.time_groups.items()
1186
+ }
1187
+ )
1188
+ self.static_embed = nn.ModuleDict(
1189
+ {
1190
+ group_name: nn.Linear(
1191
+ in_features=len(group), out_features=embedding_size
1192
+ )
1193
+ for group_name, group in self.static_groups.items()
1194
+ }
1195
+ )
1196
+ if freeze_projections:
1197
+ self.space_time_embed.requires_grad_(False)
1198
+ self.space_embed.requires_grad_(False)
1199
+ self.time_embed.requires_grad_(False)
1200
+ self.static_embed.requires_grad_(False)
1201
+ self.norm = nn.LayerNorm(embedding_size)
1202
+
1203
+ self.apply(self._init_weights)
1204
+
1205
+ @property
1206
+ @override
1207
+ def cross_attn(self) -> bool:
1208
+ """Whether to use cross attention."""
1209
+ return False
1210
+
1211
+ def _init_weights(self, m: nn.Module) -> None:
1212
+ if isinstance(m, nn.Linear):
1213
+ # we use xavier_uniform following official JAX ViT:
1214
+ torch.nn.init.xavier_uniform_(m.weight)
1215
+ if isinstance(m, nn.Linear) and m.bias is not None:
1216
+ nn.init.constant_(m.bias, 0)
1217
+
1218
+ def apply_linear_projection(
1219
+ self,
1220
+ s_t_x: torch.Tensor,
1221
+ sp_x: torch.Tensor,
1222
+ t_x: torch.Tensor,
1223
+ st_x: torch.Tensor,
1224
+ s_t_m: torch.Tensor,
1225
+ sp_m: torch.Tensor,
1226
+ t_m: torch.Tensor,
1227
+ st_m: torch.Tensor,
1228
+ patch_size: int,
1229
+ ) -> tuple[
1230
+ torch.Tensor,
1231
+ torch.Tensor,
1232
+ torch.Tensor,
1233
+ torch.Tensor,
1234
+ torch.Tensor,
1235
+ torch.Tensor,
1236
+ torch.Tensor,
1237
+ torch.Tensor,
1238
+ ]:
1239
+ """apply_linear_projection.
1240
+
1241
+ Given a [B, H, W, (T), C] inputs, returns a [B, H, W, (T), C_G, D] output.
1242
+ We assume that the spatial masks are consistent for the given patch size,
1243
+ so that if patch_size == 2 then one possible mask would be
1244
+ [0, 0, 1, 1]
1245
+ [0, 0, 1, 1]
1246
+ [1, 1, 0, 0]
1247
+ [1, 1, 0, 0]
1248
+ for the H, W dimensions
1249
+ """
1250
+ b, h, w, t, _ = s_t_x.shape
1251
+ new_h, new_w = h // patch_size, w // patch_size
1252
+
1253
+ s_t_l, sp_l, t_l, st_l, s_t_m_l, sp_m_l, t_m_l, st_m_l = (
1254
+ [],
1255
+ [],
1256
+ [],
1257
+ [],
1258
+ [],
1259
+ [],
1260
+ [],
1261
+ [],
1262
+ )
1263
+ for idx, (channel_group, channel_idxs) in enumerate(
1264
+ self.space_time_groups.items()
1265
+ ):
1266
+ s_t_m_l.append(s_t_m[:, 0::patch_size, 0::patch_size, :, idx])
1267
+ if s_t_m_l[-1].min() == 0:
1268
+ s_t_l.append(
1269
+ self.space_time_embed[channel_group](
1270
+ s_t_x[:, :, :, :, channel_idxs], patch_size=patch_size
1271
+ )
1272
+ )
1273
+ else:
1274
+ s_t_l.append(
1275
+ torch.zeros(
1276
+ b,
1277
+ new_h,
1278
+ new_w,
1279
+ t,
1280
+ self.embedding_size,
1281
+ dtype=s_t_x.dtype,
1282
+ device=s_t_x.device,
1283
+ )
1284
+ )
1285
+ for idx, (channel_group, channel_idxs) in enumerate(self.space_groups.items()):
1286
+ sp_m_l.append(sp_m[:, 0::patch_size, 0::patch_size, idx])
1287
+ if sp_m_l[-1].min() == 0:
1288
+ sp_l.append(
1289
+ self.space_embed[channel_group](
1290
+ sp_x[:, :, :, channel_idxs], patch_size=patch_size
1291
+ )
1292
+ )
1293
+ else:
1294
+ sp_l.append(
1295
+ torch.zeros(
1296
+ b,
1297
+ new_h,
1298
+ new_w,
1299
+ self.embedding_size,
1300
+ dtype=sp_x.dtype,
1301
+ device=sp_x.device,
1302
+ )
1303
+ )
1304
+
1305
+ for idx, (channel_group, channel_idxs) in enumerate(self.time_groups.items()):
1306
+ t_m_l.append(t_m[:, :, idx])
1307
+ if t_m_l[-1].min() == 0:
1308
+ t_l.append(self.time_embed[channel_group](t_x[:, :, channel_idxs]))
1309
+ else:
1310
+ t_l.append(
1311
+ torch.zeros(
1312
+ b, t, self.embedding_size, dtype=t_x.dtype, device=t_x.device
1313
+ )
1314
+ )
1315
+
1316
+ for idx, (channel_group, channel_idxs) in enumerate(self.static_groups.items()):
1317
+ st_m_l.append(st_m[:, idx])
1318
+ if st_m_l[-1].min() == 0:
1319
+ st_l.append(self.static_embed[channel_group](st_x[:, channel_idxs]))
1320
+ else:
1321
+ st_l.append(
1322
+ torch.zeros(
1323
+ b, self.embedding_size, dtype=st_x.dtype, device=st_x.device
1324
+ )
1325
+ )
1326
+
1327
+ return (
1328
+ torch.stack(s_t_l, dim=-2),
1329
+ torch.stack(sp_l, dim=-2),
1330
+ torch.stack(t_l, dim=-2),
1331
+ torch.stack(st_l, dim=-2),
1332
+ torch.stack(s_t_m_l, dim=-1),
1333
+ torch.stack(sp_m_l, dim=-1),
1334
+ torch.stack(t_m_l, dim=-1),
1335
+ torch.stack(st_m_l, dim=-1),
1336
+ )
1337
+
1338
+ @staticmethod
1339
+ def remove_masked_tokens(
1340
+ x: torch.Tensor, mask: torch.Tensor
1341
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1342
+ """Remove masked tokens."""
1343
+ org_mask_dtype = mask.dtype
1344
+ mask = mask.bool()
1345
+ # https://stackoverflow.com/a/68621610/2332296
1346
+ # move all non-masked values to the front of their rows
1347
+ sorted_mask, indices = torch.sort(
1348
+ (~mask).int(), dim=1, descending=True, stable=True
1349
+ )
1350
+ x = x.gather(1, indices[:, :, None].expand_as(x))
1351
+ # set masked values to 0 (not really necessary since we'll ignore them anyway)
1352
+ x = x * sorted_mask.unsqueeze(-1)
1353
+
1354
+ # cut off to the length of the longest sequence
1355
+ max_length = sorted_mask.sum(-1).max()
1356
+ x = x[:, :max_length]
1357
+ updated_mask = 1 - sorted_mask[:, :max_length]
1358
+
1359
+ return x, indices, updated_mask.to(dtype=org_mask_dtype)
1360
+
1361
+ @staticmethod
1362
+ def add_removed_tokens(
1363
+ x: torch.Tensor, indices: torch.Tensor, mask: torch.Tensor
1364
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1365
+ """add_removed_tokens."""
1366
+ masked_tokens = repeat(
1367
+ torch.zeros_like(x[0, 0, :]), "d -> b t d", b=x.shape[0], t=indices.shape[1]
1368
+ )
1369
+ full_mask = torch.cat(
1370
+ (
1371
+ mask,
1372
+ torch.ones(
1373
+ (x.shape[0], indices.shape[1] - x.shape[1]),
1374
+ device=x.device,
1375
+ dtype=mask.dtype,
1376
+ ),
1377
+ ),
1378
+ dim=-1,
1379
+ )
1380
+ # can't set value on leaf variable
1381
+ out = masked_tokens.clone()
1382
+ # put tokens in full masked tensor (at the first N positions in every row)
1383
+ out[~full_mask.bool()] = x[~mask.bool()]
1384
+ # then move them to their original positions
1385
+ out = out.scatter(1, indices[:, :, None].expand_as(out), out)
1386
+ full_mask = full_mask.scatter(1, indices.expand_as(full_mask), full_mask)
1387
+ return out, full_mask
1388
+
1389
+ def apply_attn(
1390
+ self,
1391
+ s_t_x: torch.Tensor,
1392
+ sp_x: torch.Tensor,
1393
+ t_x: torch.Tensor,
1394
+ st_x: torch.Tensor,
1395
+ s_t_m: torch.Tensor,
1396
+ sp_m: torch.Tensor,
1397
+ t_m: torch.Tensor,
1398
+ st_m: torch.Tensor,
1399
+ months: torch.Tensor,
1400
+ patch_size: int,
1401
+ input_res: int,
1402
+ exit_after: int | None,
1403
+ token_exit_cfg: dict | None,
1404
+ ) -> tuple[
1405
+ torch.Tensor,
1406
+ torch.Tensor,
1407
+ torch.Tensor,
1408
+ torch.Tensor,
1409
+ torch.Tensor,
1410
+ torch.Tensor,
1411
+ torch.Tensor,
1412
+ torch.Tensor,
1413
+ ]:
1414
+ """apply_attn."""
1415
+ if token_exit_cfg:
1416
+ exit_s_t, exit_sp, exit_t, exit_st = self.create_token_exit_ids(
1417
+ s_t_x, sp_x, t_x, st_x, token_exit_cfg
1418
+ )
1419
+ exit_ids_seq, _ = self.collapse_and_combine_hwtc(
1420
+ exit_s_t, exit_sp, exit_t, exit_st, s_t_m, sp_m, t_m, st_m
1421
+ )
1422
+ # exited_tokens starts as linear projections!
1423
+ exited_tokens, _ = self.collapse_and_combine_hwtc(
1424
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
1425
+ )
1426
+ else:
1427
+ exit_ids_seq = None
1428
+ exited_tokens = None
1429
+
1430
+ _, h, w, t, s_t_c_g, _ = s_t_x.shape
1431
+ sp_c_g, t_c_g, st_c_g = sp_x.shape[3], t_x.shape[-2], st_x.shape[-2]
1432
+ s_t_x, sp_x, t_x, st_x = self.apply_encodings(
1433
+ s_t_x, sp_x, t_x, st_x, months, patch_size, input_res
1434
+ )
1435
+ x, m = self.collapse_and_combine_hwtc(
1436
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
1437
+ )
1438
+
1439
+ # we only care about the values >= 1 for this mask, since 2 just tells the decoder
1440
+ # to decode those tokens. From the perspective of the encoder, 1 and 2 are equivalent
1441
+ # since they both represent masked values
1442
+ new_m = m >= 1
1443
+ x, indices, new_m = self.remove_masked_tokens(
1444
+ x, new_m
1445
+ ) # new_m is shape (bsz, seq_len)
1446
+
1447
+ if exit_ids_seq is not None:
1448
+ exit_ids_seq, _, _ = self.remove_masked_tokens(exit_ids_seq, m >= 1)
1449
+ # still linear projections
1450
+ exited_tokens, _, _ = self.remove_masked_tokens(exited_tokens, m >= 1)
1451
+
1452
+ for i_blk, blk in enumerate(self.blocks):
1453
+ if (exit_after is not None) and ((i_blk + 1) > exit_after):
1454
+ # if exit_after is N, then we exit after the Nth layer
1455
+ # if exit_after is 0, then all layers are skipped
1456
+ break
1457
+
1458
+ # skip the 0th block since this is just the linear
1459
+ # projection
1460
+ if (exit_ids_seq is not None) and (i_blk > 0):
1461
+ assert exited_tokens is not None
1462
+ # half depth
1463
+ exited_tokens = torch.where(
1464
+ condition=(exit_ids_seq == i_blk),
1465
+ input=x.detach(),
1466
+ other=exited_tokens.detach(),
1467
+ )
1468
+
1469
+ # we take the inverse of the mask because a value
1470
+ # of True indicates the value *should* take part in
1471
+ # attention
1472
+ temp_mask = ~new_m.bool()
1473
+ if temp_mask.all():
1474
+ # if all the tokens are used in attention we can pass a None mask
1475
+ # to the attention block
1476
+ temp_mask = None
1477
+
1478
+ x = blk(x=x, y=None, attn_mask=temp_mask)
1479
+
1480
+ if exit_ids_seq is not None:
1481
+ assert exited_tokens is not None
1482
+ # full depth
1483
+ # IMPORTANT: write this to x
1484
+ x = torch.where(
1485
+ condition=(exit_ids_seq == (i_blk + 1)), # 2 for full depth
1486
+ input=x.detach(),
1487
+ other=exited_tokens.detach(),
1488
+ )
1489
+
1490
+ # we don't care about the mask returned by add_removed_tokens, since we will
1491
+ # just use the original, unclipped mask here
1492
+ x, _ = self.add_removed_tokens(x, indices, new_m)
1493
+ return (
1494
+ *self.split_and_expand_hwtc(x, h, w, t, s_t_c_g, sp_c_g, t_c_g, st_c_g),
1495
+ s_t_m,
1496
+ sp_m,
1497
+ t_m,
1498
+ st_m,
1499
+ )
1500
+
1501
+ @classmethod
1502
+ def average_tokens(
1503
+ cls,
1504
+ s_t_x: torch.Tensor,
1505
+ sp_x: torch.Tensor,
1506
+ t_x: torch.Tensor,
1507
+ st_x: torch.Tensor,
1508
+ s_t_m: torch.Tensor,
1509
+ sp_m: torch.Tensor,
1510
+ t_m: torch.Tensor,
1511
+ st_m: torch.Tensor,
1512
+ ) -> torch.Tensor:
1513
+ """average_tokens."""
1514
+ x, m = cls.collapse_and_combine_hwtc(
1515
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
1516
+ )
1517
+ x, _, m = cls.remove_masked_tokens(x, m)
1518
+ x_for_mean = x * (1 - m.unsqueeze(-1))
1519
+ return x_for_mean.sum(dim=1) / torch.sum(1 - m, -1, keepdim=True)
1520
+
1521
+ @classmethod
1522
+ def apply_mask_and_average_tokens_per_patch(
1523
+ cls,
1524
+ s_t_x: torch.Tensor,
1525
+ sp_x: torch.Tensor,
1526
+ t_x: torch.Tensor,
1527
+ st_x: torch.Tensor,
1528
+ s_t_m: torch.Tensor,
1529
+ sp_m: torch.Tensor,
1530
+ t_m: torch.Tensor,
1531
+ st_m: torch.Tensor,
1532
+ ) -> torch.Tensor:
1533
+ """apply_mask_and_average_tokens_per_patch."""
1534
+ s_t_x = rearrange(s_t_x, "b t_h t_w t c_g d -> b (t_h t_w) (t c_g) d")
1535
+ sp_x = rearrange(sp_x, "b t_h t_w c_g d -> b (t_h t_w) c_g d")
1536
+ # repeat time tokens over space
1537
+ t_x = repeat(
1538
+ rearrange(t_x, "b t c_g d -> b (t c_g) d"),
1539
+ "b n d -> b s n d",
1540
+ s=sp_x.shape[1],
1541
+ )
1542
+ st_x = repeat(st_x, "b c_g d -> b s c_g d", s=sp_x.shape[1])
1543
+ s_t_m = rearrange(s_t_m, "b t_h t_w t c_g-> b (t_h t_w) (t c_g)")
1544
+ sp_m = rearrange(sp_m, "b t_h t_w c_g-> b (t_h t_w) c_g")
1545
+ t_m = repeat(
1546
+ rearrange(t_m, "b t c_g -> b (t c_g)"), "b n -> b s n", s=sp_x.shape[1]
1547
+ )
1548
+ st_m = repeat(st_m, "b c_g -> b s c_g", s=sp_x.shape[1])
1549
+
1550
+ x = torch.cat([s_t_x, sp_x, t_x, st_x], dim=2) # B, S, N, D
1551
+ m = torch.cat([s_t_m, sp_m, t_m, st_m], dim=2) # B, S, N
1552
+
1553
+ x_for_mean = x * (1 - m.unsqueeze(-1))
1554
+
1555
+ return x_for_mean.sum(dim=2) / torch.sum(1 - m, -1, keepdim=True)
1556
+
1557
+ def create_token_exit_ids(
1558
+ self,
1559
+ s_t_x: torch.Tensor,
1560
+ sp_x: torch.Tensor,
1561
+ t_x: torch.Tensor,
1562
+ st_x: torch.Tensor,
1563
+ token_exit_cfg: dict,
1564
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1565
+ """create_token_exit_ids."""
1566
+ exit_s_t = torch.zeros_like(s_t_x)
1567
+ exit_sp = torch.zeros_like(sp_x)
1568
+ exit_t = torch.zeros_like(t_x)
1569
+ exit_st = torch.zeros_like(st_x)
1570
+
1571
+ for idx, (key, _) in enumerate(self.space_time_groups.items()):
1572
+ exit_s_t[:, :, :, :, idx, :] = token_exit_cfg[key]
1573
+
1574
+ for idx, (key, _) in enumerate(self.space_groups.items()):
1575
+ exit_sp[:, :, :, idx, :] = token_exit_cfg[key]
1576
+
1577
+ for idx, (key, _) in enumerate(self.time_groups.items()):
1578
+ exit_t[:, :, idx, :] = token_exit_cfg[key]
1579
+
1580
+ for idx, (key, _) in enumerate(self.static_groups.items()):
1581
+ exit_st[:, idx, :] = token_exit_cfg[key]
1582
+ return exit_s_t, exit_sp, exit_t, exit_st
1583
+
1584
+ def forward(
1585
+ self,
1586
+ s_t_x: torch.Tensor,
1587
+ sp_x: torch.Tensor,
1588
+ t_x: torch.Tensor,
1589
+ st_x: torch.Tensor,
1590
+ s_t_m: torch.Tensor,
1591
+ sp_m: torch.Tensor,
1592
+ t_m: torch.Tensor,
1593
+ st_m: torch.Tensor,
1594
+ months: torch.Tensor,
1595
+ patch_size: int,
1596
+ input_resolution_m: int = BASE_GSD,
1597
+ exit_after: int | None = None,
1598
+ token_exit_cfg: dict | None = None,
1599
+ add_layernorm_on_exit: bool = True,
1600
+ ) -> tuple[
1601
+ torch.Tensor,
1602
+ torch.Tensor,
1603
+ torch.Tensor,
1604
+ torch.Tensor,
1605
+ torch.Tensor,
1606
+ torch.Tensor,
1607
+ torch.Tensor,
1608
+ torch.Tensor,
1609
+ torch.Tensor,
1610
+ ]:
1611
+ """Forward."""
1612
+ (
1613
+ s_t_x,
1614
+ sp_x,
1615
+ t_x,
1616
+ st_x,
1617
+ s_t_m,
1618
+ sp_m,
1619
+ t_m,
1620
+ st_m,
1621
+ ) = self.apply_linear_projection(
1622
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, patch_size
1623
+ )
1624
+
1625
+ if (exit_after is None) or (exit_after > 0):
1626
+ s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m = self.apply_attn(
1627
+ s_t_x,
1628
+ sp_x,
1629
+ t_x,
1630
+ st_x,
1631
+ s_t_m,
1632
+ sp_m,
1633
+ t_m,
1634
+ st_m,
1635
+ months,
1636
+ patch_size,
1637
+ input_resolution_m,
1638
+ exit_after=exit_after,
1639
+ token_exit_cfg=token_exit_cfg,
1640
+ )
1641
+
1642
+ if add_layernorm_on_exit:
1643
+ s_t_x = self.norm(s_t_x)
1644
+ sp_x = self.norm(sp_x)
1645
+ t_x = self.norm(t_x)
1646
+ st_x = self.norm(st_x)
1647
+
1648
+ return (s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, months)
1649
+
1650
+ @classmethod
1651
+ def load_from_folder(cls, folder: Path, device: torch.device) -> "Encoder":
1652
+ """Load a model from a folder containing an encoder.pt and config.json."""
1653
+ if not (folder / CONFIG_FILENAME).exists():
1654
+ all_files_in_folder = [f.name for f in folder.glob("*")]
1655
+ raise ValueError(
1656
+ f"Expected {CONFIG_FILENAME} in {folder}, found {all_files_in_folder}"
1657
+ )
1658
+ if not (folder / ENCODER_FILENAME).exists():
1659
+ all_files_in_folder = [f.name for f in folder.glob("*")]
1660
+ raise ValueError(
1661
+ f"Expected {ENCODER_FILENAME} in {folder}, found {all_files_in_folder}"
1662
+ )
1663
+
1664
+ with (folder / CONFIG_FILENAME).open("r") as f:
1665
+ config = json.load(f)
1666
+ model_config = config["model"]
1667
+ encoder_config = model_config["encoder"]
1668
+ encoder = cls(**encoder_config)
1669
+
1670
+ state_dict = torch.load(
1671
+ folder / ENCODER_FILENAME, map_location=device, weights_only=True
1672
+ )
1673
+ for key in list(state_dict.keys()):
1674
+ # this cleans the state dict, which occasionally had an extra
1675
+ # ".backbone" included in the key names
1676
+ state_dict[key.replace(".backbone", "")] = state_dict.pop(key)
1677
+ encoder.load_state_dict(state_dict)
1678
+ return encoder