rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,926 @@
1
+ """Single file Presto.
2
+
3
+ Copied from https://github.com/nasaharvest/presto/blob/main/single_file_presto.py
4
+ with modifications.
5
+ """
6
+
7
+ import math
8
+ from collections import OrderedDict
9
+ from copy import deepcopy
10
+ from typing import cast
11
+
12
+ import numpy as np
13
+ import torch
14
+ from einops import repeat
15
+ from torch import nn
16
+ from torch.jit import Final
17
+ from torch.nn import functional as F
18
+
19
+ # naming convention matches helios.data.constants
20
+ PRESTO_S2_BANDS = [
21
+ "B02",
22
+ "B03",
23
+ "B04",
24
+ "B05",
25
+ "B06",
26
+ "B07",
27
+ "B08",
28
+ "B8A",
29
+ "B09",
30
+ "B11",
31
+ "B12",
32
+ ]
33
+ PRESTO_S1_BANDS = ["vv", "vh"]
34
+ ERA5_BANDS = ["temperature_2m", "total_precipitation"]
35
+ SRTM_BANDS = ["elevation", "slope"]
36
+ PRESTO_BANDS = PRESTO_S1_BANDS + PRESTO_S2_BANDS + ERA5_BANDS + SRTM_BANDS + ["NDVI"]
37
+
38
+ # used in normalization
39
+ PRESTO_ADD_BY = torch.Tensor(
40
+ [
41
+ 25.0,
42
+ 25.0,
43
+ 0.0,
44
+ 0.0,
45
+ 0.0,
46
+ 0.0,
47
+ 0.0,
48
+ 0.0,
49
+ 0.0,
50
+ 0.0,
51
+ 0.0,
52
+ 0.0,
53
+ -272.15,
54
+ 0.0,
55
+ 0.0,
56
+ 0.0,
57
+ 0.0,
58
+ ]
59
+ )
60
+ PRESTO_DIV_BY = torch.Tensor(
61
+ [
62
+ 25.0,
63
+ 25.0,
64
+ 1e4,
65
+ 1e4,
66
+ 1e4,
67
+ 1e4,
68
+ 1e4,
69
+ 1e4,
70
+ 1e4,
71
+ 1e4,
72
+ 1e4,
73
+ 1e4,
74
+ 35.0,
75
+ 0.03,
76
+ 2000.0,
77
+ 50.0,
78
+ 1.0,
79
+ ]
80
+ )
81
+
82
+
83
+ BANDS_GROUPS_IDX = OrderedDict(
84
+ [
85
+ ("S1", [0, 1]),
86
+ ("S2_RGB", [2, 3, 4]),
87
+ ("S2_Red_Edge", [5, 6, 7]),
88
+ ("S2_NIR_10m", [8]),
89
+ ("S2_NIR_20m", [9]),
90
+ ("S2_SWIR", [10, 11]),
91
+ ("ERA5", [12, 13]),
92
+ ("SRTM", [14, 15]),
93
+ ("NDVI", [16]),
94
+ ]
95
+ )
96
+
97
+ NUM_DYNAMIC_WORLD_CLASSES = 9
98
+
99
+
100
+ class Attention(nn.Module):
101
+ """Attention."""
102
+
103
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
104
+ fast_attn: Final[bool]
105
+
106
+ def __init__(
107
+ self,
108
+ dim: int,
109
+ num_heads: int = 8,
110
+ qkv_bias: bool = False,
111
+ qk_norm: bool = False,
112
+ attn_drop: float = 0.0,
113
+ proj_drop: float = 0.0,
114
+ norm_layer: nn.Module = nn.LayerNorm,
115
+ ) -> None:
116
+ """Init."""
117
+ super().__init__()
118
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
119
+ self.num_heads = num_heads
120
+ self.head_dim = dim // num_heads
121
+ self.scale = self.head_dim**-0.5
122
+ self.fast_attn = hasattr(
123
+ torch.nn.functional, "scaled_dot_product_attention"
124
+ ) # FIXME
125
+
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
128
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ self.proj = nn.Linear(dim, dim)
131
+ self.proj_drop = nn.Dropout(proj_drop)
132
+
133
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
134
+ """Forward."""
135
+ B, N, C = x.shape
136
+ qkv = (
137
+ self.qkv(x)
138
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
139
+ .permute(2, 0, 3, 1, 4)
140
+ )
141
+ q, k, v = qkv.unbind(0)
142
+ q, k = self.q_norm(q), self.k_norm(k)
143
+
144
+ if self.fast_attn:
145
+ x = F.scaled_dot_product_attention(
146
+ q,
147
+ k,
148
+ v,
149
+ dropout_p=self.attn_drop.p,
150
+ )
151
+ else:
152
+ q = q * self.scale
153
+ attn = q @ k.transpose(-2, -1)
154
+ attn = attn.softmax(dim=-1)
155
+ attn = self.attn_drop(attn)
156
+ x = attn @ v
157
+
158
+ x = x.transpose(1, 2).reshape(B, N, C)
159
+ x = self.proj(x)
160
+ x = self.proj_drop(x)
161
+ return x
162
+
163
+
164
+ class Mlp(nn.Module):
165
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks."""
166
+
167
+ def __init__(
168
+ self,
169
+ in_features: int,
170
+ hidden_features: int | None = None,
171
+ out_features: int | None = None,
172
+ act_layer: nn.Module = nn.GELU,
173
+ bias: bool = True,
174
+ drop: float = 0.0,
175
+ ) -> None:
176
+ """Init."""
177
+ super().__init__()
178
+ out_features = out_features or in_features
179
+ hidden_features = hidden_features or in_features
180
+
181
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
182
+ self.act = act_layer()
183
+ self.drop1 = nn.Dropout(drop)
184
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
185
+ self.drop2 = nn.Dropout(drop)
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ """Forward."""
189
+ x = self.fc1(x)
190
+ x = self.act(x)
191
+ x = self.drop1(x)
192
+ x = self.fc2(x)
193
+ x = self.drop2(x)
194
+ return x
195
+
196
+
197
+ class LayerScale(nn.Module):
198
+ """LayerScale."""
199
+
200
+ def __init__(
201
+ self, dim: int, init_values: float = 1e-5, inplace: bool = False
202
+ ) -> None:
203
+ """__init__."""
204
+ super().__init__()
205
+ self.inplace = inplace
206
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
207
+
208
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
209
+ """Forward."""
210
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
211
+
212
+
213
+ class Block(nn.Module):
214
+ """Block."""
215
+
216
+ def __init__(
217
+ self,
218
+ dim: int,
219
+ num_heads: int,
220
+ mlp_ratio: float = 4.0,
221
+ qkv_bias: bool = False,
222
+ qk_norm: bool = False,
223
+ drop: float = 0.0,
224
+ attn_drop: float = 0.0,
225
+ init_values: float | None = None,
226
+ act_layer: nn.Module = nn.GELU,
227
+ norm_layer: nn.Module = nn.LayerNorm,
228
+ ) -> None:
229
+ """Init."""
230
+ super().__init__()
231
+ self.norm1 = norm_layer(dim)
232
+ self.attn = Attention(
233
+ dim,
234
+ num_heads=num_heads,
235
+ qkv_bias=qkv_bias,
236
+ qk_norm=qk_norm,
237
+ attn_drop=attn_drop,
238
+ proj_drop=drop,
239
+ norm_layer=norm_layer,
240
+ )
241
+ self.ls1 = (
242
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
243
+ )
244
+
245
+ self.norm2 = norm_layer(dim)
246
+ self.mlp = Mlp(
247
+ in_features=dim,
248
+ hidden_features=int(dim * mlp_ratio),
249
+ act_layer=act_layer,
250
+ drop=drop,
251
+ )
252
+ self.ls2 = (
253
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
254
+ )
255
+
256
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
257
+ """Forward."""
258
+ x = x + self.ls1(self.attn(self.norm1(x)))
259
+ x = x + self.ls2(self.mlp(self.norm2(x)))
260
+ return x
261
+
262
+
263
+ def get_sinusoid_encoding_table(
264
+ positions: int | list[int], d_hid: int, T: int = 1000
265
+ ) -> torch.Tensor:
266
+ """Sinusoid position encoding table.
267
+
268
+ positions: int or list of integer, if int range(positions)
269
+ """
270
+ if isinstance(positions, int):
271
+ positions = list(range(positions))
272
+
273
+ def cal_angle(position: int, hid_idx: int) -> float:
274
+ return position / np.power(T, 2 * (hid_idx // 2) / d_hid)
275
+
276
+ def get_posi_angle_vec(position: int) -> list[float]:
277
+ return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
278
+
279
+ sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in positions])
280
+
281
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
282
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
283
+
284
+ return torch.FloatTensor(sinusoid_table)
285
+
286
+
287
+ def get_month_encoding_table(d_hid: int) -> torch.Tensor:
288
+ """Sinusoid month encoding table, for 12 months indexed from 0-11."""
289
+ assert d_hid % 2 == 0
290
+ angles = np.arange(0, 13) / (12 / (2 * np.pi))
291
+
292
+ sin_table = np.sin(np.stack([angles for _ in range(d_hid // 2)], axis=-1))
293
+ cos_table = np.cos(np.stack([angles for _ in range(d_hid // 2)], axis=-1))
294
+ month_table = np.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
295
+
296
+ return torch.FloatTensor(month_table)
297
+
298
+
299
+ def month_to_tensor(
300
+ month: torch.Tensor | int, batch_size: int, seq_len: int, device: torch.device
301
+ ) -> torch.Tensor:
302
+ """month_to_tensor."""
303
+ if isinstance(month, int):
304
+ assert cast(int, month) < 12
305
+ else:
306
+ assert max(cast(torch.Tensor, month.flatten())) < 12
307
+
308
+ if isinstance(month, int):
309
+ # >>> torch.fmod(torch.tensor([9., 10, 11, 12, 13, 14]), 12)
310
+ # tensor([ 9., 10., 11., 0., 1., 2.])
311
+ month = (
312
+ torch.fmod(torch.arange(month, month + seq_len, dtype=torch.long), 12)
313
+ .expand(batch_size, seq_len)
314
+ .to(device)
315
+ )
316
+ elif len(month.shape) == 1:
317
+ month = torch.stack(
318
+ [
319
+ torch.fmod(torch.arange(m, m + seq_len, dtype=torch.long), 12)
320
+ for m in month
321
+ ]
322
+ ).to(device)
323
+ return month
324
+
325
+
326
+ class Encoder(nn.Module):
327
+ """Encoder."""
328
+
329
+ def __init__(
330
+ self,
331
+ embedding_size: int = 128,
332
+ channel_embed_ratio: float = 0.25,
333
+ month_embed_ratio: float = 0.25,
334
+ depth: int = 2,
335
+ mlp_ratio: float = 2,
336
+ num_heads: int = 8,
337
+ max_sequence_length: int = 24,
338
+ ) -> None:
339
+ """Init."""
340
+ super().__init__()
341
+
342
+ self.band_groups = BANDS_GROUPS_IDX
343
+ self.embedding_size = embedding_size
344
+
345
+ # this is used for the channel embedding
346
+ self.band_group_to_idx = {
347
+ group_name: idx
348
+ for idx, (group_name, _) in enumerate(self.band_groups.items())
349
+ }
350
+ self.band_group_to_idx["dynamic_world"] = (
351
+ max(self.band_group_to_idx.values()) + 1
352
+ )
353
+
354
+ self.eo_patch_embed = nn.ModuleDict(
355
+ {
356
+ group_name: nn.Linear(len(group), embedding_size)
357
+ for group_name, group in self.band_groups.items()
358
+ }
359
+ )
360
+ self.dw_embed = nn.Embedding(
361
+ num_embeddings=NUM_DYNAMIC_WORLD_CLASSES + 1, embedding_dim=embedding_size
362
+ )
363
+ self.latlon_embed = nn.Linear(3, embedding_size)
364
+
365
+ self.blocks = nn.ModuleList(
366
+ [
367
+ Block(
368
+ embedding_size,
369
+ num_heads,
370
+ mlp_ratio,
371
+ qkv_bias=True,
372
+ norm_layer=nn.LayerNorm,
373
+ )
374
+ for _ in range(depth)
375
+ ]
376
+ )
377
+ self.norm = nn.LayerNorm(embedding_size)
378
+
379
+ # the positional + monthly + channel embedding
380
+ self.max_sequence_length = max_sequence_length
381
+ pos_embedding_size = int(
382
+ embedding_size * (1 - (channel_embed_ratio + month_embed_ratio))
383
+ )
384
+ channel_embedding_size = int(embedding_size * channel_embed_ratio)
385
+ month_embedding_size = int(embedding_size * month_embed_ratio)
386
+ self.pos_embed = nn.Parameter(
387
+ torch.zeros(1, max_sequence_length, pos_embedding_size), requires_grad=False
388
+ )
389
+ month_tab = get_month_encoding_table(month_embedding_size)
390
+ self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
391
+ self.channel_embed = nn.Embedding(
392
+ num_embeddings=len(self.band_groups) + 1,
393
+ embedding_dim=channel_embedding_size,
394
+ )
395
+
396
+ self.initialize_weights()
397
+
398
+ def initialize_weights(self) -> None:
399
+ """initialize_weights."""
400
+ pos_embed = get_sinusoid_encoding_table(
401
+ self.pos_embed.shape[1], self.pos_embed.shape[-1]
402
+ ).to(device=self.pos_embed.device)
403
+ self.pos_embed.data.copy_(pos_embed)
404
+
405
+ # initialize nn.Linear and nn.LayerNorm
406
+ self.apply(self._init_weights)
407
+
408
+ def _init_weights(self, m: nn.Module) -> None:
409
+ if isinstance(m, nn.Linear):
410
+ # we use xavier_uniform following official JAX ViT:
411
+ torch.nn.init.xavier_uniform_(m.weight)
412
+ if isinstance(m, nn.Linear) and m.bias is not None:
413
+ nn.init.constant_(m.bias, 0)
414
+ elif isinstance(m, nn.LayerNorm):
415
+ nn.init.constant_(m.bias, 0)
416
+ nn.init.constant_(m.weight, 1.0)
417
+
418
+ @staticmethod
419
+ def cartesian(latlons: torch.Tensor) -> torch.Tensor:
420
+ """Cartesian."""
421
+ with torch.no_grad():
422
+ # an embedding is calculated for all timesteps. This is then expanded
423
+ # for each timestep in the sequence
424
+ latlon_radians = latlons * math.pi / 180
425
+ lats, lons = latlon_radians[:, 0], latlon_radians[:, 1]
426
+ x = torch.cos(lats) * torch.cos(lons)
427
+ y = torch.cos(lats) * torch.sin(lons)
428
+ z = torch.sin(lats)
429
+ return torch.stack([x, y, z], dim=-1)
430
+
431
+ @staticmethod
432
+ def mask_tokens(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
433
+ """mask_tokens."""
434
+ summed = mask.sum(
435
+ dim=(1, 2)
436
+ ) # summed tells me the number of masked elements per batch idx
437
+ assert summed.max() == summed.min(), f"{summed.max()}, {summed.min()}"
438
+
439
+ batch_size = x.shape[0]
440
+ removed_elements_per_batch = int(summed.max() / mask.shape[2])
441
+ kept_elements_per_batch = x.shape[1] - removed_elements_per_batch
442
+ embedding_dim = x.shape[-1]
443
+
444
+ # we want the mask to just be the indices of the masked tokens
445
+ indices = repeat(
446
+ torch.arange(0, x.shape[1]).long().to(x.device), "d -> b d", b=x.shape[0]
447
+ )
448
+
449
+ x = x[~mask.bool()].view(batch_size, kept_elements_per_batch, embedding_dim)
450
+
451
+ mask = mask[:, :, 0]
452
+ kept_indices = indices[~mask.bool()].view(batch_size, kept_elements_per_batch)
453
+ removed_indices = indices[mask.bool()].view(
454
+ batch_size, removed_elements_per_batch
455
+ )
456
+
457
+ return x, kept_indices, removed_indices
458
+
459
+ def forward(
460
+ self,
461
+ x: torch.Tensor,
462
+ dynamic_world: torch.Tensor,
463
+ # different from the original
464
+ # presto - latlons can be optionally ignored
465
+ latlons: torch.Tensor | None = None,
466
+ mask: torch.Tensor | None = None,
467
+ month: torch.Tensor | int = 0,
468
+ eval_task: bool = True,
469
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
470
+ """Forward."""
471
+ device = x.device
472
+
473
+ if mask is None:
474
+ mask = torch.zeros_like(x, device=x.device).float()
475
+
476
+ months = month_to_tensor(month, x.shape[0], x.shape[1], device)
477
+ month_embedding = self.month_embed(months)
478
+ positional_embedding = repeat(
479
+ self.pos_embed[:, : x.shape[1], :],
480
+ "b t d -> (repeat b) t d",
481
+ repeat=x.shape[0],
482
+ )
483
+
484
+ # we assume the number of masked patches is the same
485
+ # for all items in the batch. Otherwise things become a headache
486
+ all_tokens, all_masks = [], []
487
+
488
+ for channel_group, channel_idxs in self.band_groups.items():
489
+ tokens = self.eo_patch_embed[channel_group](x[:, :, channel_idxs])
490
+ channel_embedding = self.channel_embed(
491
+ torch.tensor(self.band_group_to_idx[channel_group]).long().to(device)
492
+ )
493
+ channel_embedding = repeat(
494
+ channel_embedding, "d -> b t d", b=x.shape[0], t=x.shape[1]
495
+ )
496
+ if channel_group == "SRTM":
497
+ # for SRTM, we reduce it to a single token instead of
498
+ # a token per timestep
499
+ channel_wise_positional_embedding = torch.cat(
500
+ (
501
+ torch.zeros_like(month_embedding[:, 0:1]),
502
+ channel_embedding[:, 0:1],
503
+ torch.zeros_like(positional_embedding[:, 0:1]),
504
+ ),
505
+ dim=-1,
506
+ )
507
+ indices = slice(0, 1)
508
+ else:
509
+ channel_wise_positional_embedding = torch.cat(
510
+ (month_embedding, channel_embedding, positional_embedding), dim=-1
511
+ )
512
+ indices = slice(None)
513
+
514
+ tokens = tokens[:, indices]
515
+ tokens += channel_wise_positional_embedding
516
+ all_tokens.append(tokens)
517
+ group_mask = repeat(
518
+ torch.max(mask[:, indices, channel_idxs], dim=-1)[0],
519
+ "b t -> b t d",
520
+ d=tokens.shape[-1],
521
+ )
522
+ all_masks.append(group_mask)
523
+
524
+ # then, dynamic world
525
+ tokens = self.dw_embed(dynamic_world)
526
+ channel_embedding = self.channel_embed(
527
+ torch.tensor(self.band_group_to_idx["dynamic_world"]).long().to(device)
528
+ )
529
+ channel_embedding = repeat(
530
+ channel_embedding, "d -> b t d", b=x.shape[0], t=x.shape[1]
531
+ )
532
+ positional_embedding = torch.cat(
533
+ (month_embedding, channel_embedding, positional_embedding), dim=-1
534
+ )
535
+ tokens += positional_embedding
536
+ all_tokens.append(tokens)
537
+
538
+ # now we calculate the mask for these [b, t] tokens
539
+ group_mask = repeat(
540
+ dynamic_world == NUM_DYNAMIC_WORLD_CLASSES,
541
+ "b t -> b t d",
542
+ d=tokens.shape[-1],
543
+ )
544
+ all_masks.append(group_mask)
545
+
546
+ x = torch.cat(all_tokens, dim=1) # [batch, timesteps, embedding_dim]
547
+ mask = torch.cat(all_masks, dim=1) # [batch, timesteps, embedding_dim]
548
+ x, kept_indices, removed_indices = self.mask_tokens(x, mask)
549
+
550
+ # append latlon tokens
551
+ if latlons is not None:
552
+ latlon_tokens = self.latlon_embed(self.cartesian(latlons)).unsqueeze(1)
553
+ x = torch.cat((latlon_tokens, x), dim=1)
554
+
555
+ # apply Transformer blocks
556
+ for blk in self.blocks:
557
+ x = blk(x)
558
+
559
+ # mask will be a boolean of shape [batch, total_num_tokens]
560
+ if eval_task:
561
+ return self.norm(x.mean(dim=1))
562
+ return self.norm(x), kept_indices, removed_indices
563
+
564
+
565
+ class Decoder(nn.Module):
566
+ """Decoder."""
567
+
568
+ def __init__(
569
+ self,
570
+ channel_embeddings: nn.Embedding,
571
+ encoder_embed_dim: int = 128,
572
+ decoder_embed_dim: int = 128,
573
+ decoder_depth: int = 2,
574
+ decoder_num_heads: int = 8,
575
+ mlp_ratio: float = 2,
576
+ max_sequence_length: int = 24,
577
+ ) -> None:
578
+ """Init."""
579
+ super().__init__()
580
+
581
+ self.band_groups = BANDS_GROUPS_IDX
582
+
583
+ # this is used for the channel embedding
584
+ self.band_group_to_idx = {
585
+ group_name: idx
586
+ for idx, (group_name, _) in enumerate(self.band_groups.items())
587
+ }
588
+ self.band_group_to_idx["dynamic_world"] = (
589
+ max(self.band_group_to_idx.values()) + 1
590
+ )
591
+
592
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
593
+
594
+ self.mask_token = nn.Parameter(torch.zeros(decoder_embed_dim))
595
+
596
+ self.decoder_blocks = nn.ModuleList(
597
+ [
598
+ Block(
599
+ decoder_embed_dim,
600
+ decoder_num_heads,
601
+ mlp_ratio,
602
+ qkv_bias=True,
603
+ norm_layer=nn.LayerNorm,
604
+ )
605
+ for _ in range(decoder_depth)
606
+ ]
607
+ )
608
+
609
+ self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
610
+
611
+ self.eo_decoder_pred = nn.ModuleDict(
612
+ {
613
+ group_name: nn.Linear(decoder_embed_dim, len(group))
614
+ for group_name, group in self.band_groups.items()
615
+ }
616
+ )
617
+ self.dw_decoder_pred = nn.Linear(decoder_embed_dim, NUM_DYNAMIC_WORLD_CLASSES)
618
+
619
+ self.channel_embeddings = channel_embeddings
620
+ channel_embedding_dims = channel_embeddings.weight.shape[-1]
621
+ remaining_embeddings = decoder_embed_dim - channel_embedding_dims
622
+ # the positional + monthly + channel embedding
623
+ self.max_sequence_length = max_sequence_length
624
+ self.pos_embed = nn.Parameter(
625
+ torch.zeros(1, max_sequence_length, int(remaining_embeddings) // 2),
626
+ requires_grad=False,
627
+ )
628
+ month_tab = get_month_encoding_table(int(remaining_embeddings) // 2)
629
+ self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
630
+
631
+ self.initialize_weights()
632
+
633
+ def initialize_weights(self) -> None:
634
+ """initialize_weights."""
635
+ pos_embed = get_sinusoid_encoding_table(
636
+ self.pos_embed.shape[1], self.pos_embed.shape[-1]
637
+ ).to(device=self.pos_embed.device)
638
+ self.pos_embed.data.copy_(pos_embed)
639
+
640
+ # initialize nn.Linear and nn.LayerNorm
641
+ self.apply(self._init_weights)
642
+
643
+ def _init_weights(self, m: nn.Module) -> None:
644
+ if isinstance(m, nn.Linear):
645
+ # we use xavier_uniform following official JAX ViT:
646
+ torch.nn.init.xavier_uniform_(m.weight)
647
+ if isinstance(m, nn.Linear) and m.bias is not None:
648
+ nn.init.constant_(m.bias, 0)
649
+ elif isinstance(m, nn.LayerNorm):
650
+ nn.init.constant_(m.bias, 0)
651
+ nn.init.constant_(m.weight, 1.0)
652
+
653
+ def add_masked_tokens(
654
+ self, x: torch.Tensor, kept_indices: torch.Tensor, removed_indices: torch.Tensor
655
+ ) -> torch.Tensor:
656
+ """add_masked_tokens."""
657
+ mask_tokens = repeat(
658
+ self.mask_token, "d -> b t d", b=x.shape[0], t=removed_indices.shape[1]
659
+ )
660
+
661
+ x = torch.cat([x, mask_tokens], dim=1)
662
+
663
+ # sort according to their indices. Shape is [batch, index]
664
+ combined_indices = torch.cat([kept_indices, removed_indices], dim=1) + 1
665
+ # 0 for latlon index
666
+ combined_indices = torch.sort(
667
+ torch.cat(
668
+ [torch.zeros_like(combined_indices[:, 0:1]), combined_indices], dim=1
669
+ )
670
+ )[1]
671
+ # and then tile for each dimension
672
+ combined_indices = repeat(combined_indices, "b t -> b t d", d=x.shape[-1])
673
+ x = torch.gather(x, 1, combined_indices)
674
+ return x
675
+
676
+ def add_embeddings(
677
+ self, x: torch.Tensor, month: torch.Tensor | int
678
+ ) -> torch.Tensor:
679
+ """add_embeddings."""
680
+ num_channel_groups = len(self.band_group_to_idx)
681
+ # -2 since we remove srtm and latlon, and -1 since the srtm
682
+ # channel group doesn't have timesteps
683
+ num_timesteps = int((x.shape[1] - 2) / (num_channel_groups - 1))
684
+ srtm_index = self.band_group_to_idx["SRTM"] * num_timesteps
685
+ months = month_to_tensor(month, x.shape[0], num_timesteps, x.device)
686
+
687
+ # when we expand the encodings, each channel_group gets num_timesteps
688
+ # encodings. However, there is only one SRTM token so we remove the
689
+ # excess SRTM encodings
690
+ remove_mask = torch.full(
691
+ size=(num_timesteps * num_channel_groups,), fill_value=False
692
+ )
693
+ remove_mask[torch.arange(num_timesteps - 1) + srtm_index] = True
694
+
695
+ month_embedding = repeat(
696
+ self.month_embed(months),
697
+ "b t d -> b (repeat t) d",
698
+ repeat=num_channel_groups,
699
+ )
700
+ month_embedding = month_embedding[:, ~remove_mask]
701
+ month_embedding[:, srtm_index] = 0
702
+
703
+ positional_embedding = repeat(
704
+ self.pos_embed[:, :num_timesteps, :],
705
+ "b t d -> (b2 b) (t2 t) d",
706
+ b2=x.shape[0],
707
+ t2=num_channel_groups,
708
+ )
709
+ positional_embedding = positional_embedding[:, ~remove_mask]
710
+ positional_embedding[:, srtm_index] = 0
711
+
712
+ channel_embeddings = torch.repeat_interleave(
713
+ self.channel_embeddings.weight, repeats=num_timesteps, dim=0
714
+ )
715
+ channel_embeddings = repeat(channel_embeddings, "c d -> b c d", b=x.shape[0])
716
+ channel_embeddings = channel_embeddings[:, ~remove_mask]
717
+
718
+ positional_embedding = torch.cat(
719
+ (month_embedding, channel_embeddings, positional_embedding), dim=-1
720
+ )
721
+
722
+ # add the zero embedding for the latlon token
723
+ positional_embedding = torch.cat(
724
+ [torch.zeros_like(positional_embedding[:, 0:1, :]), positional_embedding],
725
+ dim=1,
726
+ )
727
+
728
+ x += positional_embedding
729
+ return x
730
+
731
+ def reconstruct_inputs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
732
+ """reconstruct_inputs."""
733
+ # remove the latlon token
734
+ x = x[:, 1:, :]
735
+
736
+ # split into channel groups
737
+ num_channel_groups = len(self.band_group_to_idx) - 1
738
+ num_timesteps = int((x.shape[1] - 1) / num_channel_groups)
739
+ srtm_index = self.band_group_to_idx["SRTM"] * num_timesteps
740
+ srtm_token = x[:, srtm_index : srtm_index + 1, :]
741
+
742
+ mask = torch.full((x.shape[1],), True, device=x.device)
743
+ mask[torch.tensor(srtm_index)] = False
744
+ x = x[:, mask]
745
+
746
+ x = x.view(x.shape[0], num_channel_groups, num_timesteps, x.shape[-1])
747
+
748
+ eo_output, dw_output = [], None
749
+ for group_name, idx in self.band_group_to_idx.items():
750
+ if group_name == "SRTM":
751
+ eo_output.append(
752
+ repeat(
753
+ self.eo_decoder_pred[group_name](srtm_token),
754
+ "b t d -> b (t2 t) d",
755
+ t2=num_timesteps,
756
+ )
757
+ )
758
+ else:
759
+ if idx > self.band_group_to_idx["SRTM"]:
760
+ idx -= 1
761
+ group_tokens = x[:, idx]
762
+ if group_name == "dynamic_world":
763
+ dw_output = self.dw_decoder_pred(group_tokens)
764
+ else:
765
+ eo_output.append(self.eo_decoder_pred[group_name](group_tokens))
766
+
767
+ # we can just do this concatenation because the BANDS_GROUP_IDX
768
+ # is ordered
769
+ return torch.cat(eo_output, dim=-1), cast(torch.Tensor, dw_output)
770
+
771
+ def forward(
772
+ self,
773
+ x: torch.Tensor,
774
+ kept_indices: torch.Tensor,
775
+ removed_indices: torch.Tensor,
776
+ month: torch.Tensor | int,
777
+ ) -> tuple[torch.Tensor, torch.Tensor]:
778
+ """Forward."""
779
+ x = self.decoder_embed(x)
780
+ x = self.add_masked_tokens(x, kept_indices, removed_indices)
781
+ x = self.add_embeddings(x, month)
782
+
783
+ # apply Transformer blocks
784
+ for blk in self.decoder_blocks:
785
+ x = blk(x)
786
+ x = self.decoder_norm(x)
787
+ return self.reconstruct_inputs(x)
788
+
789
+
790
+ class PrestoFineTuningModel(nn.Module):
791
+ """PrestoFineTuningModel."""
792
+
793
+ def __init__(self, encoder: Encoder, head: nn.Module) -> None:
794
+ """Init."""
795
+ super().__init__()
796
+ self.encoder: Encoder = deepcopy(encoder)
797
+ # make sure the model is trainable, since we can call
798
+ # this having called requires_grad_(False)
799
+ self.encoder.requires_grad_(True)
800
+ # but don't unfreeze the position encoder, which
801
+ # shouldn't be trainable
802
+ self.encoder.pos_embed.requires_grad_(False)
803
+ self.encoder.month_embed.requires_grad_(False)
804
+ self.head = head
805
+
806
+ def forward(
807
+ self,
808
+ x: torch.Tensor,
809
+ dynamic_world: torch.Tensor,
810
+ latlons: torch.Tensor,
811
+ mask: torch.Tensor | None = None,
812
+ month: torch.Tensor | int = 0,
813
+ ) -> torch.Tensor:
814
+ """Forward."""
815
+ return self.head(
816
+ self.encoder(
817
+ x=x,
818
+ dynamic_world=dynamic_world,
819
+ latlons=latlons,
820
+ mask=mask,
821
+ month=month,
822
+ eval_task=True,
823
+ )
824
+ )
825
+
826
+
827
+ class FinetuningHead(nn.Module):
828
+ """FinetuningHead."""
829
+
830
+ def __init__(self, hidden_size: int, num_outputs: int, regression: bool) -> None:
831
+ """__init__."""
832
+ super().__init__()
833
+
834
+ self.hidden_size = hidden_size
835
+ self.num_outputs = num_outputs
836
+ self.regression = regression
837
+ self.linear = nn.Linear(hidden_size, num_outputs)
838
+
839
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
840
+ """Forward."""
841
+ x = self.linear(x)
842
+ if (not self.regression) & (self.num_outputs == 1):
843
+ x = torch.sigmoid(x)
844
+ return x
845
+
846
+
847
+ class Presto(nn.Module):
848
+ """Presto."""
849
+
850
+ def __init__(self, encoder: Encoder, decoder: Decoder):
851
+ """Init."""
852
+ super().__init__()
853
+ self.encoder: Encoder = encoder
854
+ self.decoder: Decoder = decoder
855
+
856
+ def forward(
857
+ self,
858
+ x: torch.Tensor,
859
+ dynamic_world: torch.Tensor,
860
+ latlons: torch.Tensor,
861
+ mask: torch.Tensor | None = None,
862
+ month: torch.Tensor | int = 0,
863
+ ) -> torch.Tensor:
864
+ """Forward."""
865
+ x, kept_indices, removed_indices = self.encoder(
866
+ x=x,
867
+ dynamic_world=dynamic_world,
868
+ latlons=latlons,
869
+ mask=mask,
870
+ month=month,
871
+ eval_task=False,
872
+ )
873
+
874
+ return self.decoder(x, kept_indices, removed_indices, month)
875
+
876
+ @classmethod
877
+ def construct(
878
+ cls,
879
+ encoder_embedding_size: int = 128,
880
+ channel_embed_ratio: float = 0.25,
881
+ month_embed_ratio: float = 0.25,
882
+ encoder_depth: int = 2,
883
+ mlp_ratio: float = 4,
884
+ encoder_num_heads: int = 8,
885
+ decoder_embedding_size: int = 128,
886
+ decoder_depth: int = 2,
887
+ decoder_num_heads: int = 8,
888
+ max_sequence_length: int = 24,
889
+ ) -> "Presto":
890
+ """Construct."""
891
+ encoder = Encoder(
892
+ embedding_size=encoder_embedding_size,
893
+ channel_embed_ratio=channel_embed_ratio,
894
+ month_embed_ratio=month_embed_ratio,
895
+ depth=encoder_depth,
896
+ mlp_ratio=mlp_ratio,
897
+ num_heads=encoder_num_heads,
898
+ max_sequence_length=max_sequence_length,
899
+ )
900
+ decoder = Decoder(
901
+ channel_embeddings=encoder.channel_embed,
902
+ encoder_embed_dim=encoder_embedding_size,
903
+ decoder_embed_dim=decoder_embedding_size,
904
+ decoder_depth=decoder_depth,
905
+ decoder_num_heads=decoder_num_heads,
906
+ mlp_ratio=mlp_ratio,
907
+ max_sequence_length=max_sequence_length,
908
+ )
909
+ return cls(encoder, decoder)
910
+
911
+ def construct_finetuning_model(
912
+ self,
913
+ num_outputs: int,
914
+ regression: bool = False,
915
+ ) -> PrestoFineTuningModel:
916
+ """construct_finetuning_model."""
917
+ head = FinetuningHead(
918
+ num_outputs=num_outputs,
919
+ hidden_size=self.encoder.embedding_size,
920
+ regression=regression,
921
+ )
922
+ model = PrestoFineTuningModel(self.encoder, head).to(
923
+ self.encoder.pos_embed.device
924
+ )
925
+ model.train()
926
+ return model