rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1147 @@
1
+ """Prithvi V2.
2
+
3
+ This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
4
+
5
+ The code is released under:
6
+
7
+ MIT License
8
+ Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
9
+ """
10
+
11
+ import json
12
+ import logging
13
+ import tempfile
14
+ import warnings
15
+ from enum import StrEnum
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from einops import rearrange
23
+ from huggingface_hub import hf_hub_download
24
+ from timm.layers import to_2tuple
25
+ from timm.models.vision_transformer import Block
26
+ from torch.nn import functional as F
27
+
28
+ from rslearn.train.model_context import ModelContext
29
+ from rslearn.train.transforms.normalize import Normalize
30
+ from rslearn.train.transforms.transform import Transform
31
+
32
+ from .component import FeatureExtractor, FeatureMaps
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class PrithviV2Models(StrEnum):
38
+ """Names for different Prithvi models on torch hub."""
39
+
40
+ VIT_300 = "VIT_300"
41
+ VIT_600 = "VIT_600"
42
+
43
+
44
+ MODEL_TO_HF_INFO = {
45
+ PrithviV2Models.VIT_300: {
46
+ "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
47
+ "weights": "Prithvi_EO_V2_300M.pt",
48
+ "revision": "b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
49
+ },
50
+ PrithviV2Models.VIT_600: {
51
+ "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M",
52
+ "weights": "Prithvi_EO_V2_600M.pt",
53
+ "revision": "87f15784813828dc37aa3197a143cd4689e4d080",
54
+ },
55
+ }
56
+
57
+
58
+ HF_HUB_CONFIG_FNAME = "config.json"
59
+ DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), "rslearn_cache", "prithvi_v2")
60
+
61
+
62
+ def get_config(cache_dir: Path, hf_hub_id: str, hf_hub_revision: str) -> dict[str, Any]:
63
+ """Get the JSON config dict.
64
+
65
+ Args:
66
+ cache_dir: the directory to cache the config.json file, which will be
67
+ downloaded from HF Hub.
68
+ hf_hub_id: the HF Hub ID from which to download the config.
69
+ hf_hub_revision: The revision (commit) to download the config from.
70
+ """
71
+ cache_fname = cache_dir / HF_HUB_CONFIG_FNAME
72
+ if not cache_fname.exists():
73
+ _ = hf_hub_download(
74
+ local_dir=cache_dir,
75
+ repo_id=hf_hub_id,
76
+ filename=HF_HUB_CONFIG_FNAME,
77
+ revision=hf_hub_revision,
78
+ ) # nosec
79
+ with cache_fname.open() as f:
80
+ return json.load(f)["pretrained_cfg"]
81
+
82
+
83
+ class PrithviV2(FeatureExtractor):
84
+ """An Rslearn wrapper for Prithvi 2.0."""
85
+
86
+ INPUT_KEY = "image"
87
+
88
+ def __init__(
89
+ self,
90
+ cache_dir: str | Path | None = None,
91
+ size: PrithviV2Models = PrithviV2Models.VIT_300,
92
+ num_frames: int = 1,
93
+ ):
94
+ """Create a new PrithviV2.
95
+
96
+ Args:
97
+ cache_dir: The local folder in which to download the prithvi config and
98
+ weights. If None, it downloads to a temporary folder.
99
+ size: the model size, see class for various models.
100
+ num_frames: The number of input frames (timesteps). The model was trained on 3,
101
+ but if there is just one timestamp examples use 1 (e.g.
102
+ https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/
103
+ example_landslide4sense.ipynb)
104
+
105
+ """
106
+ super().__init__()
107
+ if cache_dir is None:
108
+ cache_dir = DEFAULT_CACHE_DIR
109
+ cache_dir = Path(cache_dir)
110
+
111
+ hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
112
+ revision = MODEL_TO_HF_INFO[size]["revision"]
113
+ checkpoint_fname = MODEL_TO_HF_INFO[size]["weights"]
114
+
115
+ config = get_config(cache_dir, hub_id, revision)
116
+ config["num_frames"] = num_frames
117
+ self.model = PrithviMAE(**config)
118
+
119
+ if not (cache_dir / checkpoint_fname).exists():
120
+ _ = hf_hub_download(
121
+ local_dir=cache_dir,
122
+ repo_id=hub_id,
123
+ filename=checkpoint_fname,
124
+ revision=revision,
125
+ ) # nosec
126
+
127
+ state_dict = torch.load(
128
+ cache_dir / checkpoint_fname,
129
+ map_location="cpu",
130
+ weights_only=True,
131
+ )
132
+ # discard fixed pos_embedding weight, following
133
+ # https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M/blob/e4aabdc440c8ee703a749def8af5bf4700dee35b/inference.py#L362
134
+ for k in list(state_dict.keys()):
135
+ if "pos_embed" in k:
136
+ del state_dict[k]
137
+ self.model.load_state_dict(state_dict, strict=False)
138
+ self.image_resolution = config["img_size"]
139
+ self.bands = config["bands"]
140
+ # patch size is a list [t, h, w], where h == w
141
+ self.patch_size = config["patch_size"][-1]
142
+
143
+ def _resize_data(self, data: torch.Tensor) -> torch.Tensor:
144
+ """Process individual modality data.
145
+
146
+ Args:
147
+ data: Input tensor of shape [B, C, T, H, W]
148
+
149
+ Returns:
150
+ list of tensors of shape [B, C, T, H, W]
151
+ """
152
+ # Get original dimensions
153
+ B, C, T, H, W = data.shape
154
+ data = rearrange(data, "b c t h w -> b (c t) h w")
155
+ original_height = H
156
+ new_height = self.patch_size if original_height == 1 else self.image_resolution
157
+ data = F.interpolate(
158
+ data,
159
+ size=(new_height, new_height),
160
+ mode="bilinear",
161
+ align_corners=False,
162
+ )
163
+ data = rearrange(data, "b (c t) h w -> b c t h w", c=C, t=T)
164
+ return data
165
+
166
+ def forward(self, context: ModelContext) -> FeatureMaps:
167
+ """Compute feature maps from the Prithvi V2 backbone.
168
+
169
+ Args:
170
+ context: the model context. Input dicts must include "image" key containing
171
+ HLS (Harmonized Landsat-Sentinel) data.
172
+
173
+ Returns:
174
+ a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
175
+ feature maps across the 11 transformer blocks.
176
+ """
177
+ # x has shape BCTHW
178
+ x = torch.stack([inp[self.INPUT_KEY].image for inp in context.inputs], dim=0)
179
+ x = self._resize_data(x)
180
+ features = self.model.encoder.forward_features(x)
181
+ # prepare_features_for_image_model was slightly modified since we already
182
+ # know the number of timesteps and don't need to recompute it.
183
+ # in addition we average along the time dimension (instead of concatenating)
184
+ # to keep the embeddings reasonably sized.
185
+ result = self.model.encoder.prepare_features_for_image_model(
186
+ features, x.shape[2]
187
+ )
188
+ return FeatureMaps([torch.cat(result, dim=1)])
189
+
190
+ def get_backbone_channels(self) -> list:
191
+ """Returns the output channels of this model when used as a backbone.
192
+
193
+ The output channels is a list of (patch_size, depth) that corresponds
194
+ to the feature maps that the backbone returns.
195
+
196
+ Returns:
197
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
198
+ """
199
+ return [(1, 1024)]
200
+
201
+
202
+ class PrithviNormalize(Transform):
203
+ """Normalize inputs using Prithvi normalization.
204
+
205
+ Similar to the model, the input should be an image time series under the key
206
+ "image".
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ cache_dir: str | Path | None = None,
212
+ size: PrithviV2Models = PrithviV2Models.VIT_300,
213
+ ) -> None:
214
+ """Initialize a new PrithviNormalize.
215
+
216
+ Args:
217
+ cache_dir: the local directory to cache the config.json which contains the
218
+ means and standard deviations used in the normalization.
219
+ size: the model size, see class for various models. In this case (and
220
+ for the current hf revision), the config values (mean and std) are the
221
+ same for both the 300M and 600M model, so its safe to not set this.
222
+ """
223
+ super().__init__()
224
+ hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
225
+ revision = MODEL_TO_HF_INFO[size]["revision"]
226
+ if cache_dir is None:
227
+ cache_dir = DEFAULT_CACHE_DIR
228
+ cache_dir = Path(cache_dir)
229
+ config = get_config(cache_dir, hub_id, revision)
230
+ self.normalizer = Normalize(
231
+ mean=config["mean"],
232
+ std=config["std"],
233
+ num_bands=len(config["mean"]),
234
+ selectors=[PrithviV2.INPUT_KEY],
235
+ )
236
+
237
+ def forward(
238
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
239
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
240
+ """Apply Prithvi normalization on the image.
241
+
242
+ Args:
243
+ input_dict: the input, which must contain the "image" key.
244
+ target_dict: the target
245
+
246
+ Returns:
247
+ normalized (input_dicts, target_dicts) tuple
248
+ """
249
+ return self.normalizer(input_dict, target_dict)
250
+
251
+
252
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
253
+ #
254
+ # Licensed under the Apache License, Version 2.0 (the "License");
255
+ # you may not use this file except in compliance with the License.
256
+ # You may obtain a copy of the License at
257
+ #
258
+ # http://www.apache.org/licenses/LICENSE-2.0
259
+ #
260
+ # Unless required by applicable law or agreed to in writing, software
261
+ # distributed under the License is distributed on an "AS IS" BASIS,
262
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
263
+ # See the License for the specific language governing permissions and
264
+ # limitations under the License.
265
+ # --------------------------------------------------------
266
+ # References:
267
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
268
+ # transformers: https://github.com/huggingface/transformers
269
+ # --------------------------------------------------------
270
+
271
+
272
+ def get_3d_sincos_pos_embed(
273
+ embed_dim: int,
274
+ grid_size: tuple[int, int, int] | list[int],
275
+ add_cls_token: bool = False,
276
+ ) -> torch.Tensor:
277
+ """Create 3D sin/cos positional embeddings.
278
+
279
+ Args:
280
+ embed_dim (int):
281
+ Embedding dimension.
282
+ grid_size (tuple[int, int, int] | list[int]):
283
+ The grid depth, height and width.
284
+ add_cls_token (bool, *optional*, defaults to False):
285
+ Whether or not to add a classification (CLS) token.
286
+
287
+ Returns:
288
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
289
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
290
+ """
291
+ assert embed_dim % 16 == 0
292
+
293
+ t_size, h_size, w_size = grid_size
294
+
295
+ w_embed_dim = embed_dim // 16 * 6
296
+ h_embed_dim = embed_dim // 16 * 6
297
+ t_embed_dim = embed_dim // 16 * 4
298
+
299
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
300
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
301
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
302
+
303
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
304
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
305
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
306
+
307
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
308
+
309
+ if add_cls_token:
310
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
311
+ return pos_embed
312
+
313
+
314
+ def get_1d_sincos_pos_embed_from_grid(
315
+ embed_dim: int, pos: torch.Tensor
316
+ ) -> torch.Tensor:
317
+ """embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)."""
318
+ if embed_dim % 2 != 0:
319
+ raise ValueError("embed_dim must be even")
320
+
321
+ omega = np.arange(embed_dim // 2, dtype=float)
322
+ omega /= embed_dim / 2.0
323
+ omega = 1.0 / 10000**omega # (D/2,)
324
+
325
+ pos = pos.reshape(-1) # (M,)
326
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
327
+
328
+ emb_sin = np.sin(out) # (M, D/2)
329
+ emb_cos = np.cos(out) # (M, D/2)
330
+
331
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
332
+ return emb
333
+
334
+
335
+ def _get_1d_sincos_embed_from_grid_torch(
336
+ embed_dim: int, pos: torch.Tensor
337
+ ) -> torch.Tensor:
338
+ """Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
339
+
340
+ embed_dim: output dimension for each position
341
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
342
+ out: (M, D)
343
+ """
344
+ assert embed_dim % 2 == 0
345
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
346
+
347
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
348
+ omega /= embed_dim / 2.0
349
+ omega = 1.0 / 10000**omega # (D/2,)
350
+
351
+ pos = pos.reshape(-1) # (M,)
352
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
353
+
354
+ emb_sin = torch.sin(out) # (M, D/2)
355
+ emb_cos = torch.cos(out) # (M, D/2)
356
+
357
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
358
+
359
+ return emb
360
+
361
+
362
+ def _init_weights(module: nn.Module) -> None:
363
+ """Initialize the weights."""
364
+ if isinstance(module, nn.Linear):
365
+ nn.init.xavier_uniform_(module.weight)
366
+ if module.bias is not None:
367
+ module.bias.data.zero_()
368
+ elif isinstance(module, nn.LayerNorm):
369
+ module.bias.data.zero_()
370
+ module.weight.data.fill_(1.0)
371
+
372
+
373
+ def _interpolate_pos_encoding(
374
+ pos_embed: torch.Tensor,
375
+ grid_size: tuple[int, int, int] | list[int],
376
+ patch_size: tuple[int, int, int] | list[int],
377
+ shape: tuple[int, int, int] | list[int],
378
+ embed_dim: int,
379
+ ) -> torch.Tensor:
380
+ """_interpolate_pos_encoding.
381
+
382
+ Adapted from:
383
+ - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
384
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
385
+ """
386
+ t, h, w = shape
387
+ t_patches = t // patch_size[0]
388
+ h_patches = h // patch_size[1]
389
+ w_patches = w // patch_size[2]
390
+
391
+ if [t_patches, h_patches, w_patches] == grid_size:
392
+ # No interpolation needed
393
+ return pos_embed
394
+ if t_patches != grid_size[0]:
395
+ # Re-compute pos embedding to handle changed num_frames
396
+ new_grid_size = (t_patches, *grid_size[1:])
397
+ new_pos_embed = get_3d_sincos_pos_embed(
398
+ pos_embed.shape[-1], new_grid_size, add_cls_token=True
399
+ )
400
+ new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
401
+ else:
402
+ new_grid_size = grid_size # type: ignore
403
+ new_pos_embed = pos_embed
404
+
405
+ class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
406
+
407
+ patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(
408
+ 0, 3, 1, 2
409
+ )
410
+
411
+ patch_pos_embed = nn.functional.interpolate(
412
+ patch_pos_embed,
413
+ size=(h_patches, w_patches),
414
+ mode="bicubic",
415
+ align_corners=True,
416
+ )
417
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
418
+
419
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
420
+
421
+
422
+ class PatchEmbed(nn.Module):
423
+ """3D version of timm.models.vision_transformer.PatchEmbed."""
424
+
425
+ def __init__(
426
+ self,
427
+ input_size: tuple[int, int, int] = (1, 224, 224),
428
+ patch_size: tuple[int, int, int] = (1, 16, 16),
429
+ in_chans: int = 3,
430
+ embed_dim: int = 768,
431
+ norm_layer: nn.Module | None = None,
432
+ flatten: bool = True,
433
+ bias: bool = True,
434
+ ) -> None:
435
+ """Init."""
436
+ super().__init__()
437
+ self.input_size = input_size
438
+ self.patch_size = patch_size
439
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
440
+ assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
441
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
442
+ self.flatten = flatten
443
+
444
+ self.proj = nn.Conv3d(
445
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
446
+ )
447
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
448
+
449
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
450
+ """Forward."""
451
+ B, C, T, H, W = x.shape
452
+
453
+ if (
454
+ T / self.patch_size[0] % 1
455
+ or H / self.patch_size[1] % 1
456
+ or W / self.patch_size[2] % 1
457
+ ):
458
+ warnings.warn(
459
+ f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
460
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks."
461
+ )
462
+
463
+ x = self.proj(x)
464
+ if self.flatten:
465
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
466
+ x = self.norm(x)
467
+ return x
468
+
469
+
470
+ class TemporalEncoder(nn.Module):
471
+ """TemporalEncoder."""
472
+
473
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
474
+ """Init."""
475
+ super().__init__()
476
+ self.embed_dim = embed_dim
477
+ self.year_embed_dim = embed_dim // 2
478
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
479
+
480
+ # If trainable, initialize scale with small number
481
+ if trainable_scale:
482
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
483
+ else:
484
+ self.register_buffer("scale", torch.ones(1))
485
+
486
+ def forward(
487
+ self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None
488
+ ) -> torch.Tensor:
489
+ """Forward.
490
+
491
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
492
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
493
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
494
+ """
495
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
496
+
497
+ year = _get_1d_sincos_embed_from_grid_torch(
498
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()
499
+ ).reshape(shape)
500
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
501
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()
502
+ ).reshape(shape)
503
+
504
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
505
+
506
+ if tokens_per_frame is not None:
507
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
508
+
509
+ return embedding # B, T*tokens_per_frame, embed_dim
510
+
511
+
512
+ class LocationEncoder(nn.Module):
513
+ """LocationEncoder."""
514
+
515
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
516
+ """Init."""
517
+ super().__init__()
518
+ self.embed_dim = embed_dim
519
+ self.lat_embed_dim = embed_dim // 2
520
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
521
+
522
+ # If trainable, initialize scale with small number
523
+ if trainable_scale:
524
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
525
+ else:
526
+ self.register_buffer("scale", torch.ones(1))
527
+
528
+ def forward(self, location_coords: torch.Tensor) -> torch.Tensor:
529
+ """location_coords: lat and lon info with shape (B, 2)."""
530
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
531
+
532
+ lat = _get_1d_sincos_embed_from_grid_torch(
533
+ self.lat_embed_dim, location_coords[:, 0].flatten()
534
+ ).reshape(shape)
535
+ lon = _get_1d_sincos_embed_from_grid_torch(
536
+ self.lon_embed_dim, location_coords[:, 1].flatten()
537
+ ).reshape(shape)
538
+
539
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
540
+
541
+ return embedding # B, 1, embed_dim
542
+
543
+
544
+ class PrithviViT(nn.Module):
545
+ """Prithvi ViT Encoder."""
546
+
547
+ def __init__(
548
+ self,
549
+ img_size: int | tuple[int, int] = 224,
550
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
551
+ num_frames: int = 1,
552
+ in_chans: int = 3,
553
+ embed_dim: int = 1024,
554
+ depth: int = 24,
555
+ num_heads: int = 16,
556
+ mlp_ratio: float = 4.0,
557
+ norm_layer: nn.Module = nn.LayerNorm,
558
+ coords_encoding: list[str] | None = None,
559
+ coords_scale_learn: bool = False,
560
+ drop_path: float = 0.0,
561
+ **kwargs: Any,
562
+ ) -> None:
563
+ """Init."""
564
+ super().__init__()
565
+
566
+ self.in_chans = in_chans
567
+ self.num_frames = num_frames
568
+ self.embed_dim = embed_dim
569
+ self.img_size = to_2tuple(img_size)
570
+ if isinstance(patch_size, int):
571
+ patch_size = (1, patch_size, patch_size)
572
+
573
+ # 3D patch embedding
574
+ self.patch_embed = PatchEmbed(
575
+ input_size=(num_frames,) + self.img_size,
576
+ patch_size=patch_size,
577
+ in_chans=in_chans,
578
+ embed_dim=embed_dim,
579
+ )
580
+ self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
581
+
582
+ # Optional temporal and location embedding
583
+ coords_encoding = coords_encoding or []
584
+ self.temporal_encoding = "time" in coords_encoding
585
+ self.location_encoding = "location" in coords_encoding
586
+ if self.temporal_encoding:
587
+ assert patch_size[0] == 1, (
588
+ f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
589
+ )
590
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
591
+ if self.location_encoding:
592
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
593
+
594
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
595
+ self.register_buffer(
596
+ "pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
597
+ )
598
+
599
+ # Transformer layers
600
+ self.blocks = []
601
+ for i in range(depth):
602
+ self.blocks.append(
603
+ Block(
604
+ embed_dim,
605
+ num_heads,
606
+ mlp_ratio,
607
+ qkv_bias=True,
608
+ norm_layer=norm_layer,
609
+ drop_path=drop_path,
610
+ )
611
+ )
612
+ self.blocks = nn.ModuleList(self.blocks)
613
+
614
+ self.norm = norm_layer(embed_dim)
615
+
616
+ self.initialize_weights()
617
+
618
+ def initialize_weights(self) -> None:
619
+ """initialize_weights."""
620
+ # initialize (and freeze) position embeddings by sin-cos embedding
621
+ pos_embed = get_3d_sincos_pos_embed(
622
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
623
+ )
624
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
625
+
626
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
627
+ w = self.patch_embed.proj.weight.data
628
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
629
+
630
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
631
+ torch.nn.init.normal_(self.cls_token, std=0.02)
632
+ self.apply(_init_weights)
633
+
634
+ def random_masking(
635
+ self,
636
+ sequence: torch.Tensor,
637
+ mask_ratio: float,
638
+ noise: None | torch.Tensor = None,
639
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
640
+ """Perform per-sample random masking by per-sample shuffling.
641
+
642
+ Per-sample shuffling is done by argsort random
643
+ noise.
644
+
645
+ Args:
646
+ sequence: (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
647
+ mask_ratio: (float): mask ratio to use.
648
+ noise: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
649
+ mainly used for testing purposes to control randomness and maintain the reproducibility
650
+ """
651
+ batch_size, seq_length, dim = sequence.shape
652
+ len_keep = int(seq_length * (1 - mask_ratio))
653
+
654
+ if noise is None:
655
+ noise = torch.rand(
656
+ batch_size, seq_length, device=sequence.device
657
+ ) # noise in [0, 1]
658
+
659
+ # sort noise for each sample
660
+ ids_shuffle = torch.argsort(noise, dim=1).to(
661
+ sequence.device
662
+ ) # ascend: small is keep, large is remove
663
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
664
+
665
+ # keep the first subset
666
+ ids_keep = ids_shuffle[:, :len_keep]
667
+ sequence_unmasked = torch.gather(
668
+ sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)
669
+ )
670
+
671
+ # generate the binary mask: 0 is keep, 1 is remove
672
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
673
+ mask[:, :len_keep] = 0
674
+ # unshuffle to get the binary mask
675
+ mask = torch.gather(mask, dim=1, index=ids_restore)
676
+
677
+ return sequence_unmasked, mask, ids_restore
678
+
679
+ def interpolate_pos_encoding(
680
+ self, sample_shape: tuple[int, int, int] | list[int]
681
+ ) -> torch.Tensor:
682
+ """interpolate_pos_encoding."""
683
+ pos_embed = _interpolate_pos_encoding(
684
+ pos_embed=self.pos_embed,
685
+ grid_size=self.patch_embed.grid_size,
686
+ patch_size=self.patch_embed.patch_size,
687
+ shape=sample_shape,
688
+ embed_dim=self.embed_dim,
689
+ )
690
+ return pos_embed
691
+
692
+ def forward(
693
+ self,
694
+ x: torch.Tensor,
695
+ temporal_coords: None | torch.Tensor = None,
696
+ location_coords: None | torch.Tensor = None,
697
+ mask_ratio: float = 0.75,
698
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
699
+ """Forward."""
700
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
701
+ # add time dim
702
+ x = x.unsqueeze(2)
703
+ sample_shape = x.shape[-3:]
704
+
705
+ # embed patches
706
+ x = self.patch_embed(x)
707
+
708
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
709
+ # add pos embed w/o cls token
710
+ x = x + pos_embed[:, 1:, :]
711
+
712
+ if self.temporal_encoding and temporal_coords is not None:
713
+ num_tokens_per_frame = x.shape[1] // self.num_frames
714
+ temporal_encoding = self.temporal_embed_enc(
715
+ temporal_coords, num_tokens_per_frame
716
+ )
717
+ x = x + temporal_encoding
718
+ if self.location_encoding and location_coords is not None:
719
+ location_encoding = self.location_embed_enc(location_coords)
720
+ x = x + location_encoding
721
+
722
+ # masking: length -> length * mask_ratio
723
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
724
+
725
+ # append cls token
726
+ cls_token = self.cls_token + pos_embed[:, :1, :]
727
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
728
+ x = torch.cat((cls_tokens, x), dim=1)
729
+
730
+ # apply Transformer blocks
731
+ for block in self.blocks:
732
+ x = block(x)
733
+ x = self.norm(x)
734
+
735
+ return x, mask, ids_restore
736
+
737
+ def forward_features(
738
+ self,
739
+ x: torch.Tensor,
740
+ temporal_coords: None | torch.Tensor = None,
741
+ location_coords: None | torch.Tensor = None,
742
+ ) -> list[torch.Tensor]:
743
+ """forward_features."""
744
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
745
+ # add time dim
746
+ x = x.unsqueeze(2)
747
+ sample_shape = x.shape[-3:]
748
+
749
+ # embed patches
750
+ x = self.patch_embed(x)
751
+
752
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
753
+ # add pos embed w/o cls token
754
+ x = x + pos_embed[:, 1:, :]
755
+
756
+ if self.temporal_encoding and temporal_coords is not None:
757
+ num_tokens_per_frame = x.shape[1] // self.num_frames
758
+ temporal_encoding = self.temporal_embed_enc(
759
+ temporal_coords, num_tokens_per_frame
760
+ )
761
+ x = x + temporal_encoding
762
+ if self.location_encoding and location_coords is not None:
763
+ location_encoding = self.location_embed_enc(location_coords)
764
+ x = x + location_encoding
765
+
766
+ # append cls token
767
+ cls_token = self.cls_token + pos_embed[:, :1, :]
768
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
769
+ x = torch.cat((cls_tokens, x), dim=1)
770
+
771
+ # apply Transformer blocks
772
+ out = []
773
+ for block in self.blocks:
774
+ x = block(x)
775
+ out.append(x.clone())
776
+
777
+ x = self.norm(x)
778
+ out[-1] = x
779
+ return out
780
+
781
+ def prepare_features_for_image_model(
782
+ self, features: list[torch.Tensor], t: int
783
+ ) -> list[torch.Tensor]:
784
+ """prepare_features_for_image_model."""
785
+ out = []
786
+ for x in features:
787
+ x_no_token = x[:, 1:, :]
788
+ number_of_tokens = x_no_token.shape[1]
789
+ tokens_per_timestep = number_of_tokens // t
790
+ h = int(np.sqrt(tokens_per_timestep))
791
+ encoded = rearrange(
792
+ x_no_token,
793
+ "batch (t h w) e -> batch t e h w",
794
+ e=self.embed_dim,
795
+ t=t,
796
+ h=h,
797
+ )
798
+ # mean along the time dimension
799
+ out.append(encoded.mean(dim=1))
800
+ return out
801
+
802
+
803
+ class MAEDecoder(nn.Module):
804
+ """Transformer Decoder used in the Prithvi MAE."""
805
+
806
+ def __init__(
807
+ self,
808
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
809
+ grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
810
+ in_chans: int = 3,
811
+ encoder_embed_dim: int = 1024,
812
+ decoder_embed_dim: int = 512,
813
+ depth: int = 8,
814
+ num_heads: int = 16,
815
+ mlp_ratio: float = 4.0,
816
+ norm_layer: nn.Module = nn.LayerNorm,
817
+ coords_encoding: list[str] | None = None,
818
+ coords_scale_learn: bool = False,
819
+ ) -> None:
820
+ """Init."""
821
+ super().__init__()
822
+
823
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
824
+ self.decoder_embed_dim = decoder_embed_dim
825
+ self.grid_size = grid_size
826
+ if isinstance(patch_size, int):
827
+ patch_size = (1, patch_size, patch_size)
828
+ self.patch_size = patch_size
829
+ self.num_frames = self.grid_size[0] * patch_size[0]
830
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
831
+
832
+ # Optional temporal and location embedding
833
+ coords_encoding = coords_encoding or []
834
+ self.temporal_encoding = "time" in coords_encoding
835
+ self.location_encoding = "location" in coords_encoding
836
+ if self.temporal_encoding:
837
+ self.temporal_embed_dec = TemporalEncoder(
838
+ decoder_embed_dim, coords_scale_learn
839
+ )
840
+ if self.location_encoding:
841
+ self.location_embed_dec = LocationEncoder(
842
+ decoder_embed_dim, coords_scale_learn
843
+ )
844
+
845
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
846
+
847
+ self.register_buffer(
848
+ "decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)
849
+ )
850
+
851
+ self.decoder_blocks = nn.ModuleList(
852
+ [
853
+ Block(
854
+ decoder_embed_dim,
855
+ num_heads,
856
+ mlp_ratio,
857
+ qkv_bias=True,
858
+ norm_layer=norm_layer,
859
+ )
860
+ for _ in range(depth)
861
+ ]
862
+ )
863
+
864
+ self.decoder_norm = norm_layer(decoder_embed_dim)
865
+ self.decoder_pred = nn.Linear(
866
+ decoder_embed_dim,
867
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
868
+ bias=True,
869
+ )
870
+
871
+ self.initialize_weights()
872
+
873
+ def initialize_weights(self) -> None:
874
+ """initialize_weights."""
875
+ # initialize (and freeze) position embeddings by sin-cos embedding
876
+ decoder_pos_embed = get_3d_sincos_pos_embed(
877
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
878
+ )
879
+ self.decoder_pos_embed.data.copy_(
880
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
881
+ )
882
+
883
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
884
+ torch.nn.init.normal_(self.mask_token, std=0.02)
885
+ self.apply(_init_weights)
886
+
887
+ def interpolate_pos_encoding(
888
+ self, sample_shape: tuple[int, int, int]
889
+ ) -> torch.Tensor:
890
+ """interpolate_pos_encoding."""
891
+ pos_embed = _interpolate_pos_encoding(
892
+ pos_embed=self.decoder_pos_embed,
893
+ grid_size=self.grid_size,
894
+ patch_size=self.patch_size,
895
+ shape=sample_shape,
896
+ embed_dim=self.decoder_embed_dim,
897
+ )
898
+
899
+ return pos_embed
900
+
901
+ def forward(
902
+ self,
903
+ hidden_states: torch.Tensor,
904
+ ids_restore: torch.Tensor,
905
+ temporal_coords: None | torch.Tensor = None,
906
+ location_coords: None | torch.Tensor = None,
907
+ input_size: list[int] | None = None,
908
+ ) -> torch.Tensor:
909
+ """Forward."""
910
+ # embed tokens
911
+ x = self.decoder_embed(hidden_states)
912
+ cls_token = x[:, :1, :]
913
+
914
+ # append mask tokens to sequence
915
+ mask_tokens = self.mask_token.repeat(
916
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
917
+ )
918
+ x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
919
+ # unshuffle
920
+ x = torch.gather(
921
+ x,
922
+ dim=1,
923
+ index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device),
924
+ )
925
+
926
+ # add pos embed
927
+ decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:]) # type: ignore
928
+ cls_token = cls_token + decoder_pos_embed[:, :1, :]
929
+ x = x + decoder_pos_embed[:, 1:, :]
930
+
931
+ if self.temporal_encoding and temporal_coords is not None:
932
+ num_tokens_per_frame = x.shape[1] // self.num_frames
933
+ temporal_encoding = self.temporal_embed_dec(
934
+ temporal_coords, num_tokens_per_frame
935
+ )
936
+ # Add temporal encoding w/o cls token
937
+ x = x + temporal_encoding
938
+ if self.location_encoding and location_coords is not None:
939
+ location_encoding = self.location_embed_dec(location_coords)
940
+ # Add location encoding w/o cls token
941
+ x = x + location_encoding
942
+
943
+ # append cls token
944
+ x = torch.cat([cls_token, x], dim=1)
945
+
946
+ # apply Transformer layers (blocks)
947
+ for block in self.decoder_blocks:
948
+ x = block(x)
949
+ x = self.decoder_norm(x)
950
+
951
+ # predictor projection
952
+ pred = self.decoder_pred(x)
953
+
954
+ # remove cls token
955
+ pred = pred[:, 1:, :]
956
+
957
+ return pred
958
+
959
+
960
+ class PrithviMAE(nn.Module):
961
+ """Prithvi Masked Autoencoder."""
962
+
963
+ def __init__(
964
+ self,
965
+ img_size: int | tuple[int, int] = 224,
966
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
967
+ num_frames: int = 4,
968
+ in_chans: int = 6,
969
+ embed_dim: int = 768,
970
+ depth: int = 12,
971
+ num_heads: int = 12,
972
+ decoder_embed_dim: int = 512,
973
+ decoder_depth: int = 8,
974
+ decoder_num_heads: int = 16,
975
+ mlp_ratio: float = 4.0,
976
+ norm_layer: nn.Module = nn.LayerNorm,
977
+ norm_pix_loss: bool = False,
978
+ coords_encoding: list[str] | None = None,
979
+ coords_scale_learn: bool = False,
980
+ drop_path: float = 0.0,
981
+ mask_ratio: float = 0.75,
982
+ **kwargs: Any,
983
+ ):
984
+ """Init."""
985
+ super().__init__()
986
+
987
+ self.encoder = PrithviViT(
988
+ img_size=img_size,
989
+ num_frames=num_frames,
990
+ patch_size=patch_size,
991
+ in_chans=in_chans,
992
+ embed_dim=embed_dim,
993
+ depth=depth,
994
+ num_heads=num_heads,
995
+ mlp_ratio=mlp_ratio,
996
+ norm_layer=norm_layer,
997
+ coords_encoding=coords_encoding,
998
+ coords_scale_learn=coords_scale_learn,
999
+ drop_path=drop_path,
1000
+ )
1001
+
1002
+ self.decoder = MAEDecoder(
1003
+ patch_size=patch_size,
1004
+ grid_size=self.encoder.patch_embed.grid_size,
1005
+ in_chans=in_chans,
1006
+ encoder_embed_dim=embed_dim,
1007
+ decoder_embed_dim=decoder_embed_dim,
1008
+ depth=decoder_depth,
1009
+ num_heads=decoder_num_heads,
1010
+ mlp_ratio=mlp_ratio,
1011
+ norm_layer=norm_layer,
1012
+ coords_encoding=coords_encoding,
1013
+ coords_scale_learn=coords_scale_learn,
1014
+ )
1015
+
1016
+ self.mask_ratio = mask_ratio
1017
+ self.norm_pix_loss = norm_pix_loss
1018
+ self.out_channels = self.encoder.out_channels
1019
+
1020
+ def patchify(self, pixel_values: torch.Tensor) -> torch.Tensor:
1021
+ """Patchify.
1022
+
1023
+ Args:
1024
+ pixel_values: (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
1025
+ Pixel values.
1026
+
1027
+ Returns:
1028
+ torch.FloatTensor of shape
1029
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
1030
+ Patchified pixel values.
1031
+ """
1032
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
1033
+ num_channels = self.encoder.in_chans
1034
+
1035
+ # patchify
1036
+ patchified_pixel_values = rearrange(
1037
+ pixel_values,
1038
+ "b c (t s) (h p) (w q) -> b (t h w) (s p q c)",
1039
+ c=num_channels,
1040
+ s=patch_size_t,
1041
+ p=patch_size_h,
1042
+ q=patch_size_w,
1043
+ )
1044
+
1045
+ return patchified_pixel_values
1046
+
1047
+ def unpatchify(
1048
+ self,
1049
+ patchified_pixel_values: torch.Tensor,
1050
+ image_size: tuple[int, int] | None = None,
1051
+ ) -> torch.Tensor:
1052
+ """Unpatchify.
1053
+
1054
+ Args:
1055
+ patchified_pixel_values: (`torch.FloatTensor` of shape
1056
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
1057
+ Patchified pixel values.
1058
+ image_size: (`tuple[int, int]`, *optional*):
1059
+ Original image size.
1060
+
1061
+ Returns:
1062
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
1063
+ Pixel values.
1064
+ """
1065
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
1066
+ image_size = (
1067
+ to_2tuple(image_size) if image_size is not None else self.encoder.img_size
1068
+ )
1069
+ original_height, original_width = image_size
1070
+ num_patches_h = original_height // patch_size_h
1071
+ num_patches_w = original_width // patch_size_w
1072
+ num_channels = self.encoder.in_chans
1073
+
1074
+ pixel_values = rearrange(
1075
+ patchified_pixel_values,
1076
+ "b (t h w) (s p q c) -> b c (t s) (h p) (w q)",
1077
+ c=num_channels,
1078
+ h=num_patches_h,
1079
+ w=num_patches_w,
1080
+ s=patch_size_t,
1081
+ p=patch_size_h,
1082
+ q=patch_size_w,
1083
+ )
1084
+ return pixel_values
1085
+
1086
+ def forward_loss(
1087
+ self, pixel_values: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor
1088
+ ) -> torch.Tensor:
1089
+ """forward_loss.
1090
+
1091
+ Args:
1092
+ pixel_values: (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
1093
+ Pixel values.
1094
+ pred: (`torch.FloatTensor` of shape
1095
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
1096
+ Predicted pixel values.
1097
+ mask: (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1098
+ Tensor indicating which patches are masked (1) and which are not (0).
1099
+
1100
+ Returns:
1101
+ `torch.FloatTensor`: Pixel reconstruction loss.
1102
+ """
1103
+ target = self.patchify(pixel_values)
1104
+ if self.norm_pix_loss:
1105
+ mean = target.mean(dim=-1, keepdim=True)
1106
+ var = target.var(dim=-1, keepdim=True)
1107
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
1108
+
1109
+ loss = (pred - target) ** 2
1110
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
1111
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
1112
+ return loss
1113
+
1114
+ def forward(
1115
+ self,
1116
+ pixel_values: torch.Tensor,
1117
+ temporal_coords: None | torch.Tensor = None,
1118
+ location_coords: None | torch.Tensor = None,
1119
+ mask_ratio: float | None = None,
1120
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1121
+ """Forward."""
1122
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
1123
+ # add time dim
1124
+ pixel_values = pixel_values.unsqueeze(2)
1125
+
1126
+ mask_ratio = mask_ratio or self.mask_ratio
1127
+ latent, mask, ids_restore = self.encoder(
1128
+ pixel_values, temporal_coords, location_coords, mask_ratio
1129
+ )
1130
+ pred = self.decoder(
1131
+ latent,
1132
+ ids_restore,
1133
+ temporal_coords,
1134
+ location_coords,
1135
+ input_size=pixel_values.shape,
1136
+ )
1137
+ loss = self.forward_loss(pixel_values, pred, mask)
1138
+ return loss, pred, mask
1139
+
1140
+ def forward_features(
1141
+ self,
1142
+ x: torch.Tensor,
1143
+ temporal_coords: None | torch.Tensor = None,
1144
+ location_coords: None | torch.Tensor = None,
1145
+ ) -> list[torch.Tensor]:
1146
+ """forward_features."""
1147
+ return self.encoder.forward_features(x, temporal_coords, location_coords)