olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. olmoearth_pretrain_minimal/__init__.py +16 -0
  2. olmoearth_pretrain_minimal/model_loader.py +123 -0
  3. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
  4. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
  5. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
  6. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
  7. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
  8. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
  9. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
  10. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
  11. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
  12. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
  13. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
  14. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
  15. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
  16. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
  17. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
  18. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
  19. olmoearth_pretrain_minimal/test.py +51 -0
  20. olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
  21. olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
  22. olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
  23. olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
  24. olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2219 @@
1
+ """Model code for the OlmoEarth Pretrain model."""
2
+
3
+ import logging
4
+ import math
5
+ from dataclasses import dataclass
6
+ from enum import StrEnum
7
+ from typing import Any, NamedTuple
8
+
9
+ import torch
10
+ from einops import rearrange, reduce, repeat
11
+ from torch import Tensor, nn
12
+ from torch.distributed.fsdp import fully_shard
13
+
14
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
15
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import (
16
+ BASE_GSD,
17
+ Modality,
18
+ ModalitySpec,
19
+ get_modality_specs_from_names,
20
+ )
21
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample, MaskValue
22
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.attention import Block
23
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.encodings import (
24
+ get_1d_sincos_pos_encoding,
25
+ get_2d_sincos_pos_encoding_with_resolution,
26
+ get_month_encoding_table,
27
+ )
28
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_patch_embed import (
29
+ FlexiPatchEmbed,
30
+ FlexiPatchReconstruction,
31
+ )
32
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.tokenization import TokenizationConfig
33
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import get_cumulative_sequence_lengths
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def get_modalities_to_process(
39
+ available_modalities: list[str], supported_modality_names: list[str]
40
+ ) -> list[str]:
41
+ """Get the modalities to process."""
42
+ modalities_to_process = set(supported_modality_names).intersection(
43
+ set(available_modalities)
44
+ )
45
+ return list(modalities_to_process)
46
+
47
+
48
+ def return_modalities_from_dict(
49
+ per_modality_input_tokens: dict[str, Tensor],
50
+ ) -> list[str]:
51
+ """Return the modalities from a dictionary of per modality input tokens."""
52
+ return [
53
+ key for key in per_modality_input_tokens.keys() if not key.endswith("_mask")
54
+ ]
55
+
56
+
57
+ class PoolingType(StrEnum):
58
+ """Strategy for pooling the tokens."""
59
+
60
+ MAX = "max"
61
+ MEAN = "mean"
62
+
63
+
64
+ class TokensAndMasks(NamedTuple):
65
+ """Output to compute the loss on.
66
+
67
+ Args:
68
+ sentinel2: sentinel 2 data of shape (B, P_H, P_W, T, Band_Sets, D)
69
+ sentinel2_mask: sentinel 2 mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
70
+ sentinel1: sentinel 1 data of shape (B, P_H, P_W, T, Band_Sets, D)
71
+ sentinel1_mask: sentinel 1 mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
72
+ worldcover: worldcover data of shape (B, P_H, P_W, T, Band_Sets, D)
73
+ worldcover_mask: worldcover mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
74
+ latlon: lat lon data containing geographical coordinates
75
+ latlon_mask: lat lon mask indicating which coordinates are masked/unmasked
76
+ openstreetmap_raster: openstreetmap raster data of shape (B, P_H, P_W, T, Band_Sets, D)
77
+ openstreetmap_raster_mask: openstreetmap raster mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
78
+ """
79
+
80
+ sentinel2_l2a: Tensor | None = None
81
+ sentinel2_l2a_mask: Tensor | None = None
82
+ sentinel1: Tensor | None = None
83
+ sentinel1_mask: Tensor | None = None
84
+ worldcover: Tensor | None = None
85
+ worldcover_mask: Tensor | None = None
86
+ latlon: Tensor | None = None
87
+ latlon_mask: Tensor | None = None
88
+ openstreetmap_raster: Tensor | None = None
89
+ openstreetmap_raster_mask: Tensor | None = None
90
+ srtm: Tensor | None = None
91
+ srtm_mask: Tensor | None = None
92
+ landsat: Tensor | None = None
93
+ landsat_mask: Tensor | None = None
94
+ naip: Tensor | None = None
95
+ naip_mask: Tensor | None = None
96
+ naip_10: Tensor | None = None
97
+ naip_10_mask: Tensor | None = None
98
+ gse: Tensor | None = None
99
+ gse_mask: Tensor | None = None
100
+ cdl: Tensor | None = None
101
+ cdl_mask: Tensor | None = None
102
+ worldpop: Tensor | None = None
103
+ worldpop_mask: Tensor | None = None
104
+ worldcereal: Tensor | None = None
105
+ worldcereal_mask: Tensor | None = None
106
+ wri_canopy_height_map: Tensor | None = None
107
+ wri_canopy_height_map_mask: Tensor | None = None
108
+ era5_10: Tensor | None = None
109
+ era5_10_mask: Tensor | None = None
110
+
111
+ @property
112
+ def device(self) -> torch.device:
113
+ """Get the device of the tokens and masks."""
114
+ if self.sentinel2_l2a is not None:
115
+ return self.sentinel2_l2a.device
116
+ else:
117
+ # look for any other modality that is not None
118
+ for modality in self._fields:
119
+ if getattr(self, modality) is not None:
120
+ return getattr(self, modality).device
121
+ raise ValueError("No data to get device from")
122
+
123
+ # TODO: It seems like we want a lot of our named tuples to have this functionality so we should probably create a utility base class for the named tuples and double subclass
124
+ @classmethod
125
+ def get_masked_modality_name(cls, modality: str) -> str:
126
+ """Get the masked modality name."""
127
+ return f"{modality}_mask"
128
+
129
+ def as_dict(self, return_none: bool = True) -> dict[str, Any]:
130
+ """Convert the namedtuple to a dictionary.
131
+
132
+ Returns:
133
+ Dictionary representation of the namedtuple.
134
+ """
135
+ return_dict = {}
136
+ for field in self._fields:
137
+ val = getattr(self, field)
138
+ if return_none:
139
+ return_dict[field] = val
140
+ else:
141
+ if val is not None:
142
+ return_dict[field] = val
143
+ return return_dict
144
+
145
+ @property
146
+ def modalities(self) -> list[str]:
147
+ """Return all data fields."""
148
+ return [
149
+ x
150
+ for x in self._fields
151
+ if not x.endswith("mask") and getattr(self, x) is not None
152
+ ]
153
+
154
+ def get_shape_dict(self) -> dict[str, tuple]:
155
+ """Return a dictionary of the shapes of the fields."""
156
+ return {x: getattr(self, x).shape for x in self._fields}
157
+
158
+ @staticmethod
159
+ def _flatten(x: Tensor) -> Tensor:
160
+ return rearrange(x, "b ... d -> b (...) d")
161
+
162
+ def flatten_tokens_and_masks(
163
+ self, return_lists: bool = False
164
+ ) -> tuple[Tensor, Tensor]:
165
+ """Return the flattened tokens and masks.
166
+
167
+ Args:
168
+ return_lists: If True, return the original lists before concatenation.
169
+ If False, return concatenated tensors.
170
+
171
+ Tokens will have shape [B, T, D] and masks will have shape [B, T]
172
+ """
173
+ flattened_x, flattened_masks = [], []
174
+ for attr_name in self.modalities:
175
+ mask_attr_name = self.get_masked_modality_name(attr_name)
176
+ attr = getattr(self, attr_name)
177
+ masked_attr = getattr(self, mask_attr_name)
178
+ if attr is not None:
179
+ if masked_attr is None:
180
+ raise ValueError(
181
+ f"Can't have present {attr_name} but None {mask_attr_name}"
182
+ )
183
+ masked_attr = masked_attr.unsqueeze(dim=-1)
184
+ flattened_x.append(self._flatten(attr))
185
+ flattened_masks.append(self._flatten(masked_attr))
186
+
187
+ if return_lists:
188
+ # Remove the extra dimension from the masks
189
+ flattened_masks = [mask[:, :, 0] for mask in flattened_masks]
190
+ return flattened_x, flattened_masks
191
+
192
+ x = torch.cat(flattened_x, dim=1)
193
+ masks = torch.cat(flattened_masks, dim=1)[:, :, 0]
194
+ return x, masks
195
+
196
+ def pool_spatially_and_concat_modalities(self) -> Tensor:
197
+ """Pool the modalities across time to get spatial features and concatenate the features."""
198
+ spatial_stacked_features = []
199
+ for attr_name in self.modalities:
200
+ if Modality.get(attr_name).is_spatial:
201
+ mask_attr_name = self.get_masked_modality_name(attr_name)
202
+ masked_attr = getattr(self, mask_attr_name)
203
+ if masked_attr is None:
204
+ continue
205
+ if (masked_attr == MaskValue.ONLINE_ENCODER.value).all():
206
+ attr = getattr(self, attr_name)
207
+ # only mean in temporal dimension
208
+ pooled_attr = torch.mean(attr, dim=(-3))
209
+ spatial_stacked_features.append(pooled_attr)
210
+ if len(spatial_stacked_features) == 0:
211
+ raise ValueError("Missing unmasked spatial modalities for spatial pooling.")
212
+ # Concatenate along the band sets dimension instead of stacking
213
+ spatial_stacked_features = torch.cat(spatial_stacked_features, dim=-2)
214
+ return spatial_stacked_features
215
+
216
+ def pool_spatially(self, pooling_type: PoolingType) -> Tensor:
217
+ """Pool the modalities across time to get spatial features."""
218
+ spatial_average = []
219
+ for attr_name in self.modalities:
220
+ if Modality.get(attr_name).is_spatial:
221
+ mask_attr_name = self.get_masked_modality_name(attr_name)
222
+ masked_attr = getattr(self, mask_attr_name)
223
+ if masked_attr is None:
224
+ continue
225
+ if (masked_attr == MaskValue.ONLINE_ENCODER.value).all():
226
+ attr = getattr(self, attr_name)
227
+ # pool across time and bandset dimensions
228
+ if pooling_type == PoolingType.MEAN:
229
+ spatial_average.append(torch.mean(attr, dim=(-2, -3)))
230
+ else:
231
+ spatial_average.append(
232
+ torch.max(torch.max(attr, dim=-2).values, dim=-2).values
233
+ )
234
+ if len(spatial_average) == 0:
235
+ raise ValueError("Missing unmasked spatial modalities for spatial pooling.")
236
+ spatial_average_t = torch.stack(spatial_average, dim=-1)
237
+ if pooling_type == PoolingType.MEAN:
238
+ return spatial_average_t.mean(dim=-1)
239
+ else:
240
+ return spatial_average_t.max(dim=-1).values
241
+
242
+ def pool_instance_wise(self, pooling_type: PoolingType) -> Tensor:
243
+ """Pool all the tokens in the instance."""
244
+ x, mask = self.flatten_tokens_and_masks()
245
+ # 1s for online encoder, 0s elsewhere
246
+ mask = (mask == MaskValue.ONLINE_ENCODER.value).long()
247
+ x_for_pooling = x * mask.unsqueeze(-1)
248
+ if pooling_type == PoolingType.MAX:
249
+ x_for_pooling = x_for_pooling.masked_fill(
250
+ ~mask.bool().unsqueeze(-1), -float("inf")
251
+ )
252
+ return x_for_pooling.max(dim=1).values
253
+ elif pooling_type == PoolingType.MEAN:
254
+ num_encoded_tokens = torch.sum(mask, -1, keepdim=True)
255
+ logger.debug(f"num_encoded_tokens: {num_encoded_tokens}")
256
+ if (num_encoded_tokens == 0).any():
257
+ raise ValueError(
258
+ f"num_encoded_tokens is 0 for some samples {num_encoded_tokens}"
259
+ )
260
+ return x_for_pooling.sum(dim=1) / num_encoded_tokens
261
+ else:
262
+ raise ValueError(f"Invalid pooling type: {pooling_type}")
263
+
264
+ def pool_unmasked_tokens(
265
+ self,
266
+ pooling_type: PoolingType = PoolingType.MAX,
267
+ spatial_pooling: bool = False,
268
+ concat_features: bool = False,
269
+ ) -> Tensor:
270
+ """Pool the unmasked tokens.
271
+
272
+ Args:
273
+ pooling_type: Pooling type for the tokens
274
+ spatial_pooling: Whether to keep the spatial dimensions when pooling. If true,
275
+ this expects the masks within a spatial modality to be consistent (e.g. all
276
+ s2 tokens would have the same mask.)
277
+ concat_features: Whether to concatenate the features instead of averaging them, only enabled for spatial pooling as of now,
278
+ requires no masked out tokens
279
+ """
280
+ if concat_features and spatial_pooling:
281
+ return self.pool_spatially_and_concat_modalities()
282
+ if concat_features:
283
+ raise ValueError("concat_features is not supported for non-spatial pooling")
284
+ if not spatial_pooling:
285
+ return self.pool_instance_wise(pooling_type)
286
+ else:
287
+ return self.pool_spatially(pooling_type)
288
+
289
+
290
+ class ProjectAndAggregate(nn.Module):
291
+ """Module that applies a linear projection to tokens and masks."""
292
+
293
+ def __init__(
294
+ self,
295
+ embedding_size: int,
296
+ num_layers: int,
297
+ aggregate_then_project: bool = True,
298
+ ):
299
+ """Initialize the linear module.
300
+
301
+ embedding_size: The embedding size of the input TokensAndMasks
302
+ num_layers: The number of layers to use in the projection. If >1, then
303
+ a ReLU activation will be applied between layers
304
+ aggregate_then_project: If True, then we will average the tokens before applying
305
+ the projection. If False, we will apply the projection first.
306
+ """
307
+ super().__init__()
308
+ projections = [nn.Linear(embedding_size, embedding_size)]
309
+ for _ in range(1, num_layers):
310
+ projections.append(nn.ReLU())
311
+ projections.append(nn.Linear(embedding_size, embedding_size))
312
+ self.projection = nn.Sequential(*projections)
313
+ self.aggregate_then_project = aggregate_then_project
314
+
315
+ def apply_aggregate_then_project(
316
+ self, x: TokensAndMasks | torch.Tensor
317
+ ) -> torch.Tensor:
318
+ """Apply the aggregate operation to the input."""
319
+ if isinstance(x, TokensAndMasks):
320
+ pooled_for_contrastive = x.pool_unmasked_tokens(
321
+ PoolingType.MEAN, spatial_pooling=False
322
+ )
323
+ elif isinstance(x, torch.Tensor):
324
+ pooled_for_contrastive = reduce(x, "b ... d -> b d", "mean")
325
+ else:
326
+ raise ValueError(f"Invalid input type: {type(x)}")
327
+ return self.projection(pooled_for_contrastive)
328
+
329
+ def apply_project_then_aggregate(
330
+ self, x: TokensAndMasks | torch.Tensor
331
+ ) -> torch.Tensor:
332
+ """Apply the project operation to the input then aggregate."""
333
+ if isinstance(x, TokensAndMasks):
334
+ decoder_emedded_dict = x._asdict()
335
+ for modality in x.modalities:
336
+ x_modality = getattr(x, modality)
337
+ # Are these normalizations masked correctly?
338
+ x_modality = self.projection(x_modality)
339
+ masked_modality_name = x.get_masked_modality_name(modality)
340
+ decoder_emedded_dict[modality] = x_modality
341
+ decoder_emedded_dict[masked_modality_name] = getattr(
342
+ x, masked_modality_name
343
+ )
344
+ x_projected = TokensAndMasks(**decoder_emedded_dict)
345
+ projected_pooled = x_projected.pool_unmasked_tokens(
346
+ PoolingType.MEAN, spatial_pooling=False
347
+ )
348
+ elif isinstance(x, torch.Tensor):
349
+ x_projected = self.projection(x)
350
+ projected_pooled = reduce(x_projected, "b ... d -> b d", "mean")
351
+ else:
352
+ raise ValueError(f"Invalid input type: {type(x)}")
353
+ return projected_pooled
354
+
355
+ def forward(self, x: TokensAndMasks | torch.Tensor) -> torch.Tensor:
356
+ """Apply a (non)linear projection to an input TokensAndMasks.
357
+
358
+ This can be applied either before or after pooling the tokens.
359
+ """
360
+ return (
361
+ self.apply_aggregate_then_project(x)
362
+ if self.aggregate_then_project
363
+ else self.apply_project_then_aggregate(x)
364
+ )
365
+
366
+
367
+ class MultiModalPatchEmbeddings(nn.Module):
368
+ """Module that patchifies and encodes the input data for multiple modalities."""
369
+
370
+ def __init__(
371
+ self,
372
+ supported_modality_names: list[str],
373
+ max_patch_size: int,
374
+ embedding_size: int,
375
+ tokenization_config: TokenizationConfig | None = None,
376
+ ):
377
+ """Initialize the patch embeddings.
378
+
379
+ Args:
380
+ supported_modality_names: Which modalities from Modality this model
381
+ instantiation supports
382
+ max_patch_size: Maximum size of patches
383
+ embedding_size: Size of embeddings
384
+ tokenization_config: Optional config for custom band groupings
385
+ """
386
+ super().__init__()
387
+ self.max_patch_size = max_patch_size
388
+ self.embedding_size = embedding_size
389
+ self.supported_modality_names = supported_modality_names
390
+ self.tokenization_config = tokenization_config or TokenizationConfig()
391
+ # TODO: want to be able to remove certain bands and modalities
392
+ self.per_modality_embeddings = nn.ModuleDict({})
393
+
394
+ for modality in self.supported_modality_names:
395
+ self.per_modality_embeddings[modality] = (
396
+ self._get_patch_embedding_module_for_modality(modality)
397
+ )
398
+
399
+ # For every patch embedding module we want to create a unique buffer
400
+ # for selecting the correct band indices from the data tensor
401
+ for modality in self.supported_modality_names:
402
+ for idx, bandset_indices in enumerate(
403
+ self.tokenization_config.get_bandset_indices(modality)
404
+ ):
405
+ buffer_name = self._get_buffer_name(modality, idx)
406
+ banset_indices_tensor = torch.tensor(bandset_indices, dtype=torch.long)
407
+ self.register_buffer(
408
+ buffer_name, banset_indices_tensor, persistent=False
409
+ )
410
+
411
+ # Create a dictionary of per modality index tensors to do index select with registered buffer
412
+
413
+ @staticmethod
414
+ def _get_buffer_name(modality: str, idx: int) -> str:
415
+ """Get the buffer name."""
416
+ return f"{modality}__{idx}_buffer"
417
+
418
+ @staticmethod
419
+ def _get_embedding_module_name(modality: str, idx: int) -> str:
420
+ """Get the embedding module name.
421
+
422
+ Module Dicts require string keys
423
+ """
424
+ return f"{modality}__{idx}"
425
+
426
+ def _get_patch_embedding_module_for_modality(self, modality: str) -> nn.Module:
427
+ """Get the patch embedding module for a modality."""
428
+ modality_spec = Modality.get(modality)
429
+ # Get bandset indices from tokenization config (may be overridden)
430
+ bandset_indices = self.tokenization_config.get_bandset_indices(modality)
431
+
432
+ # Based on the modality name we choose the way to embed the data
433
+ # I likely will need to know about what the embedding strategy is in the forward as well
434
+ # Static modality
435
+ if not modality_spec.is_spatial:
436
+ # static in space
437
+ return nn.ModuleDict(
438
+ {
439
+ self._get_embedding_module_name(modality, idx): nn.Linear(
440
+ len(channel_set_idxs), self.embedding_size
441
+ )
442
+ for idx, channel_set_idxs in enumerate(bandset_indices)
443
+ }
444
+ )
445
+ else:
446
+ return nn.ModuleDict(
447
+ {
448
+ self._get_embedding_module_name(modality, idx): FlexiPatchEmbed(
449
+ in_chans=len(channel_set_idxs),
450
+ embedding_size=self.embedding_size,
451
+ patch_size_at_16=self.max_patch_size,
452
+ modality_spec=modality_spec,
453
+ )
454
+ for idx, channel_set_idxs in enumerate(bandset_indices)
455
+ }
456
+ )
457
+
458
+ def apply_embedding_to_modality(
459
+ self,
460
+ modality: str,
461
+ input_data: MaskedOlmoEarthSample,
462
+ patch_size: int,
463
+ fast_pass: bool = False,
464
+ ) -> tuple[Tensor, Tensor]:
465
+ """Apply embedding to a modality."""
466
+ logger.debug(f"applying embedding to modality:{modality}")
467
+ masked_modality_name = input_data.get_masked_modality_name(modality)
468
+ modality_mask = getattr(input_data, masked_modality_name)
469
+ modality_data = getattr(input_data, modality)
470
+
471
+ modality_spec = Modality.get(modality)
472
+ num_band_sets = self.tokenization_config.get_num_bandsets(modality)
473
+
474
+ modality_tokens, modality_masks = [], []
475
+ for idx in range(num_band_sets):
476
+ modality_specific_kwargs = {}
477
+ if not modality_spec.is_spatial:
478
+ # static in time
479
+ token_mask = modality_mask[..., idx]
480
+ else:
481
+ token_mask = modality_mask[
482
+ :,
483
+ 0 :: patch_size * modality_spec.image_tile_size_factor,
484
+ 0 :: patch_size * modality_spec.image_tile_size_factor,
485
+ ...,
486
+ idx,
487
+ ]
488
+ modality_specific_kwargs = {"patch_size": patch_size}
489
+ # In the fast pass we want to the sync that comes with checking for online encoder
490
+ if fast_pass or (token_mask == MaskValue.ONLINE_ENCODER.value).any():
491
+ buffer_name = self._get_buffer_name(modality, idx)
492
+ patchified_data = torch.index_select(
493
+ modality_data, -1, getattr(self, buffer_name)
494
+ )
495
+ embedding_module = self.per_modality_embeddings[modality][
496
+ self._get_embedding_module_name(modality, idx)
497
+ ]
498
+ patchified_data = embedding_module(
499
+ patchified_data, **modality_specific_kwargs
500
+ )
501
+ else:
502
+ mask_shape = token_mask.shape + (self.embedding_size,)
503
+ patchified_data = torch.zeros(
504
+ mask_shape, dtype=modality_data.dtype, device=token_mask.device
505
+ )
506
+
507
+ modality_tokens.append(patchified_data)
508
+ modality_masks.append(token_mask)
509
+ return torch.stack(modality_tokens, dim=-2), torch.stack(modality_masks, dim=-1)
510
+
511
+ @staticmethod
512
+ def is_any_data_seen_by_encoder(modality_mask: Tensor) -> bool:
513
+ """Check if any data is seen by the encoder."""
514
+ return (MaskValue.ONLINE_ENCODER.value == modality_mask).any()
515
+
516
+ def apply_compile(self) -> None:
517
+ """Apply torch.compile to the model."""
518
+ self.compile(dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
519
+
520
+ def forward(
521
+ self,
522
+ input_data: MaskedOlmoEarthSample,
523
+ patch_size: int,
524
+ fast_pass: bool = False,
525
+ ) -> dict[str, Tensor]:
526
+ """Return flexibly patchified embeddings for each modality of the input data.
527
+
528
+ Given a [B, H, W, (T), C] inputs, returns a [B, H, W, (T), b_s, D] output.
529
+
530
+ We assume that the spatial masks are consistent for the given patch size,
531
+ so that if patch_size == 2 then one possible mask would be
532
+ [0, 0, 1, 1]
533
+ [0, 0, 1, 1]
534
+ [1, 1, 0, 0]
535
+ [1, 1, 0, 0]
536
+ for the H, W dimensions
537
+ """
538
+ output_dict = {}
539
+ modalities_to_process = get_modalities_to_process(
540
+ input_data.modalities, self.supported_modality_names
541
+ )
542
+ for modality in modalities_to_process:
543
+ modality_tokens, modality_masks = self.apply_embedding_to_modality(
544
+ modality, input_data, patch_size, fast_pass
545
+ )
546
+ output_dict[modality] = modality_tokens
547
+ modality_mask_name = input_data.get_masked_modality_name(modality)
548
+ output_dict[modality_mask_name] = modality_masks
549
+ return output_dict
550
+
551
+
552
+ class Reconstructor(nn.Module):
553
+ """Module that patchifies and encodes the input data."""
554
+
555
+ def __init__(
556
+ self,
557
+ decoder: nn.Module,
558
+ supported_modalities: list[ModalitySpec],
559
+ max_patch_size: int,
560
+ tokenization_config: TokenizationConfig | None = None,
561
+ ):
562
+ """Initialize the patch embeddings.
563
+
564
+ Args:
565
+ decoder: Predictor nn module to use on before reconstructor on input
566
+ supported_modalities: Which modalities from Modality this model
567
+ instantiation supports
568
+ max_patch_size: Maximum size of patches
569
+ tokenization_config: Optional config for custom band groupings
570
+ """
571
+ super().__init__()
572
+ self.max_patch_size = max_patch_size
573
+ self.embedding_size = decoder.output_embedding_size
574
+ self.supported_modalities = supported_modalities
575
+ self.tokenization_config = tokenization_config or TokenizationConfig()
576
+ self.decoder = decoder
577
+ # TODO: want to be able to remove certain bands and modalities
578
+ self.per_modality_reconstructions = nn.ModuleDict({})
579
+ for modality in self.supported_modalities:
580
+ self.per_modality_reconstructions[modality.name] = (
581
+ self._get_patch_reconstruction_module_for_modality(modality)
582
+ )
583
+
584
+ def apply_compile(self) -> None:
585
+ """Apply torch.compile to the model."""
586
+ self.decoder.apply_compile()
587
+
588
+ def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
589
+ """Apply FSDP to the model."""
590
+ self.decoder.apply_fsdp(**fsdp_kwargs)
591
+
592
+ @staticmethod
593
+ def _get_reconstruction_module_name(modality: str, idx: int) -> str:
594
+ """Get the reconstruction module name.
595
+
596
+ Module Dicts require string keys
597
+ """
598
+ return f"{modality}__{idx}"
599
+
600
+ def _get_patch_reconstruction_module_for_modality(
601
+ self, modality: ModalitySpec
602
+ ) -> nn.Module:
603
+ """Get the patch reconstruction module for a modality."""
604
+ # Get bandset indices from tokenization config (may be overridden)
605
+ bandset_indices = self.tokenization_config.get_bandset_indices(modality.name)
606
+
607
+ # Based on the modality name we choose the way to embed the data
608
+ # I likely will need to know about what the embedding strategy is in the forward as well
609
+ # Static modality
610
+ if modality.get_tile_resolution() == 0:
611
+ # static in space
612
+ return nn.ModuleDict(
613
+ {
614
+ self._get_reconstruction_module_name(modality.name, idx): nn.Linear(
615
+ self.embedding_size, len(channel_set_idxs)
616
+ )
617
+ for idx, channel_set_idxs in enumerate(bandset_indices)
618
+ }
619
+ )
620
+ else:
621
+ return nn.ModuleDict(
622
+ {
623
+ self._get_reconstruction_module_name(
624
+ modality.name, idx
625
+ ): FlexiPatchReconstruction(
626
+ out_chans=len(channel_set_idxs),
627
+ embedding_size=self.embedding_size,
628
+ max_patch_size=self.max_patch_size,
629
+ )
630
+ for idx, channel_set_idxs in enumerate(bandset_indices)
631
+ }
632
+ )
633
+
634
+ # TODO: Likely we want a single object that stores all the data related configuration etc per modality including channel grous bands patch size etc
635
+ def apply_reconstruction_to_modality(
636
+ self, modality: str, input_data: TokensAndMasks, patch_size: int
637
+ ) -> tuple[Tensor, Tensor]:
638
+ """Apply reconstruction to a modality."""
639
+ masked_modality_name = input_data.get_masked_modality_name(modality)
640
+ modality_mask = getattr(input_data, masked_modality_name)
641
+ modality_data = getattr(input_data, modality)
642
+
643
+ modality_spec = Modality.get(modality)
644
+ bandset_indices = self.tokenization_config.get_bandset_indices(modality)
645
+
646
+ # x: Input tensor with shape [b, h, w, (t), b_s, d]
647
+ modality_tokens, modality_masks = [], []
648
+ for idx, channel_set_indices in enumerate(bandset_indices):
649
+ data = modality_data[..., idx, :]
650
+ masks = modality_mask[..., idx]
651
+ r_model = self.per_modality_reconstructions[modality][
652
+ self._get_reconstruction_module_name(modality, idx)
653
+ ]
654
+ if modality_spec.get_tile_resolution() == 0:
655
+ data = r_model(data)
656
+ else:
657
+ data = r_model(data, patch_size=patch_size)
658
+ modality_tokens.append(data)
659
+ masks = repeat(
660
+ masks,
661
+ "b h w ... -> b (h p_h) (w p_w) ...",
662
+ p_h=patch_size,
663
+ p_w=patch_size,
664
+ )
665
+ modality_masks.append(masks)
666
+ modality_mask = repeat(
667
+ modality_mask,
668
+ "b h w ... -> b (h p_h) (w p_w) ...",
669
+ p_h=patch_size,
670
+ p_w=patch_size,
671
+ )
672
+ return torch.cat(modality_tokens, dim=-1), modality_mask
673
+
674
+ def forward(
675
+ self,
676
+ x: TokensAndMasks,
677
+ timestamps: Tensor,
678
+ patch_size: int,
679
+ input_res: int = BASE_GSD,
680
+ ) -> TokensAndMasks:
681
+ """Return flexibly patchified reconstruction for each modality of the input data.
682
+
683
+ Given a [B, H, W, (T), b_s, D] inputs, returns a [B, H, W, (T), C] output.
684
+ """
685
+ input_data = self.decoder(x, timestamps, patch_size, input_res)
686
+ output_dict = {}
687
+ modalities_to_process = get_modalities_to_process(
688
+ input_data.modalities, [m.name for m in self.supported_modalities]
689
+ )
690
+ for modality in modalities_to_process:
691
+ modality_tokens, modality_masks = self.apply_reconstruction_to_modality(
692
+ modality, input_data, patch_size
693
+ )
694
+ output_dict[modality] = modality_tokens
695
+ modality_mask_name = input_data.get_masked_modality_name(modality)
696
+ output_dict[modality_mask_name] = modality_masks
697
+ return TokensAndMasks(**output_dict)
698
+
699
+
700
+ @dataclass
701
+ class ReconstructorConfig(Config):
702
+ """Configuration for the Reconstructor."""
703
+
704
+ decoder_config: "Config"
705
+ supported_modality_names: list[str]
706
+ max_patch_size: int = 8
707
+ tokenization_config: TokenizationConfig | None = None
708
+
709
+ def validate(self) -> None:
710
+ """Validate the configuration."""
711
+ if len(self.supported_modalities) == 0:
712
+ raise ValueError("At least one modality must be added!")
713
+ else:
714
+ for modality in self.supported_modalities:
715
+ if modality not in Modality.values():
716
+ raise ValueError(f"Modality {modality} is not supported")
717
+ if self.tokenization_config is not None:
718
+ self.tokenization_config.validate()
719
+
720
+ @property
721
+ def supported_modalities(self) -> list[ModalitySpec]:
722
+ """Get the supported modalities."""
723
+ return get_modality_specs_from_names(self.supported_modality_names)
724
+
725
+ def build(self) -> "Reconstructor":
726
+ """Build the reconstructor."""
727
+ self.validate()
728
+ kwargs = self.as_dict(exclude_none=True, recurse=False)
729
+ kwargs.pop("supported_modality_names")
730
+ kwargs["supported_modalities"] = self.supported_modalities
731
+ kwargs.pop("decoder_config")
732
+ kwargs["decoder"] = self.decoder_config.build()
733
+ logger.info(f"Predictor kwargs: {kwargs}")
734
+ return Reconstructor(**kwargs)
735
+
736
+
737
+ class CompositeEncodings(nn.Module):
738
+ """Composite encodings for FlexiVit models."""
739
+
740
+ def __init__(
741
+ self,
742
+ embedding_size: int,
743
+ supported_modalities: list[ModalitySpec],
744
+ max_sequence_length: int,
745
+ learnable_channel_embeddings: bool = True,
746
+ random_channel_embeddings: bool = False,
747
+ tokenization_config: TokenizationConfig | None = None,
748
+ ):
749
+ """Initialize the composite encodings.
750
+
751
+ Args:
752
+ embedding_size: Size of token embeddings
753
+ supported_modalities: Which modalities from Modality this model
754
+ instantiation supports
755
+ max_sequence_length: Maximum sequence length
756
+ learnable_channel_embeddings: Whether to use learnable channel embeddings
757
+ random_channel_embeddings: Initialize channel embeddings randomly (zeros if False)
758
+ tokenization_config: Optional config for custom band groupings
759
+ """
760
+ super().__init__()
761
+ self.embedding_size = embedding_size
762
+ self.supported_modalities = supported_modalities
763
+ self.supported_modality_names = [
764
+ modality.name for modality in supported_modalities
765
+ ]
766
+ self.tokenization_config = tokenization_config or TokenizationConfig()
767
+ self.embedding_size = embedding_size
768
+ self.max_sequence_length = (
769
+ max_sequence_length # This max sequence length is a time dim thing
770
+ )
771
+ # TODO: we need to be able to calculate the size of the param based on what types of embeddings it will get
772
+
773
+ # we have 4 embeddings types (pos_in_time, pos_in_space, month, channel) so each get
774
+ # 0.25 of the dimension
775
+ self.embedding_dim_per_embedding_type = int(embedding_size * 0.25)
776
+ # Position encodings for time dimension initialized to 1D sinusoidal encodings
777
+ self.pos_embed = nn.Parameter(
778
+ get_1d_sincos_pos_encoding(
779
+ torch.arange(max_sequence_length),
780
+ self.embedding_dim_per_embedding_type,
781
+ ),
782
+ requires_grad=False,
783
+ )
784
+ # Month encodings
785
+ month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type)
786
+ self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
787
+ if not learnable_channel_embeddings and not random_channel_embeddings:
788
+ self.per_modality_channel_embeddings = nn.ParameterDict()
789
+ for modality in self.supported_modalities:
790
+ num_bandsets = self.tokenization_config.get_num_bandsets(modality.name)
791
+ shape = (num_bandsets, self.embedding_dim_per_embedding_type)
792
+ channel_embeddings = nn.Parameter(
793
+ torch.zeros(shape), requires_grad=False
794
+ )
795
+ self.per_modality_channel_embeddings[modality.name] = channel_embeddings
796
+ else:
797
+ # Channel embeddings
798
+ if learnable_channel_embeddings:
799
+ args = {"requires_grad": True}
800
+ else:
801
+ args = {"requires_grad": False}
802
+
803
+ self.per_modality_channel_embeddings = nn.ParameterDict()
804
+ for modality in self.supported_modalities:
805
+ num_bandsets = self.tokenization_config.get_num_bandsets(modality.name)
806
+ shape = (num_bandsets, self.embedding_dim_per_embedding_type)
807
+ if random_channel_embeddings:
808
+ channel_embeddings = nn.Parameter(torch.rand(shape), **args)
809
+ else:
810
+ channel_embeddings = nn.Parameter(torch.zeros(shape), **args)
811
+ self.per_modality_channel_embeddings[modality.name] = channel_embeddings
812
+
813
+ self.apply(self._init_weights)
814
+
815
+ def _init_weights(self, m: nn.Module) -> None:
816
+ if isinstance(m, nn.Linear):
817
+ # we use xavier_uniform following official JAX ViT:
818
+ torch.nn.init.xavier_uniform_(m.weight)
819
+ if isinstance(m, nn.Linear) and m.bias is not None:
820
+ # TODO: fix the dtype here
821
+ nn.init.constant_(m.bias, 0).to(torch.float32)
822
+
823
+ @staticmethod
824
+ def calculate_gsd_ratio(input_res: float, patch_size: int) -> float:
825
+ """Calculate the Ground Sample Distance ratio."""
826
+ return input_res * patch_size / BASE_GSD
827
+
828
+ def _apply_encodings_per_modality(
829
+ self,
830
+ modality_name: str,
831
+ modality_tokens: Tensor,
832
+ timestamps: Tensor | None = None,
833
+ patch_size: int | None = None,
834
+ input_res: int | None = None,
835
+ use_modality_encodings: bool = True,
836
+ use_temporal_encodings: bool = True,
837
+ ) -> Tensor:
838
+ """Apply the encodings to the patchified data based on modality type.
839
+
840
+ Args:
841
+ modality_name: Name of the modality being processed
842
+ modality_tokens: Token embeddings for the modality
843
+ timestamps: Optional timestamps for temporal encodings
844
+ patch_size: Optional patch size for spatial encodings
845
+ input_res: Optional input resolution for spatial encodings
846
+ use_modality_encodings: Whether to use modality encodings
847
+ use_temporal_encodings: Whether to use temporal encodings
848
+
849
+ Returns:
850
+ Tensor with encodings applied based on modality type
851
+ """
852
+ logger.debug(
853
+ f"use_modality_encodings: {use_modality_encodings}, use_temporal_encodings: {use_temporal_encodings}"
854
+ )
855
+ # TODO: Improve this implementation it is quite bad
856
+
857
+ modality = Modality.get(modality_name)
858
+ logger.debug(f"Applying encodings to modality {modality}")
859
+ if not use_modality_encodings and use_temporal_encodings:
860
+ b, h, w, t, _ = modality_tokens.shape
861
+ ein_string, ein_dict = (
862
+ "b h w t d",
863
+ {"b": b, "h": h, "w": w, "t": t},
864
+ )
865
+ elif not use_temporal_encodings and not use_modality_encodings:
866
+ b, h, w, _ = modality_tokens.shape
867
+ ein_string, ein_dict = (
868
+ "b h w d",
869
+ {"b": b, "h": h, "w": w},
870
+ )
871
+ elif not use_temporal_encodings and use_modality_encodings:
872
+ raise NotImplementedError("Not implemented")
873
+ else:
874
+ if modality_tokens.ndim == 3:
875
+ # modality_tokens = [B, Band_Sets, D]; static in space, static in time
876
+ b, b_s, _ = modality_tokens.shape
877
+ ein_string, ein_dict = "b b_s d", {"b": b, "b_s": b_s}
878
+ elif modality_tokens.ndim == 4:
879
+ b, t, b_s, _ = modality_tokens.shape
880
+ ein_string, ein_dict = "b t b_s d", {"b": b, "t": t, "b_s": b_s}
881
+ elif modality_tokens.ndim == 5:
882
+ b, h, w, b_s, _ = modality_tokens.shape
883
+ ein_string, ein_dict = (
884
+ "b h w b_s d",
885
+ {"b": b, "h": h, "w": w, "b_s": b_s},
886
+ )
887
+ elif modality_tokens.ndim == 6:
888
+ b, h, w, t, b_s, _ = modality_tokens.shape
889
+ ein_string, ein_dict = (
890
+ "b h w t b_s d",
891
+ {"b": b, "h": h, "w": w, "t": t, "b_s": b_s},
892
+ )
893
+ else:
894
+ raise ValueError(f"Unsupported tokens shape: {modality_tokens.shape}")
895
+
896
+ device = modality_tokens.device
897
+ modality_embed = torch.zeros(modality_tokens.shape, device=device)
898
+ n = self.embedding_dim_per_embedding_type
899
+ actual_bandsets = modality_tokens.shape[-2]
900
+
901
+ # Channel embeddings
902
+ if use_modality_encodings:
903
+ channel_embed = self.per_modality_channel_embeddings[modality.name]
904
+ if channel_embed.shape[0] != actual_bandsets:
905
+ raise ValueError(
906
+ f"Channel embeddings for {modality.name} expect "
907
+ f"{channel_embed.shape[0]} bandsets but tokens have "
908
+ f"{actual_bandsets}. Ensure tokenization_config is "
909
+ "consistently passed to the encoder/decoder and masking strategy."
910
+ )
911
+ channel_embed = repeat(
912
+ channel_embed, f"b_s d -> {ein_string}", **ein_dict
913
+ ).to(device)
914
+ modality_embed[..., :n] += channel_embed
915
+
916
+ if modality.is_multitemporal and use_temporal_encodings:
917
+ # Time position encodings
918
+ time_embed = repeat(self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict)
919
+ modality_embed[..., n : n * 2] += time_embed.to(device)
920
+
921
+ # Month encodings
922
+ assert timestamps is not None
923
+ months = timestamps[:, :, 1]
924
+ month_embed = self.month_embed(months)
925
+ month_embed = repeat(month_embed, f"b t d -> {ein_string}", **ein_dict)
926
+ modality_embed[..., n * 2 : n * 3] += month_embed.to(device)
927
+ if modality.is_spatial:
928
+ # Spatial encodings
929
+ assert input_res is not None
930
+ assert patch_size is not None
931
+ gsd_ratio = self.calculate_gsd_ratio(input_res, patch_size)
932
+ spatial_embed = get_2d_sincos_pos_encoding_with_resolution(
933
+ grid_size=h,
934
+ res=torch.ones(b, device=device) * gsd_ratio,
935
+ encoding_dim=self.embedding_dim_per_embedding_type,
936
+ device=device,
937
+ )
938
+ spatial_embed = rearrange(spatial_embed, "b (h w) d -> b h w d", h=h, w=w)
939
+ spatial_embed = repeat(
940
+ spatial_embed, f"b h w d -> {ein_string}", **ein_dict
941
+ )
942
+ modality_embed[..., n * 3 : n * 4] += spatial_embed
943
+ return modality_tokens + modality_embed
944
+
945
+ def forward(
946
+ self,
947
+ per_modality_input_tokens: dict[str, Tensor],
948
+ timestamps: Tensor,
949
+ patch_size: int,
950
+ input_res: int = BASE_GSD,
951
+ ) -> dict[str, Tensor]:
952
+ """Apply the encodings to the patchified data.
953
+
954
+ Args:
955
+ per_modality_input_tokens: Tokens only for each modality
956
+ timestamps: Timestamps of the data
957
+ patch_size: Size of patches
958
+ input_res: Resolution of the input data
959
+
960
+ Returns:
961
+ Tokens only for each modality
962
+ """
963
+ output_dict = {}
964
+ available_modalities = return_modalities_from_dict(per_modality_input_tokens)
965
+ modalities_to_process = get_modalities_to_process(
966
+ available_modalities, self.supported_modality_names
967
+ )
968
+ for modality_name in modalities_to_process:
969
+ output_dict[modality_name] = self._apply_encodings_per_modality(
970
+ modality_name,
971
+ per_modality_input_tokens[modality_name],
972
+ timestamps=timestamps,
973
+ patch_size=patch_size,
974
+ input_res=input_res,
975
+ )
976
+ return output_dict
977
+
978
+
979
+ class FlexiVitBase(nn.Module):
980
+ """FlexiVitBase is a base class for FlexiVit models."""
981
+
982
+ cross_attn: bool = False
983
+
984
+ def __init__(
985
+ self,
986
+ embedding_size: int,
987
+ max_sequence_length: int,
988
+ num_heads: int,
989
+ mlp_ratio: float,
990
+ depth: int,
991
+ drop_path: float,
992
+ supported_modalities: list[ModalitySpec],
993
+ learnable_channel_embeddings: bool = True,
994
+ random_channel_embeddings: bool = False,
995
+ use_flash_attn: bool = False,
996
+ qk_norm: bool = False,
997
+ tokenization_config: TokenizationConfig | None = None,
998
+ ) -> None:
999
+ """Initialize the FlexiVitBase class."""
1000
+ super().__init__()
1001
+
1002
+ self.embedding_size = embedding_size
1003
+ self.supported_modalities = supported_modalities
1004
+ self.supported_modality_names = [x.name for x in supported_modalities]
1005
+ logger.info(f"modalities being used by model: {self.supported_modality_names}")
1006
+
1007
+ self.max_sequence_length = max_sequence_length
1008
+ self._base_tokenization_config = tokenization_config or TokenizationConfig()
1009
+
1010
+ self.use_flash_attn = use_flash_attn
1011
+ self.learnable_channel_embeddings = learnable_channel_embeddings
1012
+ self.random_channel_embeddings = random_channel_embeddings
1013
+ self.blocks = nn.ModuleList(
1014
+ [
1015
+ Block(
1016
+ embedding_size,
1017
+ num_heads,
1018
+ mlp_ratio,
1019
+ qkv_bias=True,
1020
+ qk_norm=qk_norm,
1021
+ norm_layer=nn.LayerNorm, # TODO: This should be configurable
1022
+ cross_attn=self.cross_attn,
1023
+ drop_path=drop_path,
1024
+ use_flash_attn=self.use_flash_attn,
1025
+ )
1026
+ for _ in range(depth)
1027
+ ]
1028
+ )
1029
+
1030
+ self.composite_encodings = CompositeEncodings(
1031
+ embedding_size,
1032
+ self.supported_modalities,
1033
+ max_sequence_length,
1034
+ learnable_channel_embeddings,
1035
+ random_channel_embeddings,
1036
+ tokenization_config=self._base_tokenization_config,
1037
+ )
1038
+ self.apply(self._init_weights)
1039
+
1040
+ def _init_weights(self, m: nn.Module) -> None:
1041
+ if isinstance(m, nn.Linear):
1042
+ # we use xavier_uniform following official JAX ViT:
1043
+ torch.nn.init.xavier_uniform_(m.weight)
1044
+ if isinstance(m, nn.Linear) and m.bias is not None:
1045
+ nn.init.constant_(m.bias, 0)
1046
+
1047
+ @staticmethod
1048
+ def grab_modality_specific_dims(modality_data: Tensor) -> tuple[int, ...]:
1049
+ """Grab the modality specific dimensions from the modality data.
1050
+
1051
+ Assumes [B, ..., C, D]
1052
+
1053
+ Every modality will have a batch dimension, a channel dimension and embedding dimension.
1054
+
1055
+ Args:
1056
+ modality_data: Modality data
1057
+
1058
+ Returns:
1059
+ Modality specific dimensions
1060
+ """
1061
+ return modality_data.shape[1:-2] if modality_data.ndim > 3 else ()
1062
+
1063
+ # is naming here confusing if one of these channels can be missing?
1064
+ def collapse_and_combine_hwtc(self, x: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
1065
+ """Collapse the tokens and masks, respectively, into two tensors."""
1066
+ tokens, masks = [], []
1067
+ available_modalities = return_modalities_from_dict(x)
1068
+ modalities_to_process = get_modalities_to_process(
1069
+ available_modalities, self.supported_modality_names
1070
+ )
1071
+ for modality in modalities_to_process:
1072
+ masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
1073
+ modality
1074
+ )
1075
+ x_modality = x[modality]
1076
+ x_modality_mask = x[masked_modality_name]
1077
+ tokens.append(rearrange(x_modality, "b ... d -> b (...) d"))
1078
+ masks.append(rearrange(x_modality_mask, "b ... -> b (...)"))
1079
+ tokens = torch.cat(tokens, dim=1)
1080
+ masks = torch.cat(masks, dim=1)
1081
+
1082
+ return tokens, masks
1083
+
1084
+ @staticmethod
1085
+ def _construct_einops_pattern(
1086
+ spatial_dims: tuple[int, ...],
1087
+ ) -> tuple[str, dict[str, int]]:
1088
+ """Given a tuple of spatial dimensions (e.g. [B, H, W, T, ...]).
1089
+
1090
+ build (1) an einops rearrange pattern of the form:
1091
+ "d -> (dim0) (dim1) (dim2)... d"
1092
+ and (2) a dictionary mapping dim0..dimN to the actual sizes.
1093
+
1094
+ This allows reshaping a single-dimensional tensor [D] into
1095
+ [B, H, W, T, ..., D] using einops.
1096
+ """
1097
+ dim_dict = {f"dim{i}": size for i, size in enumerate(spatial_dims)}
1098
+ # e.g., "d -> (dim0) (dim1) (dim2) (dim3) d"
1099
+ pattern_input = (
1100
+ "d -> " + " ".join(f"(dim{i})" for i in range(len(spatial_dims))) + " d"
1101
+ )
1102
+ return pattern_input, dim_dict
1103
+
1104
+ def split_tokens_masks_and_dims(
1105
+ self, x: dict[str, Tensor]
1106
+ ) -> tuple[dict[str, Tensor], dict[str, Tensor], dict[str, tuple]]:
1107
+ """Split the tokens, masks, and dimensions out into separate dicts."""
1108
+ tokens_only_dict = {}
1109
+ original_masks_dict = {}
1110
+ modalities_to_dims_dict = {}
1111
+ available_modalities = return_modalities_from_dict(x)
1112
+ modalities_to_process = get_modalities_to_process(
1113
+ available_modalities, self.supported_modality_names
1114
+ )
1115
+ for modality in modalities_to_process:
1116
+ x_modality = x[modality]
1117
+ tokens_only_dict[modality] = x_modality
1118
+ modalities_to_dims_dict[modality] = x_modality.shape
1119
+ masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
1120
+ modality
1121
+ )
1122
+ original_masks_dict[masked_modality_name] = x[masked_modality_name]
1123
+ return tokens_only_dict, original_masks_dict, modalities_to_dims_dict
1124
+
1125
+ @staticmethod
1126
+ def split_and_expand_per_modality(
1127
+ x: Tensor, modalities_to_dims_dict: dict
1128
+ ) -> dict[str, Tensor]:
1129
+ """Split and expand the tokens per modality.
1130
+
1131
+ Args:
1132
+ x: Tokens to split and expand (b n d)
1133
+ modalities_to_dims_dict: Dictionary mapping modalities to their dimensions
1134
+ Returns:
1135
+ tokens_only_dict: mapping modalities to their tokens
1136
+ """
1137
+ tokens_only_dict = {}
1138
+ tokens_reshaped = 0
1139
+ for modality, dims in modalities_to_dims_dict.items():
1140
+ # Skip batch (first) and embedding (last) dimensions
1141
+ middle_dims = dims[1:-1]
1142
+ num_tokens_for_modality = math.prod(middle_dims)
1143
+
1144
+ # Extract tokens for this modality (b n d)
1145
+ modality_tokens = x[
1146
+ :, tokens_reshaped : tokens_reshaped + num_tokens_for_modality
1147
+ ]
1148
+
1149
+ # TODO: see if there is a general and clean einops way to do this
1150
+ # Reshape to original dimensions (e.g., for 4D spatial dims: b d1 d2 d3 d4 e)
1151
+ x_modality = modality_tokens.view(x.shape[0], *middle_dims, x.shape[-1])
1152
+
1153
+ tokens_reshaped += num_tokens_for_modality
1154
+ tokens_only_dict[modality] = x_modality
1155
+
1156
+ return tokens_only_dict
1157
+
1158
+ @staticmethod
1159
+ def pack_tokens(tokens: Tensor, mask: Tensor) -> Tensor:
1160
+ """Pack the Batch and sequence length dimensions of tokens and mask into a single tensor.
1161
+
1162
+ Args:
1163
+ tokens: Tokens to pack
1164
+ mask: Mask to pack
1165
+
1166
+ Returns:
1167
+ Packed tokens enabling varlen flash attention
1168
+ """
1169
+ tokens_packed = torch.flatten(tokens, end_dim=1)
1170
+ mask = torch.flatten(mask)
1171
+ tokens = tokens_packed[mask]
1172
+ return tokens
1173
+
1174
+ @staticmethod
1175
+ def unpack_tokens(tokens: Tensor, mask: Tensor, og_shape: tuple) -> Tensor:
1176
+ """Unpack the Batch and sequence length dimensions of tokens and mask into a single tensor.
1177
+
1178
+ Args:
1179
+ tokens: Tokens to unpack
1180
+ mask: Mask to unpack
1181
+ og_shape: Original shape of the tokens
1182
+ """
1183
+ tokens_new = tokens.new_zeros(og_shape[0] * og_shape[1], og_shape[2])
1184
+ mask = torch.flatten(mask)
1185
+ tokens_new[mask] = tokens
1186
+ tokens = tokens_new.reshape(og_shape[0], og_shape[1], -1)
1187
+ return tokens
1188
+
1189
+ def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
1190
+ """Apply FSDP to the model."""
1191
+ for block in self.blocks:
1192
+ block.apply_fsdp(**fsdp_kwargs)
1193
+
1194
+ def apply_compile(self) -> None:
1195
+ """Apply torch.compile to the model."""
1196
+ for block in self.blocks:
1197
+ block.apply_compile()
1198
+
1199
+
1200
+ class Encoder(FlexiVitBase):
1201
+ """Encoder module that processes masked input samples into token representations."""
1202
+
1203
+ cross_attn: bool = False
1204
+
1205
+ def __init__(
1206
+ self,
1207
+ embedding_size: int,
1208
+ max_patch_size: int,
1209
+ min_patch_size: int,
1210
+ num_heads: int,
1211
+ mlp_ratio: float,
1212
+ depth: int,
1213
+ drop_path: float,
1214
+ supported_modalities: list[ModalitySpec],
1215
+ max_sequence_length: int,
1216
+ num_register_tokens: int = 0,
1217
+ learnable_channel_embeddings: bool = True,
1218
+ random_channel_embeddings: bool = False,
1219
+ num_projection_layers: int = 1,
1220
+ aggregate_then_project: bool = True,
1221
+ use_flash_attn: bool = False,
1222
+ frozen_patch_embeddings: bool = False,
1223
+ qk_norm: bool = False,
1224
+ log_token_norm_stats: bool = False,
1225
+ tokenization_config: TokenizationConfig | None = None,
1226
+ ):
1227
+ """Initialize the encoder.
1228
+
1229
+ Args:
1230
+ embedding_size: Size of token embeddings
1231
+ max_patch_size: Maximum patch size for patchification
1232
+ min_patch_size: Minimum patch size for patchification
1233
+ num_heads: Number of attention heads
1234
+ mlp_ratio: Ratio for MLP hidden dimension
1235
+ depth: Number of transformer layers
1236
+ drop_path: Drop path rate
1237
+ supported_modalities: list documenting modalities used in a given model instantiation
1238
+ max_sequence_length: Maximum sequence length
1239
+ num_register_tokens: Number of register tokens to use
1240
+ learnable_channel_embeddings: Whether to use learnable channel embeddings
1241
+ random_channel_embeddings: Initialize channel embeddings randomly (zeros if False)
1242
+ num_projection_layers: The number of layers to use in the projection. If >1, then
1243
+ a ReLU activation will be applied between layers
1244
+ aggregate_then_project: If True, then we will average the tokens before applying
1245
+ the projection. If False, we will apply the projection first.
1246
+ use_flash_attn: Whether to use flash attention
1247
+ frozen_patch_embeddings: If True, we freeze the embedding layer, as recommended in
1248
+ https://arxiv.org/pdf/2104.02057, Section 4.2
1249
+ qk_norm: Whether to apply normalization to Q and K in attention
1250
+ log_token_norm_stats: Whether to log the token norm stats
1251
+ tokenization_config: Optional config for custom band groupings
1252
+ """
1253
+ self.tokenization_config = tokenization_config or TokenizationConfig()
1254
+ super().__init__(
1255
+ embedding_size=embedding_size,
1256
+ depth=depth,
1257
+ mlp_ratio=mlp_ratio,
1258
+ num_heads=num_heads,
1259
+ max_sequence_length=max_sequence_length,
1260
+ learnable_channel_embeddings=learnable_channel_embeddings,
1261
+ drop_path=drop_path,
1262
+ supported_modalities=supported_modalities,
1263
+ use_flash_attn=use_flash_attn,
1264
+ random_channel_embeddings=random_channel_embeddings,
1265
+ qk_norm=qk_norm,
1266
+ tokenization_config=self.tokenization_config,
1267
+ )
1268
+ self.num_register_tokens = num_register_tokens
1269
+ self.has_register_tokens = num_register_tokens > 0
1270
+ self.log_token_norm_stats = log_token_norm_stats
1271
+ if self.has_register_tokens:
1272
+ self.register_tokens = nn.Parameter(
1273
+ torch.zeros(num_register_tokens, embedding_size)
1274
+ )
1275
+ self.min_patch_size = min_patch_size
1276
+ self.max_patch_size = max_patch_size
1277
+ self.embedding_size = embedding_size
1278
+ self.patch_embeddings = MultiModalPatchEmbeddings(
1279
+ self.supported_modality_names,
1280
+ self.max_patch_size,
1281
+ self.embedding_size,
1282
+ tokenization_config=self.tokenization_config,
1283
+ )
1284
+ self.project_and_aggregate = ProjectAndAggregate(
1285
+ embedding_size=self.embedding_size,
1286
+ num_layers=num_projection_layers,
1287
+ aggregate_then_project=aggregate_then_project,
1288
+ )
1289
+ self.norm = nn.LayerNorm(self.embedding_size)
1290
+ self.apply(self._init_weights)
1291
+
1292
+ if frozen_patch_embeddings:
1293
+ for p in self.patch_embeddings.parameters():
1294
+ p.requires_grad = False
1295
+ if self.has_register_tokens:
1296
+ self._init_register_tokens()
1297
+
1298
+ def _init_register_tokens(self) -> None:
1299
+ """Initialize the register tokens."""
1300
+ nn.init.xavier_uniform_(self.register_tokens)
1301
+
1302
+ def create_token_exit_ids(
1303
+ self, x: dict[str, Tensor], token_exit_cfg: dict[str, int]
1304
+ ) -> dict[str, Tensor]:
1305
+ """Create the token exit ids for # of layers of attention for each band group.
1306
+
1307
+ Assumes modality channel groups are in the second to last dimension of the tokens.
1308
+ """
1309
+ exit_ids_per_modality_dict = {}
1310
+ available_modalities = return_modalities_from_dict(x)
1311
+ modalities_to_process = get_modalities_to_process(
1312
+ available_modalities, self.supported_modality_names
1313
+ )
1314
+ for modality in modalities_to_process:
1315
+ num_exit_layers = token_exit_cfg[modality]
1316
+ exit_seq_modality = torch.full_like(x[modality], fill_value=num_exit_layers)
1317
+ exit_ids_per_modality_dict[modality] = exit_seq_modality
1318
+ return exit_ids_per_modality_dict
1319
+
1320
+ @staticmethod
1321
+ def remove_masked_tokens(
1322
+ x: Tensor, mask: Tensor
1323
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
1324
+ """Remove masked tokens from the tokens and masks.
1325
+
1326
+ Implementation from https://stackoverflow.com/a/68621610/2332296
1327
+
1328
+ On Input:
1329
+ 0 means this token should be removed
1330
+ 1 means this token should be kept
1331
+
1332
+ Args:
1333
+ x: Tokens to remove masked tokens from
1334
+ mask: Mask to remove masked tokens from
1335
+
1336
+ Returns:
1337
+ tokens: [B, T, D]
1338
+ indices: [B, T]
1339
+ updated_mask: [B, T]
1340
+ seqlens: [B]
1341
+ max_length: [1]
1342
+ where T is the max number of unmasked tokens for an instance
1343
+ """
1344
+ sorted_mask, indices = torch.sort(mask, dim=1, descending=True, stable=True)
1345
+ # Now all the places where we want to keep the token are at the front of the tensor
1346
+ x = x.gather(1, indices[:, :, None].expand_as(x))
1347
+ # Now all tokens that should be kept are first in the tensor
1348
+
1349
+ # set masked values to 0 (not really necessary since we'll ignore them anyway)
1350
+ x = x * sorted_mask.unsqueeze(-1)
1351
+
1352
+ # cut off to the length of the longest sequence
1353
+ seq_lengths = sorted_mask.sum(-1)
1354
+ max_length = seq_lengths.max()
1355
+ x = x[:, :max_length]
1356
+ # New mask chopped to the longest sequence
1357
+ updated_mask = sorted_mask[:, :max_length]
1358
+
1359
+ return x, indices, updated_mask, seq_lengths, max_length
1360
+
1361
+ @staticmethod
1362
+ def add_removed_tokens(
1363
+ x: Tensor, indices: Tensor, mask: Tensor
1364
+ ) -> tuple[Tensor, Tensor]:
1365
+ """Add removed tokens to the tokens and masks.
1366
+
1367
+ Args:
1368
+ x: Tokens to add removed tokens to
1369
+ indices: Original indices of the masked tokens
1370
+ mask: Mask to add removed tokens to
1371
+
1372
+ Returns:
1373
+ tokens: Tokens with removed tokens added
1374
+ mask: Mask with removed tokens added
1375
+ """
1376
+ assert x.shape[1] > 0, (
1377
+ "x must have at least one token we should not mask all tokens"
1378
+ )
1379
+ masked_tokens = repeat(
1380
+ torch.zeros_like(x[0, 0, :]), "d -> b t d", b=x.shape[0], t=indices.shape[1]
1381
+ )
1382
+ full_mask = torch.cat(
1383
+ (
1384
+ mask,
1385
+ torch.zeros(
1386
+ (x.shape[0], indices.shape[1] - x.shape[1]),
1387
+ device=x.device,
1388
+ dtype=mask.dtype,
1389
+ ),
1390
+ ),
1391
+ dim=-1,
1392
+ )
1393
+ # can't set value on leaf variable
1394
+ out = masked_tokens.clone()
1395
+ # put tokens in full masked tensor (at the first N positions in every row)
1396
+ out[full_mask] = x[mask]
1397
+ # then move them to their original positions
1398
+ out = out.scatter(1, indices[:, :, None].expand_as(out), out)
1399
+ full_mask = full_mask.scatter(1, indices.expand_as(full_mask), full_mask)
1400
+ # Values that were masked out are not returned but the values that are still there are returned to the original positions
1401
+ return out, full_mask
1402
+
1403
+ def create_exit_seqs(
1404
+ self,
1405
+ tokens_only_dict: dict[str, Tensor],
1406
+ mask_only_dict: dict[str, Tensor],
1407
+ token_exit_cfg: dict[str, int] | None,
1408
+ ) -> tuple[Tensor | None]:
1409
+ """Create the exit sequences and tokens."""
1410
+ # Check that tokens_only_dict doesn't contain any mask keys
1411
+ assert all(not key.endswith("_mask") for key in tokens_only_dict), (
1412
+ "tokens_only_dict should not contain mask keys"
1413
+ )
1414
+ if token_exit_cfg:
1415
+ exit_ids_per_modality = self.create_token_exit_ids(
1416
+ tokens_only_dict, token_exit_cfg
1417
+ )
1418
+ exit_ids_per_modality.update(mask_only_dict)
1419
+ # Exit ids seqs tells us which layer to exit each token
1420
+ exit_ids_seq, _ = self.collapse_and_combine_hwtc(exit_ids_per_modality)
1421
+ else:
1422
+ exit_ids_seq = None
1423
+ return exit_ids_seq
1424
+
1425
+ def _maybe_get_attn_mask(
1426
+ self,
1427
+ new_mask: Tensor,
1428
+ fast_pass: bool,
1429
+ ) -> Tensor | None:
1430
+ """Get the attention mask or None if we should pass None to the transformer."""
1431
+ if fast_pass or not self.training:
1432
+ return None
1433
+ else:
1434
+ return new_mask
1435
+
1436
+ def add_register_tokens_and_masks(
1437
+ self,
1438
+ tokens: Tensor,
1439
+ attn_mask: Tensor | None,
1440
+ processed_register_tokens: Tensor | None = None,
1441
+ ) -> tuple[Tensor, Tensor | None]:
1442
+ """Concatenate register tokens to the tokens."""
1443
+ batch_size = tokens.shape[0]
1444
+ # Expand register tokens to match batch size: [num_register_tokens, embedding_size] -> [batch_size, num_register_tokens, embedding_size]
1445
+ if processed_register_tokens is None:
1446
+ reg_tokens = self.register_tokens.unsqueeze(0).expand(batch_size, -1, -1)
1447
+ else:
1448
+ reg_tokens = processed_register_tokens
1449
+ # Concatenate register tokens at the beginning: [batch_size, seq_len, embedding_size] -> [batch_size, num_register_tokens + seq_len, embedding_size]
1450
+ tokens = torch.cat([reg_tokens, tokens], dim=1)
1451
+ if attn_mask is not None:
1452
+ # Create mask for register tokens (all True - they should participate in attention)
1453
+ reg_mask = torch.ones(
1454
+ batch_size,
1455
+ self.num_register_tokens,
1456
+ dtype=attn_mask.dtype,
1457
+ device=attn_mask.device,
1458
+ )
1459
+ attn_mask = torch.cat([reg_mask, attn_mask], dim=1)
1460
+ else:
1461
+ reg_mask = None
1462
+ return tokens, attn_mask
1463
+
1464
+ def pop_register_tokens(self, tokens: Tensor) -> tuple[Tensor, Tensor]:
1465
+ """Pop the register tokens from the tokens."""
1466
+ register_tokens = tokens[:, : self.num_register_tokens, :]
1467
+ tokens = tokens[:, self.num_register_tokens :, :]
1468
+ return tokens, register_tokens
1469
+
1470
+ def get_token_norm_stats(
1471
+ self, tokens: Tensor, register_tokens: Tensor
1472
+ ) -> dict[str, float]:
1473
+ """Get the token norm stats."""
1474
+ # Compute norms for register tokens: [batch_size, num_register_tokens]
1475
+ register_tokens_norms = torch.norm(register_tokens, dim=2)
1476
+ reg_norms_flat = register_tokens_norms.flatten()
1477
+ reg_stats = {
1478
+ "register_mean": reg_norms_flat.mean().item(),
1479
+ "register_min": reg_norms_flat.min().item(),
1480
+ "register_max": reg_norms_flat.max().item(),
1481
+ }
1482
+
1483
+ # Compute norms for non-register tokens: [batch_size, seq_len]
1484
+ nonreg_tokens_norms = torch.norm(tokens, dim=2)
1485
+ nonreg_norms_flat = nonreg_tokens_norms.flatten()
1486
+ percentiles = [25.0, 75.0, 90.0, 95.0, 99.0]
1487
+ nonreg_percentiles = torch.quantile(
1488
+ nonreg_norms_flat.float(),
1489
+ torch.tensor(
1490
+ [p / 100.0 for p in percentiles], device=nonreg_norms_flat.device
1491
+ ),
1492
+ ).tolist()
1493
+ nonreg_stats = {
1494
+ "nonregister_mean": nonreg_norms_flat.mean().item(),
1495
+ "nonregister_min": nonreg_norms_flat.min().item(),
1496
+ "nonregister_max": nonreg_norms_flat.max().item(),
1497
+ "nonregister_std": nonreg_norms_flat.std().item(),
1498
+ "nonregister_25th": nonreg_percentiles[0],
1499
+ "nonregister_75th": nonreg_percentiles[1],
1500
+ "nonregister_90th": nonreg_percentiles[2],
1501
+ "nonregister_95th": nonreg_percentiles[3],
1502
+ "nonregister_99th": nonreg_percentiles[4],
1503
+ }
1504
+
1505
+ token_norm_stats = {**reg_stats, **nonreg_stats}
1506
+ return token_norm_stats
1507
+
1508
+ def _maybe_remove_masked_tokens(
1509
+ self,
1510
+ tokens: Tensor,
1511
+ mask: Tensor,
1512
+ fast_pass: bool,
1513
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
1514
+ """Remove masked tokens from the tokens and masks."""
1515
+ if fast_pass and not self.use_flash_attn:
1516
+ # This is the inference fast pass
1517
+ indices = None
1518
+ new_mask = None
1519
+ seq_lengths = None
1520
+ max_seqlen = None
1521
+ bool_mask = None
1522
+ else:
1523
+ bool_mask = mask == MaskValue.ONLINE_ENCODER.value
1524
+ tokens, indices, new_mask, seq_lengths, max_seqlen = (
1525
+ self.remove_masked_tokens(tokens, bool_mask)
1526
+ )
1527
+ return tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask
1528
+
1529
+ def _maybe_add_removed_tokens(
1530
+ self,
1531
+ tokens: Tensor,
1532
+ indices: Tensor,
1533
+ mask: Tensor,
1534
+ fast_pass: bool,
1535
+ ) -> Tensor:
1536
+ """Add removed tokens to the tokens and masks."""
1537
+ if not fast_pass:
1538
+ tokens, _ = self.add_removed_tokens(tokens, indices, mask)
1539
+ return tokens
1540
+
1541
+ def apply_attn(
1542
+ self,
1543
+ x: dict[str, Tensor],
1544
+ timestamps: Tensor,
1545
+ patch_size: int,
1546
+ input_res: int,
1547
+ token_exit_cfg: dict[str, int] | None = None,
1548
+ fast_pass: bool = False,
1549
+ ) -> tuple[dict[str, Tensor], dict[str, Any] | None]:
1550
+ """Apply the attention to the tokens and masks."""
1551
+ tokens_only_dict, original_masks_dict, modalities_to_dims_dict = (
1552
+ self.split_tokens_masks_and_dims(x)
1553
+ )
1554
+ # already a no-op but we could remove entirely
1555
+ exit_ids_seq = self.create_exit_seqs(
1556
+ tokens_only_dict, original_masks_dict, token_exit_cfg
1557
+ )
1558
+ # exited tokens are just the linear projection
1559
+ exited_tokens, _ = self.collapse_and_combine_hwtc(x)
1560
+
1561
+ tokens_dict = self.composite_encodings.forward(
1562
+ tokens_only_dict,
1563
+ timestamps,
1564
+ patch_size,
1565
+ input_res,
1566
+ )
1567
+ tokens_dict.update(original_masks_dict)
1568
+ tokens, mask = self.collapse_and_combine_hwtc(tokens_dict)
1569
+
1570
+ tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask = (
1571
+ self._maybe_remove_masked_tokens(tokens, mask, fast_pass)
1572
+ )
1573
+
1574
+ if exit_ids_seq is not None:
1575
+ exit_ids_seq, _, _, _, _ = self.remove_masked_tokens(
1576
+ exit_ids_seq, bool_mask
1577
+ )
1578
+ # still linear projections
1579
+ exited_tokens, _, _, _, _ = self.remove_masked_tokens(
1580
+ exited_tokens, bool_mask
1581
+ )
1582
+
1583
+ # Pack x tokens
1584
+ if self.use_flash_attn:
1585
+ cu_seqlens = get_cumulative_sequence_lengths(seq_lengths)
1586
+ og_shape = tokens.shape
1587
+ tokens = self.pack_tokens(tokens, new_mask)
1588
+ else:
1589
+ cu_seqlens = None
1590
+
1591
+ attn_mask = self._maybe_get_attn_mask(
1592
+ new_mask,
1593
+ fast_pass=fast_pass,
1594
+ )
1595
+
1596
+ if self.has_register_tokens:
1597
+ tokens, attn_mask = self.add_register_tokens_and_masks(tokens, attn_mask)
1598
+
1599
+ # Apply attn with varying encoder depths
1600
+ for i_blk, blk in enumerate(self.blocks):
1601
+ # Skip the zeroth block because we want to use the exited tokens that don't have encodings as this allows trivial solution of predicting the shared encodings
1602
+ if (exit_ids_seq is not None) and (i_blk > 0):
1603
+ # this should only ever be called by the target encoder,
1604
+ # in a torch.no_grad context
1605
+ assert exited_tokens is not None
1606
+ # If a token should exit, then we update the exit token with the current token at the same position
1607
+ exited_tokens = torch.where(
1608
+ condition=(exit_ids_seq == i_blk),
1609
+ input=tokens,
1610
+ other=exited_tokens,
1611
+ )
1612
+ # we take the inverse of the mask because a value
1613
+ # of True indicates the value *should* take part in
1614
+ # attention
1615
+ # WARNING: THIS MAY CHANGE DEPENDING ON THE ATTENTION IMPLEMENTATION
1616
+
1617
+ tokens = blk(
1618
+ x=tokens,
1619
+ cu_seqlens=cu_seqlens,
1620
+ max_seqlen=max_seqlen,
1621
+ # we will have to specify k and q lens for cross attention
1622
+ attn_mask=attn_mask,
1623
+ )
1624
+
1625
+ if self.has_register_tokens:
1626
+ tokens, register_tokens = self.pop_register_tokens(tokens)
1627
+ token_norm_stats = (
1628
+ self.get_token_norm_stats(tokens, register_tokens)
1629
+ if self.log_token_norm_stats
1630
+ else None
1631
+ )
1632
+ else:
1633
+ token_norm_stats = None
1634
+
1635
+ if self.use_flash_attn:
1636
+ tokens = self.unpack_tokens(tokens, new_mask, og_shape)
1637
+
1638
+ if exit_ids_seq is not None:
1639
+ # this should only ever be called by the target encoder,
1640
+ # in a torch.no_grad context
1641
+ assert exited_tokens is not None
1642
+ # full depth
1643
+ # IMPORTANT: write this to x
1644
+ tokens = torch.where(
1645
+ condition=(exit_ids_seq == (i_blk + 1)), # 2 for full depth
1646
+ input=tokens,
1647
+ other=exited_tokens,
1648
+ )
1649
+ # we apply the norm before we add the removed tokens,
1650
+ # so that the norm is only computed against "real" tokens
1651
+ tokens = self.norm(tokens)
1652
+ # we don't care about the mask returned by add_removed_tokens, since we will
1653
+ # just use the original, unclipped mask here
1654
+ tokens = self._maybe_add_removed_tokens(tokens, indices, new_mask, fast_pass)
1655
+
1656
+ tokens_per_modality_dict = self.split_and_expand_per_modality(
1657
+ tokens, modalities_to_dims_dict
1658
+ )
1659
+ # merge original masks and the processed tokens
1660
+ tokens_per_modality_dict.update(original_masks_dict)
1661
+ return tokens_per_modality_dict, token_norm_stats
1662
+
1663
+ def forward(
1664
+ self,
1665
+ x: MaskedOlmoEarthSample,
1666
+ patch_size: int,
1667
+ input_res: int = BASE_GSD,
1668
+ token_exit_cfg: dict | None = None,
1669
+ fast_pass: bool = False,
1670
+ ) -> dict[str, Any]:
1671
+ """Process masked input samples into token representations.
1672
+
1673
+ Args:
1674
+ x: Masked input sample containing the data to be encoded
1675
+ patch_size: Size of patches to divide the input into
1676
+ input_res: Resolution of the input data
1677
+ token_exit_cfg: Configuration for token exit
1678
+ fast_pass: Whether to always pass None as the mask to the transformer, this enables torch based flash attention, and skips mask construciton and sorting
1679
+
1680
+ Returns:
1681
+ TokensAndMasks containing the encoded representations and their masks
1682
+ """
1683
+ if fast_pass and token_exit_cfg is not None:
1684
+ raise ValueError("token_exit_cfg cannot be set when fast_pass is True")
1685
+
1686
+ patchified_tokens_and_masks = self.patch_embeddings.forward(
1687
+ x, patch_size, fast_pass=fast_pass
1688
+ )
1689
+ if token_exit_cfg is None or any(
1690
+ [exit_depth > 0 for exit_depth in token_exit_cfg.values()]
1691
+ ):
1692
+ patchified_tokens_and_masks, token_norm_stats = self.apply_attn(
1693
+ x=patchified_tokens_and_masks,
1694
+ timestamps=x.timestamps,
1695
+ patch_size=patch_size,
1696
+ input_res=input_res,
1697
+ token_exit_cfg=token_exit_cfg,
1698
+ fast_pass=fast_pass,
1699
+ )
1700
+ else:
1701
+ token_norm_stats = {}
1702
+ output = TokensAndMasks(**patchified_tokens_and_masks)
1703
+ output_dict: dict[str, Any] = {
1704
+ "tokens_and_masks": output,
1705
+ }
1706
+ if token_norm_stats:
1707
+ output_dict["token_norm_stats"] = token_norm_stats
1708
+
1709
+ if not fast_pass:
1710
+ output_dict["project_aggregated"] = self.project_and_aggregate(output)
1711
+ return output_dict
1712
+
1713
+ def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
1714
+ """Apply FSDP to the model."""
1715
+ super().apply_fsdp(**fsdp_kwargs)
1716
+ # Don't Shard the small layers
1717
+ # fully_shard(self.patch_embeddings, **fsdp_kwargs)
1718
+ # register_fsdp_forward_method(self.patch_embeddings, "forward")
1719
+ # fully_shard(self.project_and_aggregate, **fsdp_kwargs)
1720
+ # register_fsdp_forward_method(self.project_and_aggregate, "forward")
1721
+ fully_shard(self, **fsdp_kwargs)
1722
+
1723
+ def apply_compile(self) -> None:
1724
+ """Apply torch.compile to the model."""
1725
+ # self.compile(mode="max-autotune", dynamic=False, fullgraph=True)
1726
+ logger.info("Compiling blocks")
1727
+ # torch.compile(self.blocks, dynamic=False, mode="max-autotune", fullgraph=True)
1728
+ # individual block compile is still a lot slower
1729
+ for block in self.blocks:
1730
+ block.apply_compile()
1731
+ # torch.compile(self.patch_embeddings, dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
1732
+
1733
+
1734
+ class PredictorBase(FlexiVitBase):
1735
+ """Predictor module that generates predictions from encoded tokens."""
1736
+
1737
+ cross_attn = True
1738
+
1739
+ def __init__(
1740
+ self,
1741
+ supported_modalities: list[ModalitySpec],
1742
+ encoder_embedding_size: int = 128,
1743
+ decoder_embedding_size: int = 128,
1744
+ depth: int = 2,
1745
+ mlp_ratio: float = 2.0,
1746
+ num_heads: int = 8,
1747
+ max_sequence_length: int = 24,
1748
+ drop_path: float = 0.0,
1749
+ learnable_channel_embeddings: bool = True,
1750
+ random_channel_embeddings: bool = False,
1751
+ output_embedding_size: int | None = None,
1752
+ use_flash_attn: bool = False,
1753
+ qk_norm: bool = False,
1754
+ tokenization_config: TokenizationConfig | None = None,
1755
+ ):
1756
+ """Initialize the predictor.
1757
+
1758
+ Args:
1759
+ supported_modalities: modalities this model instantiation supports
1760
+ encoder_embedding_size: Size of encoder embeddings
1761
+ decoder_embedding_size: Size of decoder embeddings
1762
+ depth: Number of transformer layers
1763
+ mlp_ratio: Ratio for MLP hidden dimension
1764
+ num_heads: Number of attention heads
1765
+ max_sequence_length: Maximum sequence length
1766
+ drop_path: Drop path rate
1767
+ learnable_channel_embeddings: Whether to use learnable channel embeddings
1768
+ random_channel_embeddings: Whether to randomly initialize channel embeddings
1769
+ output_embedding_size: Size of output embeddings
1770
+ use_flash_attn: Whether to use flash attention
1771
+ qk_norm: Whether to apply normalization to Q and K in attention
1772
+ tokenization_config: Optional config for custom band groupings
1773
+ """
1774
+ self.tokenization_config = tokenization_config or TokenizationConfig()
1775
+ super().__init__(
1776
+ embedding_size=decoder_embedding_size,
1777
+ depth=depth,
1778
+ mlp_ratio=mlp_ratio,
1779
+ num_heads=num_heads,
1780
+ max_sequence_length=max_sequence_length,
1781
+ drop_path=drop_path,
1782
+ learnable_channel_embeddings=learnable_channel_embeddings,
1783
+ random_channel_embeddings=random_channel_embeddings,
1784
+ supported_modalities=supported_modalities,
1785
+ use_flash_attn=use_flash_attn,
1786
+ qk_norm=qk_norm,
1787
+ tokenization_config=self.tokenization_config,
1788
+ )
1789
+ self.learnable_channel_embeddings = learnable_channel_embeddings
1790
+ self.random_channel_embeddings = random_channel_embeddings
1791
+ self.encoder_embedding_size = encoder_embedding_size
1792
+ self.encoder_to_decoder_embed = nn.Linear(
1793
+ encoder_embedding_size, decoder_embedding_size, bias=True
1794
+ )
1795
+ if output_embedding_size is None:
1796
+ output_embedding_size = encoder_embedding_size
1797
+ self.output_embedding_size = output_embedding_size
1798
+ self.to_output_embed = nn.Linear(
1799
+ decoder_embedding_size, output_embedding_size, bias=True
1800
+ )
1801
+ # THIS is the learnable mask token
1802
+ self.mask_token = nn.Parameter(torch.zeros(decoder_embedding_size))
1803
+
1804
+ self.input_norm = nn.LayerNorm(encoder_embedding_size)
1805
+ self.norm = nn.LayerNorm(decoder_embedding_size)
1806
+ self.apply(self._init_weights)
1807
+
1808
+ def add_masks(self, x: dict[str, Tensor]) -> dict[str, Tensor]:
1809
+ """Replace tokens that should be decoded (MaskValue.DECODER_ONLY) with the learnable mask token.
1810
+
1811
+ in a dimension-agnostic way using einops. We assume the final dimension of each token tensor
1812
+ is the embedding dimension matching self.mask_token's size.
1813
+ """
1814
+ output_dict = {}
1815
+ available_modalities = return_modalities_from_dict(x)
1816
+ modalities_to_process = get_modalities_to_process(
1817
+ available_modalities, self.supported_modality_names
1818
+ )
1819
+ for modality in modalities_to_process:
1820
+ x_modality = x[modality]
1821
+ mask_name = MaskedOlmoEarthSample.get_masked_modality_name(modality)
1822
+ mask_modality = x[mask_name]
1823
+ # A boolean mask: True where tokens must be replaced by the mask token
1824
+ kept_mask = mask_modality == MaskValue.DECODER.value
1825
+
1826
+ # Build the einops pattern and dimension dict
1827
+ spatial_dims = x_modality.shape[
1828
+ :-1
1829
+ ] # all dimensions except the last (embedding)
1830
+ pattern_input, dim_dict = self._construct_einops_pattern(spatial_dims)
1831
+
1832
+ mask_token_broadcasted = repeat(self.mask_token, pattern_input, **dim_dict)
1833
+
1834
+ # Where kept_mask is True, use the broadcasted mask token
1835
+ x_modality = torch.where(
1836
+ kept_mask.unsqueeze(-1).bool(), mask_token_broadcasted, x_modality
1837
+ )
1838
+
1839
+ output_dict[modality] = x_modality
1840
+
1841
+ return output_dict
1842
+
1843
+ # TODO: GIVE more explicit function names
1844
+ @staticmethod
1845
+ def split_x_y(tokens: Tensor, mask: Tensor) -> tuple[Tensor, ...]:
1846
+ """Splits tokens into three groups based on mask values.
1847
+
1848
+ This function:
1849
+ 1. Sorts tokens according to the mask and gathers them in order.
1850
+ 2. Chooses tokens to be decoded (x) based on the mask value DECODER.
1851
+ 3. Chooses tokens to be used as context (y) based on the mask value ONLINE_ENCODER.
1852
+ 4. Identifies missing tokens (z) based on the mask value MISSING.
1853
+ 5. Returns boolean masks for x, y, and z along with indices to revert to the original ordering.
1854
+
1855
+ Args:
1856
+ tokens: Tokens to split of shape [B, T, D].
1857
+ mask: Mask of shape [B, T].
1858
+
1859
+ Returns:
1860
+ tokens_to_decode: Tokens to be decoded of shape [B, X_len, D].
1861
+ unmasked_tokens: Tokens to be used as context of shape [B, Y_len, D].
1862
+ tokens_to_decode_mask: Binary mask for x tokens of shape [B, X_len].
1863
+ unmasked_tokens_mask: Binary mask for y tokens of shape [B, Y_len].
1864
+ indices: Indices for restoring the original token ordering of shape [B, T].
1865
+ seqlens_tokens_to_decode: Sequence lengths of tokens to decode of shape [B].
1866
+ seqlens_unmasked_tokens: Sequence lengths of unmasked tokens of shape [B].
1867
+ max_length_of_decoded_tokens: Maximum length of decoded tokens of shape [1].
1868
+ max_length_of_unmasked_tokens: Maximum length of unmasked tokens of shape [1].
1869
+ """
1870
+ # Set Missing Masks to Target Encoder ONLY so that we can have all unused tokens in the middle
1871
+ org_mask_dtype = mask.dtype
1872
+ missing_mask = mask == MaskValue.MISSING.value
1873
+ mask[missing_mask] = MaskValue.TARGET_ENCODER_ONLY.value
1874
+
1875
+ # Sort tokens by mask value (descending order)
1876
+ sorted_mask, indices = torch.sort(
1877
+ mask.int(), dim=1, descending=True, stable=True
1878
+ )
1879
+ tokens = tokens.gather(1, indices[:, :, None].expand_as(tokens))
1880
+
1881
+ # Create binary masks for Encoder and Decoder
1882
+ binarized_decoder_mask = sorted_mask == MaskValue.DECODER.value
1883
+ binarized_online_encoder_mask = sorted_mask == MaskValue.ONLINE_ENCODER.value
1884
+
1885
+ seqlens_unmasked_tokens = binarized_online_encoder_mask.sum(dim=-1)
1886
+ max_length_of_unmasked_tokens = seqlens_unmasked_tokens.max()
1887
+ seqlens_tokens_to_decode = binarized_decoder_mask.sum(dim=-1)
1888
+ max_length_of_decoded_tokens = seqlens_tokens_to_decode.max()
1889
+
1890
+ # the y mask is going to be used to determine which of the y values take. True values
1891
+ # take part in the attention (we don't take the inverse here, unlike in the decoder)
1892
+ tokens_to_decode = tokens[:, :max_length_of_decoded_tokens]
1893
+ tokens_to_decode_mask = binarized_decoder_mask[
1894
+ :, :max_length_of_decoded_tokens
1895
+ ].to(org_mask_dtype)
1896
+
1897
+ unmasked_tokens = tokens[:, -max_length_of_unmasked_tokens:]
1898
+ # the x_mask is just going to be used in the reconstruction, to know which
1899
+ # x tokens to add back into the token list. TODO is this even necessary? it could
1900
+ # get padded with noise tokens since we don't care about reconstruction at all
1901
+ # for a whole bunch of tokens
1902
+ unmasked_tokens_mask = binarized_online_encoder_mask[
1903
+ :, -max_length_of_unmasked_tokens:
1904
+ ].to(org_mask_dtype)
1905
+
1906
+ return (
1907
+ tokens_to_decode,
1908
+ unmasked_tokens,
1909
+ tokens_to_decode_mask,
1910
+ unmasked_tokens_mask,
1911
+ indices,
1912
+ seqlens_tokens_to_decode,
1913
+ seqlens_unmasked_tokens,
1914
+ max_length_of_decoded_tokens,
1915
+ max_length_of_unmasked_tokens,
1916
+ )
1917
+
1918
+ @staticmethod
1919
+ def combine_x_y(
1920
+ tokens_to_decode: Tensor,
1921
+ unmasked_tokens: Tensor,
1922
+ tokens_to_decode_mask: Tensor,
1923
+ unmasked_tokens_mask: Tensor,
1924
+ indices: Tensor,
1925
+ ) -> Tensor:
1926
+ """Reintegrate the separated token sequences into their original order.
1927
+
1928
+ The token masks zero out positions which are not used/needed,
1929
+ and the final scatter step re-applies the original ordering tracked in 'indices'.
1930
+
1931
+ Args:
1932
+ tokens_to_decode: Key/value tokens of shape [B, X_len, D].
1933
+ unmasked_tokens: Query tokens of shape [B, Y_len, D].
1934
+ tokens_to_decode_mask: Binary mask for tokens to decode of shape [B, X_len].
1935
+ unmasked_tokens_mask: Binary mask for unmasked tokens of shape [B, Y_len].
1936
+ indices: Indices for restoring the original token ordering of shape [B, T].
1937
+
1938
+ Returns:
1939
+ A merged tokens tensor of shape [B, T, D] with all tokens in their
1940
+ original positions.
1941
+ """
1942
+ # Get dimensions
1943
+ B, T = indices.shape[0], indices.shape[1]
1944
+ D = tokens_to_decode.shape[-1]
1945
+ tokens = torch.zeros(
1946
+ (B, T, D), dtype=tokens_to_decode.dtype, device=tokens_to_decode.device
1947
+ )
1948
+ tokens[:, -unmasked_tokens.shape[1] :] = (
1949
+ unmasked_tokens * unmasked_tokens_mask.unsqueeze(-1)
1950
+ )
1951
+ tokens[:, : tokens_to_decode.shape[1]] += (
1952
+ tokens_to_decode * tokens_to_decode_mask.unsqueeze(-1)
1953
+ )
1954
+ tokens = tokens.scatter(1, indices[:, :, None].expand_as(tokens), tokens)
1955
+ return tokens
1956
+
1957
+ def is_any_data_to_be_decoded(self, modality_mask: Tensor) -> bool:
1958
+ """Check if any data is to be decoded for a given modality."""
1959
+ return (MaskValue.DECODER.value == modality_mask).any()
1960
+
1961
+ def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
1962
+ """Apply FSDP to the model."""
1963
+ super().apply_fsdp(**fsdp_kwargs)
1964
+ fully_shard(self, **fsdp_kwargs)
1965
+
1966
+
1967
+ class Predictor(PredictorBase):
1968
+ """Predictor module that generates predictions from encoded tokens."""
1969
+
1970
+ cross_attn = True
1971
+
1972
+ def apply_attn(
1973
+ self,
1974
+ x: dict[str, Tensor],
1975
+ timestamps: Tensor,
1976
+ patch_size: int,
1977
+ input_res: int,
1978
+ ) -> dict[str, Tensor]:
1979
+ """Apply attention to the tokens."""
1980
+ tokens_only_dict, original_masks_dict, modalities_to_dims_dict = (
1981
+ self.split_tokens_masks_and_dims(x)
1982
+ )
1983
+ tokens_dict = self.composite_encodings(
1984
+ tokens_only_dict, timestamps, patch_size, input_res
1985
+ )
1986
+ tokens_dict.update(original_masks_dict)
1987
+ all_tokens, mask = self.collapse_and_combine_hwtc(tokens_dict)
1988
+ # X contains the tokens to decode, Y contains the tokens to attend to for context
1989
+ (
1990
+ tokens_to_decode,
1991
+ unmasked_tokens,
1992
+ tokens_to_decode_mask,
1993
+ unmasked_tokens_mask,
1994
+ indices,
1995
+ seqlens_tokens_to_decode,
1996
+ seqlens_unmasked_tokens,
1997
+ max_length_of_tokens_to_decode,
1998
+ max_length_of_unmasked_tokens,
1999
+ ) = self.split_x_y(all_tokens, mask)
2000
+ # Pack x tokens
2001
+ if self.use_flash_attn:
2002
+ og_shape_tokens_to_decode = tokens_to_decode.shape
2003
+ tokens_to_decode = self.pack_tokens(
2004
+ tokens_to_decode, tokens_to_decode_mask.bool()
2005
+ )
2006
+ og_shape_unmasked_tokens = unmasked_tokens.shape
2007
+ unmasked_tokens = self.pack_tokens(
2008
+ unmasked_tokens, unmasked_tokens_mask.bool()
2009
+ )
2010
+ cu_seqlens_tokens_to_decode = get_cumulative_sequence_lengths(
2011
+ seqlens_tokens_to_decode
2012
+ )
2013
+ cu_seqlens_unmasked_tokens = get_cumulative_sequence_lengths(
2014
+ seqlens_unmasked_tokens
2015
+ )
2016
+ else:
2017
+ cu_seqlens_tokens_to_decode = None
2018
+ cu_seqlens_unmasked_tokens = None
2019
+
2020
+ for blk in self.blocks:
2021
+ # note that we are not taking the inverse of the mask, since split_x_y gives us
2022
+ # true values for values we want to take part in attention
2023
+ tokens_to_decode = blk(
2024
+ x=tokens_to_decode,
2025
+ y=unmasked_tokens,
2026
+ attn_mask=(
2027
+ unmasked_tokens_mask.bool() if not self.use_flash_attn else None
2028
+ ), # only for flash attn though this should not be left in
2029
+ cu_seqlens_q=cu_seqlens_tokens_to_decode,
2030
+ cu_seqlens_k=cu_seqlens_unmasked_tokens,
2031
+ max_seqlen_q=max_length_of_tokens_to_decode,
2032
+ max_seqlen_k=max_length_of_unmasked_tokens,
2033
+ )
2034
+
2035
+ if self.use_flash_attn:
2036
+ tokens_to_decode = self.unpack_tokens(
2037
+ tokens_to_decode,
2038
+ tokens_to_decode_mask.bool(),
2039
+ og_shape_tokens_to_decode,
2040
+ )
2041
+ unmasked_tokens = self.unpack_tokens(
2042
+ unmasked_tokens, unmasked_tokens_mask.bool(), og_shape_unmasked_tokens
2043
+ )
2044
+
2045
+ x = self.combine_x_y(
2046
+ tokens_to_decode=tokens_to_decode,
2047
+ unmasked_tokens=unmasked_tokens,
2048
+ tokens_to_decode_mask=tokens_to_decode_mask,
2049
+ unmasked_tokens_mask=unmasked_tokens_mask,
2050
+ indices=indices,
2051
+ )
2052
+ tokens_per_modality_dict = self.split_and_expand_per_modality(
2053
+ x, modalities_to_dims_dict
2054
+ )
2055
+ tokens_per_modality_dict.update(original_masks_dict)
2056
+ return tokens_per_modality_dict
2057
+
2058
+ def forward(
2059
+ self,
2060
+ x: TokensAndMasks,
2061
+ timestamps: Tensor,
2062
+ patch_size: int,
2063
+ input_res: int = BASE_GSD,
2064
+ ) -> TokensAndMasks:
2065
+ """Generate predictions from encoded token representations.
2066
+
2067
+ Args:
2068
+ x: TokensAndMasks containing the encoded tokens to make predictions from
2069
+ timestamps: Timestamps of the tokens
2070
+ patch_size: Patch size of the tokens
2071
+ input_res: Input resolution of the tokens
2072
+
2073
+ Returns:
2074
+ TokensAndMasks containing the predicted tokens and their masks
2075
+ """
2076
+ decoder_emedded_dict = x.as_dict(return_none=False)
2077
+ # Apply Input Norms and encoder to decoder embeds to each modality
2078
+ available_modalities = x.modalities
2079
+ modalities_to_process = get_modalities_to_process(
2080
+ available_modalities, self.supported_modality_names
2081
+ )
2082
+ for modality in modalities_to_process:
2083
+ x_modality = getattr(x, modality)
2084
+ # Although, we do not account for missing tokens both proj and normalize are on token dimension so there is no mixing with real tokens
2085
+ x_modality = self.input_norm(x_modality)
2086
+ x_modality = self.encoder_to_decoder_embed(x_modality)
2087
+ masked_modality_name = x.get_masked_modality_name(modality)
2088
+ decoder_emedded_dict[modality] = x_modality
2089
+ decoder_emedded_dict[masked_modality_name] = getattr(
2090
+ x, masked_modality_name
2091
+ )
2092
+
2093
+ tokens_only_dict = self.add_masks(decoder_emedded_dict)
2094
+ decoder_emedded_dict.update(tokens_only_dict)
2095
+ tokens_and_masks = self.apply_attn(
2096
+ decoder_emedded_dict, timestamps, patch_size, input_res
2097
+ )
2098
+ # TODO: Factor this out into a more readable function
2099
+ output_dict = {}
2100
+ available_modalities = return_modalities_from_dict(tokens_and_masks)
2101
+ modalities_to_process = get_modalities_to_process(
2102
+ available_modalities, self.supported_modality_names
2103
+ )
2104
+ for modality in modalities_to_process:
2105
+ masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
2106
+ modality
2107
+ )
2108
+ modality_mask = tokens_and_masks[masked_modality_name]
2109
+ # patchify masked data
2110
+ per_modality_output_tokens = []
2111
+ modality_data = tokens_and_masks[modality]
2112
+
2113
+ num_band_sets = self.tokenization_config.get_num_bandsets(modality)
2114
+ for idx in range(num_band_sets):
2115
+ per_channel_modality_data = modality_data[..., idx, :]
2116
+ output_data = self.to_output_embed(self.norm(per_channel_modality_data))
2117
+ per_modality_output_tokens.append(output_data)
2118
+ output_dict[modality] = torch.stack(per_modality_output_tokens, dim=-2)
2119
+ output_dict[masked_modality_name] = modality_mask
2120
+ return TokensAndMasks(**output_dict)
2121
+
2122
+
2123
+ @dataclass
2124
+ class EncoderConfig(Config):
2125
+ """Configuration for the Encoder."""
2126
+
2127
+ supported_modality_names: list[str]
2128
+
2129
+ embedding_size: int = 16
2130
+ # This is the base patch size for the patch embedder
2131
+ max_patch_size: int = 8
2132
+ min_patch_size: int = 1
2133
+ num_heads: int = 2
2134
+ mlp_ratio: float = 1.0
2135
+ depth: int = 2
2136
+ drop_path: float = 0.1
2137
+ max_sequence_length: int = 12
2138
+ num_register_tokens: int = 0
2139
+ learnable_channel_embeddings: bool = True
2140
+ random_channel_embeddings: bool = False
2141
+ num_projection_layers: int = 1
2142
+ aggregate_then_project: bool = True
2143
+ use_flash_attn: bool = False
2144
+ frozen_patch_embeddings: bool = False
2145
+ qk_norm: bool = False
2146
+ log_token_norm_stats: bool = False
2147
+ tokenization_config: TokenizationConfig | None = None
2148
+
2149
+ def validate(self) -> None:
2150
+ """Validate the configuration."""
2151
+ if len(self.supported_modalities) == 0:
2152
+ raise ValueError("At least one modality must be added!")
2153
+ else:
2154
+ for modality in self.supported_modalities:
2155
+ if modality not in Modality.values():
2156
+ raise ValueError(f"Modality {modality} is not supported")
2157
+ if self.tokenization_config is not None:
2158
+ self.tokenization_config.validate()
2159
+
2160
+ @property
2161
+ def supported_modalities(self) -> list[ModalitySpec]:
2162
+ """Get the supported modalities."""
2163
+ return get_modality_specs_from_names(self.supported_modality_names)
2164
+
2165
+ def build(self) -> "Encoder":
2166
+ """Build the encoder."""
2167
+ self.validate()
2168
+ kwargs = self.as_dict(exclude_none=True, recurse=False)
2169
+ # supported_modality_names is replaced by supported_modalities
2170
+ kwargs.pop("supported_modality_names")
2171
+ kwargs["supported_modalities"] = self.supported_modalities
2172
+ logger.info(f"Encoder kwargs: {kwargs}")
2173
+ return Encoder(**kwargs)
2174
+
2175
+
2176
+ @dataclass
2177
+ class PredictorConfig(Config):
2178
+ """Configuration for the Predictor."""
2179
+
2180
+ supported_modality_names: list[str]
2181
+ encoder_embedding_size: int = 16
2182
+ decoder_embedding_size: int = 16
2183
+ depth: int = 2
2184
+ mlp_ratio: float = 1.0
2185
+ num_heads: int = 2
2186
+ max_sequence_length: int = 12
2187
+ drop_path: float = 0.0
2188
+ learnable_channel_embeddings: bool = True
2189
+ random_channel_embeddings: bool = False
2190
+ output_embedding_size: int | None = None
2191
+ use_flash_attn: bool = False
2192
+ qk_norm: bool = False
2193
+ tokenization_config: TokenizationConfig | None = None
2194
+
2195
+ def validate(self) -> None:
2196
+ """Validate the configuration."""
2197
+ if len(self.supported_modalities) == 0:
2198
+ raise ValueError("At least one modality must be added!")
2199
+ else:
2200
+ for modality in self.supported_modalities:
2201
+ if modality not in Modality.values():
2202
+ raise ValueError(f"Modality {modality} is not supported")
2203
+ if self.tokenization_config is not None:
2204
+ self.tokenization_config.validate()
2205
+
2206
+ @property
2207
+ def supported_modalities(self) -> list[ModalitySpec]:
2208
+ """Get the supported modalities."""
2209
+ return get_modality_specs_from_names(self.supported_modality_names)
2210
+
2211
+ def build(self) -> "PredictorBase":
2212
+ """Build the predictor."""
2213
+ self.validate()
2214
+ kwargs = self.as_dict(exclude_none=True, recurse=False)
2215
+ # supported_modality_names is replaced by supported_modalities
2216
+ kwargs.pop("supported_modality_names")
2217
+ kwargs["supported_modalities"] = self.supported_modalities
2218
+ logger.info(f"Predictor kwargs: {kwargs}")
2219
+ return Predictor(**kwargs)