olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- olmoearth_pretrain_minimal/__init__.py +16 -0
- olmoearth_pretrain_minimal/model_loader.py +123 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
- olmoearth_pretrain_minimal/test.py +51 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2219 @@
|
|
|
1
|
+
"""Model code for the OlmoEarth Pretrain model."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import StrEnum
|
|
7
|
+
from typing import Any, NamedTuple
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops import rearrange, reduce, repeat
|
|
11
|
+
from torch import Tensor, nn
|
|
12
|
+
from torch.distributed.fsdp import fully_shard
|
|
13
|
+
|
|
14
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
|
|
15
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import (
|
|
16
|
+
BASE_GSD,
|
|
17
|
+
Modality,
|
|
18
|
+
ModalitySpec,
|
|
19
|
+
get_modality_specs_from_names,
|
|
20
|
+
)
|
|
21
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample, MaskValue
|
|
22
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.attention import Block
|
|
23
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.encodings import (
|
|
24
|
+
get_1d_sincos_pos_encoding,
|
|
25
|
+
get_2d_sincos_pos_encoding_with_resolution,
|
|
26
|
+
get_month_encoding_table,
|
|
27
|
+
)
|
|
28
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_patch_embed import (
|
|
29
|
+
FlexiPatchEmbed,
|
|
30
|
+
FlexiPatchReconstruction,
|
|
31
|
+
)
|
|
32
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.tokenization import TokenizationConfig
|
|
33
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import get_cumulative_sequence_lengths
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_modalities_to_process(
|
|
39
|
+
available_modalities: list[str], supported_modality_names: list[str]
|
|
40
|
+
) -> list[str]:
|
|
41
|
+
"""Get the modalities to process."""
|
|
42
|
+
modalities_to_process = set(supported_modality_names).intersection(
|
|
43
|
+
set(available_modalities)
|
|
44
|
+
)
|
|
45
|
+
return list(modalities_to_process)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def return_modalities_from_dict(
|
|
49
|
+
per_modality_input_tokens: dict[str, Tensor],
|
|
50
|
+
) -> list[str]:
|
|
51
|
+
"""Return the modalities from a dictionary of per modality input tokens."""
|
|
52
|
+
return [
|
|
53
|
+
key for key in per_modality_input_tokens.keys() if not key.endswith("_mask")
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class PoolingType(StrEnum):
|
|
58
|
+
"""Strategy for pooling the tokens."""
|
|
59
|
+
|
|
60
|
+
MAX = "max"
|
|
61
|
+
MEAN = "mean"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TokensAndMasks(NamedTuple):
|
|
65
|
+
"""Output to compute the loss on.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
sentinel2: sentinel 2 data of shape (B, P_H, P_W, T, Band_Sets, D)
|
|
69
|
+
sentinel2_mask: sentinel 2 mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
|
|
70
|
+
sentinel1: sentinel 1 data of shape (B, P_H, P_W, T, Band_Sets, D)
|
|
71
|
+
sentinel1_mask: sentinel 1 mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
|
|
72
|
+
worldcover: worldcover data of shape (B, P_H, P_W, T, Band_Sets, D)
|
|
73
|
+
worldcover_mask: worldcover mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
|
|
74
|
+
latlon: lat lon data containing geographical coordinates
|
|
75
|
+
latlon_mask: lat lon mask indicating which coordinates are masked/unmasked
|
|
76
|
+
openstreetmap_raster: openstreetmap raster data of shape (B, P_H, P_W, T, Band_Sets, D)
|
|
77
|
+
openstreetmap_raster_mask: openstreetmap raster mask indicating which tokens are masked/unmasked (B, P_H, P_W, T, Band_Sets)
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
sentinel2_l2a: Tensor | None = None
|
|
81
|
+
sentinel2_l2a_mask: Tensor | None = None
|
|
82
|
+
sentinel1: Tensor | None = None
|
|
83
|
+
sentinel1_mask: Tensor | None = None
|
|
84
|
+
worldcover: Tensor | None = None
|
|
85
|
+
worldcover_mask: Tensor | None = None
|
|
86
|
+
latlon: Tensor | None = None
|
|
87
|
+
latlon_mask: Tensor | None = None
|
|
88
|
+
openstreetmap_raster: Tensor | None = None
|
|
89
|
+
openstreetmap_raster_mask: Tensor | None = None
|
|
90
|
+
srtm: Tensor | None = None
|
|
91
|
+
srtm_mask: Tensor | None = None
|
|
92
|
+
landsat: Tensor | None = None
|
|
93
|
+
landsat_mask: Tensor | None = None
|
|
94
|
+
naip: Tensor | None = None
|
|
95
|
+
naip_mask: Tensor | None = None
|
|
96
|
+
naip_10: Tensor | None = None
|
|
97
|
+
naip_10_mask: Tensor | None = None
|
|
98
|
+
gse: Tensor | None = None
|
|
99
|
+
gse_mask: Tensor | None = None
|
|
100
|
+
cdl: Tensor | None = None
|
|
101
|
+
cdl_mask: Tensor | None = None
|
|
102
|
+
worldpop: Tensor | None = None
|
|
103
|
+
worldpop_mask: Tensor | None = None
|
|
104
|
+
worldcereal: Tensor | None = None
|
|
105
|
+
worldcereal_mask: Tensor | None = None
|
|
106
|
+
wri_canopy_height_map: Tensor | None = None
|
|
107
|
+
wri_canopy_height_map_mask: Tensor | None = None
|
|
108
|
+
era5_10: Tensor | None = None
|
|
109
|
+
era5_10_mask: Tensor | None = None
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def device(self) -> torch.device:
|
|
113
|
+
"""Get the device of the tokens and masks."""
|
|
114
|
+
if self.sentinel2_l2a is not None:
|
|
115
|
+
return self.sentinel2_l2a.device
|
|
116
|
+
else:
|
|
117
|
+
# look for any other modality that is not None
|
|
118
|
+
for modality in self._fields:
|
|
119
|
+
if getattr(self, modality) is not None:
|
|
120
|
+
return getattr(self, modality).device
|
|
121
|
+
raise ValueError("No data to get device from")
|
|
122
|
+
|
|
123
|
+
# TODO: It seems like we want a lot of our named tuples to have this functionality so we should probably create a utility base class for the named tuples and double subclass
|
|
124
|
+
@classmethod
|
|
125
|
+
def get_masked_modality_name(cls, modality: str) -> str:
|
|
126
|
+
"""Get the masked modality name."""
|
|
127
|
+
return f"{modality}_mask"
|
|
128
|
+
|
|
129
|
+
def as_dict(self, return_none: bool = True) -> dict[str, Any]:
|
|
130
|
+
"""Convert the namedtuple to a dictionary.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Dictionary representation of the namedtuple.
|
|
134
|
+
"""
|
|
135
|
+
return_dict = {}
|
|
136
|
+
for field in self._fields:
|
|
137
|
+
val = getattr(self, field)
|
|
138
|
+
if return_none:
|
|
139
|
+
return_dict[field] = val
|
|
140
|
+
else:
|
|
141
|
+
if val is not None:
|
|
142
|
+
return_dict[field] = val
|
|
143
|
+
return return_dict
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def modalities(self) -> list[str]:
|
|
147
|
+
"""Return all data fields."""
|
|
148
|
+
return [
|
|
149
|
+
x
|
|
150
|
+
for x in self._fields
|
|
151
|
+
if not x.endswith("mask") and getattr(self, x) is not None
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
def get_shape_dict(self) -> dict[str, tuple]:
|
|
155
|
+
"""Return a dictionary of the shapes of the fields."""
|
|
156
|
+
return {x: getattr(self, x).shape for x in self._fields}
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def _flatten(x: Tensor) -> Tensor:
|
|
160
|
+
return rearrange(x, "b ... d -> b (...) d")
|
|
161
|
+
|
|
162
|
+
def flatten_tokens_and_masks(
|
|
163
|
+
self, return_lists: bool = False
|
|
164
|
+
) -> tuple[Tensor, Tensor]:
|
|
165
|
+
"""Return the flattened tokens and masks.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
return_lists: If True, return the original lists before concatenation.
|
|
169
|
+
If False, return concatenated tensors.
|
|
170
|
+
|
|
171
|
+
Tokens will have shape [B, T, D] and masks will have shape [B, T]
|
|
172
|
+
"""
|
|
173
|
+
flattened_x, flattened_masks = [], []
|
|
174
|
+
for attr_name in self.modalities:
|
|
175
|
+
mask_attr_name = self.get_masked_modality_name(attr_name)
|
|
176
|
+
attr = getattr(self, attr_name)
|
|
177
|
+
masked_attr = getattr(self, mask_attr_name)
|
|
178
|
+
if attr is not None:
|
|
179
|
+
if masked_attr is None:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Can't have present {attr_name} but None {mask_attr_name}"
|
|
182
|
+
)
|
|
183
|
+
masked_attr = masked_attr.unsqueeze(dim=-1)
|
|
184
|
+
flattened_x.append(self._flatten(attr))
|
|
185
|
+
flattened_masks.append(self._flatten(masked_attr))
|
|
186
|
+
|
|
187
|
+
if return_lists:
|
|
188
|
+
# Remove the extra dimension from the masks
|
|
189
|
+
flattened_masks = [mask[:, :, 0] for mask in flattened_masks]
|
|
190
|
+
return flattened_x, flattened_masks
|
|
191
|
+
|
|
192
|
+
x = torch.cat(flattened_x, dim=1)
|
|
193
|
+
masks = torch.cat(flattened_masks, dim=1)[:, :, 0]
|
|
194
|
+
return x, masks
|
|
195
|
+
|
|
196
|
+
def pool_spatially_and_concat_modalities(self) -> Tensor:
|
|
197
|
+
"""Pool the modalities across time to get spatial features and concatenate the features."""
|
|
198
|
+
spatial_stacked_features = []
|
|
199
|
+
for attr_name in self.modalities:
|
|
200
|
+
if Modality.get(attr_name).is_spatial:
|
|
201
|
+
mask_attr_name = self.get_masked_modality_name(attr_name)
|
|
202
|
+
masked_attr = getattr(self, mask_attr_name)
|
|
203
|
+
if masked_attr is None:
|
|
204
|
+
continue
|
|
205
|
+
if (masked_attr == MaskValue.ONLINE_ENCODER.value).all():
|
|
206
|
+
attr = getattr(self, attr_name)
|
|
207
|
+
# only mean in temporal dimension
|
|
208
|
+
pooled_attr = torch.mean(attr, dim=(-3))
|
|
209
|
+
spatial_stacked_features.append(pooled_attr)
|
|
210
|
+
if len(spatial_stacked_features) == 0:
|
|
211
|
+
raise ValueError("Missing unmasked spatial modalities for spatial pooling.")
|
|
212
|
+
# Concatenate along the band sets dimension instead of stacking
|
|
213
|
+
spatial_stacked_features = torch.cat(spatial_stacked_features, dim=-2)
|
|
214
|
+
return spatial_stacked_features
|
|
215
|
+
|
|
216
|
+
def pool_spatially(self, pooling_type: PoolingType) -> Tensor:
|
|
217
|
+
"""Pool the modalities across time to get spatial features."""
|
|
218
|
+
spatial_average = []
|
|
219
|
+
for attr_name in self.modalities:
|
|
220
|
+
if Modality.get(attr_name).is_spatial:
|
|
221
|
+
mask_attr_name = self.get_masked_modality_name(attr_name)
|
|
222
|
+
masked_attr = getattr(self, mask_attr_name)
|
|
223
|
+
if masked_attr is None:
|
|
224
|
+
continue
|
|
225
|
+
if (masked_attr == MaskValue.ONLINE_ENCODER.value).all():
|
|
226
|
+
attr = getattr(self, attr_name)
|
|
227
|
+
# pool across time and bandset dimensions
|
|
228
|
+
if pooling_type == PoolingType.MEAN:
|
|
229
|
+
spatial_average.append(torch.mean(attr, dim=(-2, -3)))
|
|
230
|
+
else:
|
|
231
|
+
spatial_average.append(
|
|
232
|
+
torch.max(torch.max(attr, dim=-2).values, dim=-2).values
|
|
233
|
+
)
|
|
234
|
+
if len(spatial_average) == 0:
|
|
235
|
+
raise ValueError("Missing unmasked spatial modalities for spatial pooling.")
|
|
236
|
+
spatial_average_t = torch.stack(spatial_average, dim=-1)
|
|
237
|
+
if pooling_type == PoolingType.MEAN:
|
|
238
|
+
return spatial_average_t.mean(dim=-1)
|
|
239
|
+
else:
|
|
240
|
+
return spatial_average_t.max(dim=-1).values
|
|
241
|
+
|
|
242
|
+
def pool_instance_wise(self, pooling_type: PoolingType) -> Tensor:
|
|
243
|
+
"""Pool all the tokens in the instance."""
|
|
244
|
+
x, mask = self.flatten_tokens_and_masks()
|
|
245
|
+
# 1s for online encoder, 0s elsewhere
|
|
246
|
+
mask = (mask == MaskValue.ONLINE_ENCODER.value).long()
|
|
247
|
+
x_for_pooling = x * mask.unsqueeze(-1)
|
|
248
|
+
if pooling_type == PoolingType.MAX:
|
|
249
|
+
x_for_pooling = x_for_pooling.masked_fill(
|
|
250
|
+
~mask.bool().unsqueeze(-1), -float("inf")
|
|
251
|
+
)
|
|
252
|
+
return x_for_pooling.max(dim=1).values
|
|
253
|
+
elif pooling_type == PoolingType.MEAN:
|
|
254
|
+
num_encoded_tokens = torch.sum(mask, -1, keepdim=True)
|
|
255
|
+
logger.debug(f"num_encoded_tokens: {num_encoded_tokens}")
|
|
256
|
+
if (num_encoded_tokens == 0).any():
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"num_encoded_tokens is 0 for some samples {num_encoded_tokens}"
|
|
259
|
+
)
|
|
260
|
+
return x_for_pooling.sum(dim=1) / num_encoded_tokens
|
|
261
|
+
else:
|
|
262
|
+
raise ValueError(f"Invalid pooling type: {pooling_type}")
|
|
263
|
+
|
|
264
|
+
def pool_unmasked_tokens(
|
|
265
|
+
self,
|
|
266
|
+
pooling_type: PoolingType = PoolingType.MAX,
|
|
267
|
+
spatial_pooling: bool = False,
|
|
268
|
+
concat_features: bool = False,
|
|
269
|
+
) -> Tensor:
|
|
270
|
+
"""Pool the unmasked tokens.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
pooling_type: Pooling type for the tokens
|
|
274
|
+
spatial_pooling: Whether to keep the spatial dimensions when pooling. If true,
|
|
275
|
+
this expects the masks within a spatial modality to be consistent (e.g. all
|
|
276
|
+
s2 tokens would have the same mask.)
|
|
277
|
+
concat_features: Whether to concatenate the features instead of averaging them, only enabled for spatial pooling as of now,
|
|
278
|
+
requires no masked out tokens
|
|
279
|
+
"""
|
|
280
|
+
if concat_features and spatial_pooling:
|
|
281
|
+
return self.pool_spatially_and_concat_modalities()
|
|
282
|
+
if concat_features:
|
|
283
|
+
raise ValueError("concat_features is not supported for non-spatial pooling")
|
|
284
|
+
if not spatial_pooling:
|
|
285
|
+
return self.pool_instance_wise(pooling_type)
|
|
286
|
+
else:
|
|
287
|
+
return self.pool_spatially(pooling_type)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class ProjectAndAggregate(nn.Module):
|
|
291
|
+
"""Module that applies a linear projection to tokens and masks."""
|
|
292
|
+
|
|
293
|
+
def __init__(
|
|
294
|
+
self,
|
|
295
|
+
embedding_size: int,
|
|
296
|
+
num_layers: int,
|
|
297
|
+
aggregate_then_project: bool = True,
|
|
298
|
+
):
|
|
299
|
+
"""Initialize the linear module.
|
|
300
|
+
|
|
301
|
+
embedding_size: The embedding size of the input TokensAndMasks
|
|
302
|
+
num_layers: The number of layers to use in the projection. If >1, then
|
|
303
|
+
a ReLU activation will be applied between layers
|
|
304
|
+
aggregate_then_project: If True, then we will average the tokens before applying
|
|
305
|
+
the projection. If False, we will apply the projection first.
|
|
306
|
+
"""
|
|
307
|
+
super().__init__()
|
|
308
|
+
projections = [nn.Linear(embedding_size, embedding_size)]
|
|
309
|
+
for _ in range(1, num_layers):
|
|
310
|
+
projections.append(nn.ReLU())
|
|
311
|
+
projections.append(nn.Linear(embedding_size, embedding_size))
|
|
312
|
+
self.projection = nn.Sequential(*projections)
|
|
313
|
+
self.aggregate_then_project = aggregate_then_project
|
|
314
|
+
|
|
315
|
+
def apply_aggregate_then_project(
|
|
316
|
+
self, x: TokensAndMasks | torch.Tensor
|
|
317
|
+
) -> torch.Tensor:
|
|
318
|
+
"""Apply the aggregate operation to the input."""
|
|
319
|
+
if isinstance(x, TokensAndMasks):
|
|
320
|
+
pooled_for_contrastive = x.pool_unmasked_tokens(
|
|
321
|
+
PoolingType.MEAN, spatial_pooling=False
|
|
322
|
+
)
|
|
323
|
+
elif isinstance(x, torch.Tensor):
|
|
324
|
+
pooled_for_contrastive = reduce(x, "b ... d -> b d", "mean")
|
|
325
|
+
else:
|
|
326
|
+
raise ValueError(f"Invalid input type: {type(x)}")
|
|
327
|
+
return self.projection(pooled_for_contrastive)
|
|
328
|
+
|
|
329
|
+
def apply_project_then_aggregate(
|
|
330
|
+
self, x: TokensAndMasks | torch.Tensor
|
|
331
|
+
) -> torch.Tensor:
|
|
332
|
+
"""Apply the project operation to the input then aggregate."""
|
|
333
|
+
if isinstance(x, TokensAndMasks):
|
|
334
|
+
decoder_emedded_dict = x._asdict()
|
|
335
|
+
for modality in x.modalities:
|
|
336
|
+
x_modality = getattr(x, modality)
|
|
337
|
+
# Are these normalizations masked correctly?
|
|
338
|
+
x_modality = self.projection(x_modality)
|
|
339
|
+
masked_modality_name = x.get_masked_modality_name(modality)
|
|
340
|
+
decoder_emedded_dict[modality] = x_modality
|
|
341
|
+
decoder_emedded_dict[masked_modality_name] = getattr(
|
|
342
|
+
x, masked_modality_name
|
|
343
|
+
)
|
|
344
|
+
x_projected = TokensAndMasks(**decoder_emedded_dict)
|
|
345
|
+
projected_pooled = x_projected.pool_unmasked_tokens(
|
|
346
|
+
PoolingType.MEAN, spatial_pooling=False
|
|
347
|
+
)
|
|
348
|
+
elif isinstance(x, torch.Tensor):
|
|
349
|
+
x_projected = self.projection(x)
|
|
350
|
+
projected_pooled = reduce(x_projected, "b ... d -> b d", "mean")
|
|
351
|
+
else:
|
|
352
|
+
raise ValueError(f"Invalid input type: {type(x)}")
|
|
353
|
+
return projected_pooled
|
|
354
|
+
|
|
355
|
+
def forward(self, x: TokensAndMasks | torch.Tensor) -> torch.Tensor:
|
|
356
|
+
"""Apply a (non)linear projection to an input TokensAndMasks.
|
|
357
|
+
|
|
358
|
+
This can be applied either before or after pooling the tokens.
|
|
359
|
+
"""
|
|
360
|
+
return (
|
|
361
|
+
self.apply_aggregate_then_project(x)
|
|
362
|
+
if self.aggregate_then_project
|
|
363
|
+
else self.apply_project_then_aggregate(x)
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class MultiModalPatchEmbeddings(nn.Module):
|
|
368
|
+
"""Module that patchifies and encodes the input data for multiple modalities."""
|
|
369
|
+
|
|
370
|
+
def __init__(
|
|
371
|
+
self,
|
|
372
|
+
supported_modality_names: list[str],
|
|
373
|
+
max_patch_size: int,
|
|
374
|
+
embedding_size: int,
|
|
375
|
+
tokenization_config: TokenizationConfig | None = None,
|
|
376
|
+
):
|
|
377
|
+
"""Initialize the patch embeddings.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
supported_modality_names: Which modalities from Modality this model
|
|
381
|
+
instantiation supports
|
|
382
|
+
max_patch_size: Maximum size of patches
|
|
383
|
+
embedding_size: Size of embeddings
|
|
384
|
+
tokenization_config: Optional config for custom band groupings
|
|
385
|
+
"""
|
|
386
|
+
super().__init__()
|
|
387
|
+
self.max_patch_size = max_patch_size
|
|
388
|
+
self.embedding_size = embedding_size
|
|
389
|
+
self.supported_modality_names = supported_modality_names
|
|
390
|
+
self.tokenization_config = tokenization_config or TokenizationConfig()
|
|
391
|
+
# TODO: want to be able to remove certain bands and modalities
|
|
392
|
+
self.per_modality_embeddings = nn.ModuleDict({})
|
|
393
|
+
|
|
394
|
+
for modality in self.supported_modality_names:
|
|
395
|
+
self.per_modality_embeddings[modality] = (
|
|
396
|
+
self._get_patch_embedding_module_for_modality(modality)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# For every patch embedding module we want to create a unique buffer
|
|
400
|
+
# for selecting the correct band indices from the data tensor
|
|
401
|
+
for modality in self.supported_modality_names:
|
|
402
|
+
for idx, bandset_indices in enumerate(
|
|
403
|
+
self.tokenization_config.get_bandset_indices(modality)
|
|
404
|
+
):
|
|
405
|
+
buffer_name = self._get_buffer_name(modality, idx)
|
|
406
|
+
banset_indices_tensor = torch.tensor(bandset_indices, dtype=torch.long)
|
|
407
|
+
self.register_buffer(
|
|
408
|
+
buffer_name, banset_indices_tensor, persistent=False
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Create a dictionary of per modality index tensors to do index select with registered buffer
|
|
412
|
+
|
|
413
|
+
@staticmethod
|
|
414
|
+
def _get_buffer_name(modality: str, idx: int) -> str:
|
|
415
|
+
"""Get the buffer name."""
|
|
416
|
+
return f"{modality}__{idx}_buffer"
|
|
417
|
+
|
|
418
|
+
@staticmethod
|
|
419
|
+
def _get_embedding_module_name(modality: str, idx: int) -> str:
|
|
420
|
+
"""Get the embedding module name.
|
|
421
|
+
|
|
422
|
+
Module Dicts require string keys
|
|
423
|
+
"""
|
|
424
|
+
return f"{modality}__{idx}"
|
|
425
|
+
|
|
426
|
+
def _get_patch_embedding_module_for_modality(self, modality: str) -> nn.Module:
|
|
427
|
+
"""Get the patch embedding module for a modality."""
|
|
428
|
+
modality_spec = Modality.get(modality)
|
|
429
|
+
# Get bandset indices from tokenization config (may be overridden)
|
|
430
|
+
bandset_indices = self.tokenization_config.get_bandset_indices(modality)
|
|
431
|
+
|
|
432
|
+
# Based on the modality name we choose the way to embed the data
|
|
433
|
+
# I likely will need to know about what the embedding strategy is in the forward as well
|
|
434
|
+
# Static modality
|
|
435
|
+
if not modality_spec.is_spatial:
|
|
436
|
+
# static in space
|
|
437
|
+
return nn.ModuleDict(
|
|
438
|
+
{
|
|
439
|
+
self._get_embedding_module_name(modality, idx): nn.Linear(
|
|
440
|
+
len(channel_set_idxs), self.embedding_size
|
|
441
|
+
)
|
|
442
|
+
for idx, channel_set_idxs in enumerate(bandset_indices)
|
|
443
|
+
}
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
return nn.ModuleDict(
|
|
447
|
+
{
|
|
448
|
+
self._get_embedding_module_name(modality, idx): FlexiPatchEmbed(
|
|
449
|
+
in_chans=len(channel_set_idxs),
|
|
450
|
+
embedding_size=self.embedding_size,
|
|
451
|
+
patch_size_at_16=self.max_patch_size,
|
|
452
|
+
modality_spec=modality_spec,
|
|
453
|
+
)
|
|
454
|
+
for idx, channel_set_idxs in enumerate(bandset_indices)
|
|
455
|
+
}
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def apply_embedding_to_modality(
|
|
459
|
+
self,
|
|
460
|
+
modality: str,
|
|
461
|
+
input_data: MaskedOlmoEarthSample,
|
|
462
|
+
patch_size: int,
|
|
463
|
+
fast_pass: bool = False,
|
|
464
|
+
) -> tuple[Tensor, Tensor]:
|
|
465
|
+
"""Apply embedding to a modality."""
|
|
466
|
+
logger.debug(f"applying embedding to modality:{modality}")
|
|
467
|
+
masked_modality_name = input_data.get_masked_modality_name(modality)
|
|
468
|
+
modality_mask = getattr(input_data, masked_modality_name)
|
|
469
|
+
modality_data = getattr(input_data, modality)
|
|
470
|
+
|
|
471
|
+
modality_spec = Modality.get(modality)
|
|
472
|
+
num_band_sets = self.tokenization_config.get_num_bandsets(modality)
|
|
473
|
+
|
|
474
|
+
modality_tokens, modality_masks = [], []
|
|
475
|
+
for idx in range(num_band_sets):
|
|
476
|
+
modality_specific_kwargs = {}
|
|
477
|
+
if not modality_spec.is_spatial:
|
|
478
|
+
# static in time
|
|
479
|
+
token_mask = modality_mask[..., idx]
|
|
480
|
+
else:
|
|
481
|
+
token_mask = modality_mask[
|
|
482
|
+
:,
|
|
483
|
+
0 :: patch_size * modality_spec.image_tile_size_factor,
|
|
484
|
+
0 :: patch_size * modality_spec.image_tile_size_factor,
|
|
485
|
+
...,
|
|
486
|
+
idx,
|
|
487
|
+
]
|
|
488
|
+
modality_specific_kwargs = {"patch_size": patch_size}
|
|
489
|
+
# In the fast pass we want to the sync that comes with checking for online encoder
|
|
490
|
+
if fast_pass or (token_mask == MaskValue.ONLINE_ENCODER.value).any():
|
|
491
|
+
buffer_name = self._get_buffer_name(modality, idx)
|
|
492
|
+
patchified_data = torch.index_select(
|
|
493
|
+
modality_data, -1, getattr(self, buffer_name)
|
|
494
|
+
)
|
|
495
|
+
embedding_module = self.per_modality_embeddings[modality][
|
|
496
|
+
self._get_embedding_module_name(modality, idx)
|
|
497
|
+
]
|
|
498
|
+
patchified_data = embedding_module(
|
|
499
|
+
patchified_data, **modality_specific_kwargs
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
mask_shape = token_mask.shape + (self.embedding_size,)
|
|
503
|
+
patchified_data = torch.zeros(
|
|
504
|
+
mask_shape, dtype=modality_data.dtype, device=token_mask.device
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
modality_tokens.append(patchified_data)
|
|
508
|
+
modality_masks.append(token_mask)
|
|
509
|
+
return torch.stack(modality_tokens, dim=-2), torch.stack(modality_masks, dim=-1)
|
|
510
|
+
|
|
511
|
+
@staticmethod
|
|
512
|
+
def is_any_data_seen_by_encoder(modality_mask: Tensor) -> bool:
|
|
513
|
+
"""Check if any data is seen by the encoder."""
|
|
514
|
+
return (MaskValue.ONLINE_ENCODER.value == modality_mask).any()
|
|
515
|
+
|
|
516
|
+
def apply_compile(self) -> None:
|
|
517
|
+
"""Apply torch.compile to the model."""
|
|
518
|
+
self.compile(dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
|
|
519
|
+
|
|
520
|
+
def forward(
|
|
521
|
+
self,
|
|
522
|
+
input_data: MaskedOlmoEarthSample,
|
|
523
|
+
patch_size: int,
|
|
524
|
+
fast_pass: bool = False,
|
|
525
|
+
) -> dict[str, Tensor]:
|
|
526
|
+
"""Return flexibly patchified embeddings for each modality of the input data.
|
|
527
|
+
|
|
528
|
+
Given a [B, H, W, (T), C] inputs, returns a [B, H, W, (T), b_s, D] output.
|
|
529
|
+
|
|
530
|
+
We assume that the spatial masks are consistent for the given patch size,
|
|
531
|
+
so that if patch_size == 2 then one possible mask would be
|
|
532
|
+
[0, 0, 1, 1]
|
|
533
|
+
[0, 0, 1, 1]
|
|
534
|
+
[1, 1, 0, 0]
|
|
535
|
+
[1, 1, 0, 0]
|
|
536
|
+
for the H, W dimensions
|
|
537
|
+
"""
|
|
538
|
+
output_dict = {}
|
|
539
|
+
modalities_to_process = get_modalities_to_process(
|
|
540
|
+
input_data.modalities, self.supported_modality_names
|
|
541
|
+
)
|
|
542
|
+
for modality in modalities_to_process:
|
|
543
|
+
modality_tokens, modality_masks = self.apply_embedding_to_modality(
|
|
544
|
+
modality, input_data, patch_size, fast_pass
|
|
545
|
+
)
|
|
546
|
+
output_dict[modality] = modality_tokens
|
|
547
|
+
modality_mask_name = input_data.get_masked_modality_name(modality)
|
|
548
|
+
output_dict[modality_mask_name] = modality_masks
|
|
549
|
+
return output_dict
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class Reconstructor(nn.Module):
|
|
553
|
+
"""Module that patchifies and encodes the input data."""
|
|
554
|
+
|
|
555
|
+
def __init__(
|
|
556
|
+
self,
|
|
557
|
+
decoder: nn.Module,
|
|
558
|
+
supported_modalities: list[ModalitySpec],
|
|
559
|
+
max_patch_size: int,
|
|
560
|
+
tokenization_config: TokenizationConfig | None = None,
|
|
561
|
+
):
|
|
562
|
+
"""Initialize the patch embeddings.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
decoder: Predictor nn module to use on before reconstructor on input
|
|
566
|
+
supported_modalities: Which modalities from Modality this model
|
|
567
|
+
instantiation supports
|
|
568
|
+
max_patch_size: Maximum size of patches
|
|
569
|
+
tokenization_config: Optional config for custom band groupings
|
|
570
|
+
"""
|
|
571
|
+
super().__init__()
|
|
572
|
+
self.max_patch_size = max_patch_size
|
|
573
|
+
self.embedding_size = decoder.output_embedding_size
|
|
574
|
+
self.supported_modalities = supported_modalities
|
|
575
|
+
self.tokenization_config = tokenization_config or TokenizationConfig()
|
|
576
|
+
self.decoder = decoder
|
|
577
|
+
# TODO: want to be able to remove certain bands and modalities
|
|
578
|
+
self.per_modality_reconstructions = nn.ModuleDict({})
|
|
579
|
+
for modality in self.supported_modalities:
|
|
580
|
+
self.per_modality_reconstructions[modality.name] = (
|
|
581
|
+
self._get_patch_reconstruction_module_for_modality(modality)
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
def apply_compile(self) -> None:
|
|
585
|
+
"""Apply torch.compile to the model."""
|
|
586
|
+
self.decoder.apply_compile()
|
|
587
|
+
|
|
588
|
+
def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
|
|
589
|
+
"""Apply FSDP to the model."""
|
|
590
|
+
self.decoder.apply_fsdp(**fsdp_kwargs)
|
|
591
|
+
|
|
592
|
+
@staticmethod
|
|
593
|
+
def _get_reconstruction_module_name(modality: str, idx: int) -> str:
|
|
594
|
+
"""Get the reconstruction module name.
|
|
595
|
+
|
|
596
|
+
Module Dicts require string keys
|
|
597
|
+
"""
|
|
598
|
+
return f"{modality}__{idx}"
|
|
599
|
+
|
|
600
|
+
def _get_patch_reconstruction_module_for_modality(
|
|
601
|
+
self, modality: ModalitySpec
|
|
602
|
+
) -> nn.Module:
|
|
603
|
+
"""Get the patch reconstruction module for a modality."""
|
|
604
|
+
# Get bandset indices from tokenization config (may be overridden)
|
|
605
|
+
bandset_indices = self.tokenization_config.get_bandset_indices(modality.name)
|
|
606
|
+
|
|
607
|
+
# Based on the modality name we choose the way to embed the data
|
|
608
|
+
# I likely will need to know about what the embedding strategy is in the forward as well
|
|
609
|
+
# Static modality
|
|
610
|
+
if modality.get_tile_resolution() == 0:
|
|
611
|
+
# static in space
|
|
612
|
+
return nn.ModuleDict(
|
|
613
|
+
{
|
|
614
|
+
self._get_reconstruction_module_name(modality.name, idx): nn.Linear(
|
|
615
|
+
self.embedding_size, len(channel_set_idxs)
|
|
616
|
+
)
|
|
617
|
+
for idx, channel_set_idxs in enumerate(bandset_indices)
|
|
618
|
+
}
|
|
619
|
+
)
|
|
620
|
+
else:
|
|
621
|
+
return nn.ModuleDict(
|
|
622
|
+
{
|
|
623
|
+
self._get_reconstruction_module_name(
|
|
624
|
+
modality.name, idx
|
|
625
|
+
): FlexiPatchReconstruction(
|
|
626
|
+
out_chans=len(channel_set_idxs),
|
|
627
|
+
embedding_size=self.embedding_size,
|
|
628
|
+
max_patch_size=self.max_patch_size,
|
|
629
|
+
)
|
|
630
|
+
for idx, channel_set_idxs in enumerate(bandset_indices)
|
|
631
|
+
}
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# TODO: Likely we want a single object that stores all the data related configuration etc per modality including channel grous bands patch size etc
|
|
635
|
+
def apply_reconstruction_to_modality(
|
|
636
|
+
self, modality: str, input_data: TokensAndMasks, patch_size: int
|
|
637
|
+
) -> tuple[Tensor, Tensor]:
|
|
638
|
+
"""Apply reconstruction to a modality."""
|
|
639
|
+
masked_modality_name = input_data.get_masked_modality_name(modality)
|
|
640
|
+
modality_mask = getattr(input_data, masked_modality_name)
|
|
641
|
+
modality_data = getattr(input_data, modality)
|
|
642
|
+
|
|
643
|
+
modality_spec = Modality.get(modality)
|
|
644
|
+
bandset_indices = self.tokenization_config.get_bandset_indices(modality)
|
|
645
|
+
|
|
646
|
+
# x: Input tensor with shape [b, h, w, (t), b_s, d]
|
|
647
|
+
modality_tokens, modality_masks = [], []
|
|
648
|
+
for idx, channel_set_indices in enumerate(bandset_indices):
|
|
649
|
+
data = modality_data[..., idx, :]
|
|
650
|
+
masks = modality_mask[..., idx]
|
|
651
|
+
r_model = self.per_modality_reconstructions[modality][
|
|
652
|
+
self._get_reconstruction_module_name(modality, idx)
|
|
653
|
+
]
|
|
654
|
+
if modality_spec.get_tile_resolution() == 0:
|
|
655
|
+
data = r_model(data)
|
|
656
|
+
else:
|
|
657
|
+
data = r_model(data, patch_size=patch_size)
|
|
658
|
+
modality_tokens.append(data)
|
|
659
|
+
masks = repeat(
|
|
660
|
+
masks,
|
|
661
|
+
"b h w ... -> b (h p_h) (w p_w) ...",
|
|
662
|
+
p_h=patch_size,
|
|
663
|
+
p_w=patch_size,
|
|
664
|
+
)
|
|
665
|
+
modality_masks.append(masks)
|
|
666
|
+
modality_mask = repeat(
|
|
667
|
+
modality_mask,
|
|
668
|
+
"b h w ... -> b (h p_h) (w p_w) ...",
|
|
669
|
+
p_h=patch_size,
|
|
670
|
+
p_w=patch_size,
|
|
671
|
+
)
|
|
672
|
+
return torch.cat(modality_tokens, dim=-1), modality_mask
|
|
673
|
+
|
|
674
|
+
def forward(
|
|
675
|
+
self,
|
|
676
|
+
x: TokensAndMasks,
|
|
677
|
+
timestamps: Tensor,
|
|
678
|
+
patch_size: int,
|
|
679
|
+
input_res: int = BASE_GSD,
|
|
680
|
+
) -> TokensAndMasks:
|
|
681
|
+
"""Return flexibly patchified reconstruction for each modality of the input data.
|
|
682
|
+
|
|
683
|
+
Given a [B, H, W, (T), b_s, D] inputs, returns a [B, H, W, (T), C] output.
|
|
684
|
+
"""
|
|
685
|
+
input_data = self.decoder(x, timestamps, patch_size, input_res)
|
|
686
|
+
output_dict = {}
|
|
687
|
+
modalities_to_process = get_modalities_to_process(
|
|
688
|
+
input_data.modalities, [m.name for m in self.supported_modalities]
|
|
689
|
+
)
|
|
690
|
+
for modality in modalities_to_process:
|
|
691
|
+
modality_tokens, modality_masks = self.apply_reconstruction_to_modality(
|
|
692
|
+
modality, input_data, patch_size
|
|
693
|
+
)
|
|
694
|
+
output_dict[modality] = modality_tokens
|
|
695
|
+
modality_mask_name = input_data.get_masked_modality_name(modality)
|
|
696
|
+
output_dict[modality_mask_name] = modality_masks
|
|
697
|
+
return TokensAndMasks(**output_dict)
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
@dataclass
|
|
701
|
+
class ReconstructorConfig(Config):
|
|
702
|
+
"""Configuration for the Reconstructor."""
|
|
703
|
+
|
|
704
|
+
decoder_config: "Config"
|
|
705
|
+
supported_modality_names: list[str]
|
|
706
|
+
max_patch_size: int = 8
|
|
707
|
+
tokenization_config: TokenizationConfig | None = None
|
|
708
|
+
|
|
709
|
+
def validate(self) -> None:
|
|
710
|
+
"""Validate the configuration."""
|
|
711
|
+
if len(self.supported_modalities) == 0:
|
|
712
|
+
raise ValueError("At least one modality must be added!")
|
|
713
|
+
else:
|
|
714
|
+
for modality in self.supported_modalities:
|
|
715
|
+
if modality not in Modality.values():
|
|
716
|
+
raise ValueError(f"Modality {modality} is not supported")
|
|
717
|
+
if self.tokenization_config is not None:
|
|
718
|
+
self.tokenization_config.validate()
|
|
719
|
+
|
|
720
|
+
@property
|
|
721
|
+
def supported_modalities(self) -> list[ModalitySpec]:
|
|
722
|
+
"""Get the supported modalities."""
|
|
723
|
+
return get_modality_specs_from_names(self.supported_modality_names)
|
|
724
|
+
|
|
725
|
+
def build(self) -> "Reconstructor":
|
|
726
|
+
"""Build the reconstructor."""
|
|
727
|
+
self.validate()
|
|
728
|
+
kwargs = self.as_dict(exclude_none=True, recurse=False)
|
|
729
|
+
kwargs.pop("supported_modality_names")
|
|
730
|
+
kwargs["supported_modalities"] = self.supported_modalities
|
|
731
|
+
kwargs.pop("decoder_config")
|
|
732
|
+
kwargs["decoder"] = self.decoder_config.build()
|
|
733
|
+
logger.info(f"Predictor kwargs: {kwargs}")
|
|
734
|
+
return Reconstructor(**kwargs)
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
class CompositeEncodings(nn.Module):
|
|
738
|
+
"""Composite encodings for FlexiVit models."""
|
|
739
|
+
|
|
740
|
+
def __init__(
|
|
741
|
+
self,
|
|
742
|
+
embedding_size: int,
|
|
743
|
+
supported_modalities: list[ModalitySpec],
|
|
744
|
+
max_sequence_length: int,
|
|
745
|
+
learnable_channel_embeddings: bool = True,
|
|
746
|
+
random_channel_embeddings: bool = False,
|
|
747
|
+
tokenization_config: TokenizationConfig | None = None,
|
|
748
|
+
):
|
|
749
|
+
"""Initialize the composite encodings.
|
|
750
|
+
|
|
751
|
+
Args:
|
|
752
|
+
embedding_size: Size of token embeddings
|
|
753
|
+
supported_modalities: Which modalities from Modality this model
|
|
754
|
+
instantiation supports
|
|
755
|
+
max_sequence_length: Maximum sequence length
|
|
756
|
+
learnable_channel_embeddings: Whether to use learnable channel embeddings
|
|
757
|
+
random_channel_embeddings: Initialize channel embeddings randomly (zeros if False)
|
|
758
|
+
tokenization_config: Optional config for custom band groupings
|
|
759
|
+
"""
|
|
760
|
+
super().__init__()
|
|
761
|
+
self.embedding_size = embedding_size
|
|
762
|
+
self.supported_modalities = supported_modalities
|
|
763
|
+
self.supported_modality_names = [
|
|
764
|
+
modality.name for modality in supported_modalities
|
|
765
|
+
]
|
|
766
|
+
self.tokenization_config = tokenization_config or TokenizationConfig()
|
|
767
|
+
self.embedding_size = embedding_size
|
|
768
|
+
self.max_sequence_length = (
|
|
769
|
+
max_sequence_length # This max sequence length is a time dim thing
|
|
770
|
+
)
|
|
771
|
+
# TODO: we need to be able to calculate the size of the param based on what types of embeddings it will get
|
|
772
|
+
|
|
773
|
+
# we have 4 embeddings types (pos_in_time, pos_in_space, month, channel) so each get
|
|
774
|
+
# 0.25 of the dimension
|
|
775
|
+
self.embedding_dim_per_embedding_type = int(embedding_size * 0.25)
|
|
776
|
+
# Position encodings for time dimension initialized to 1D sinusoidal encodings
|
|
777
|
+
self.pos_embed = nn.Parameter(
|
|
778
|
+
get_1d_sincos_pos_encoding(
|
|
779
|
+
torch.arange(max_sequence_length),
|
|
780
|
+
self.embedding_dim_per_embedding_type,
|
|
781
|
+
),
|
|
782
|
+
requires_grad=False,
|
|
783
|
+
)
|
|
784
|
+
# Month encodings
|
|
785
|
+
month_tab = get_month_encoding_table(self.embedding_dim_per_embedding_type)
|
|
786
|
+
self.month_embed = nn.Embedding.from_pretrained(month_tab, freeze=True)
|
|
787
|
+
if not learnable_channel_embeddings and not random_channel_embeddings:
|
|
788
|
+
self.per_modality_channel_embeddings = nn.ParameterDict()
|
|
789
|
+
for modality in self.supported_modalities:
|
|
790
|
+
num_bandsets = self.tokenization_config.get_num_bandsets(modality.name)
|
|
791
|
+
shape = (num_bandsets, self.embedding_dim_per_embedding_type)
|
|
792
|
+
channel_embeddings = nn.Parameter(
|
|
793
|
+
torch.zeros(shape), requires_grad=False
|
|
794
|
+
)
|
|
795
|
+
self.per_modality_channel_embeddings[modality.name] = channel_embeddings
|
|
796
|
+
else:
|
|
797
|
+
# Channel embeddings
|
|
798
|
+
if learnable_channel_embeddings:
|
|
799
|
+
args = {"requires_grad": True}
|
|
800
|
+
else:
|
|
801
|
+
args = {"requires_grad": False}
|
|
802
|
+
|
|
803
|
+
self.per_modality_channel_embeddings = nn.ParameterDict()
|
|
804
|
+
for modality in self.supported_modalities:
|
|
805
|
+
num_bandsets = self.tokenization_config.get_num_bandsets(modality.name)
|
|
806
|
+
shape = (num_bandsets, self.embedding_dim_per_embedding_type)
|
|
807
|
+
if random_channel_embeddings:
|
|
808
|
+
channel_embeddings = nn.Parameter(torch.rand(shape), **args)
|
|
809
|
+
else:
|
|
810
|
+
channel_embeddings = nn.Parameter(torch.zeros(shape), **args)
|
|
811
|
+
self.per_modality_channel_embeddings[modality.name] = channel_embeddings
|
|
812
|
+
|
|
813
|
+
self.apply(self._init_weights)
|
|
814
|
+
|
|
815
|
+
def _init_weights(self, m: nn.Module) -> None:
|
|
816
|
+
if isinstance(m, nn.Linear):
|
|
817
|
+
# we use xavier_uniform following official JAX ViT:
|
|
818
|
+
torch.nn.init.xavier_uniform_(m.weight)
|
|
819
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
820
|
+
# TODO: fix the dtype here
|
|
821
|
+
nn.init.constant_(m.bias, 0).to(torch.float32)
|
|
822
|
+
|
|
823
|
+
@staticmethod
|
|
824
|
+
def calculate_gsd_ratio(input_res: float, patch_size: int) -> float:
|
|
825
|
+
"""Calculate the Ground Sample Distance ratio."""
|
|
826
|
+
return input_res * patch_size / BASE_GSD
|
|
827
|
+
|
|
828
|
+
def _apply_encodings_per_modality(
|
|
829
|
+
self,
|
|
830
|
+
modality_name: str,
|
|
831
|
+
modality_tokens: Tensor,
|
|
832
|
+
timestamps: Tensor | None = None,
|
|
833
|
+
patch_size: int | None = None,
|
|
834
|
+
input_res: int | None = None,
|
|
835
|
+
use_modality_encodings: bool = True,
|
|
836
|
+
use_temporal_encodings: bool = True,
|
|
837
|
+
) -> Tensor:
|
|
838
|
+
"""Apply the encodings to the patchified data based on modality type.
|
|
839
|
+
|
|
840
|
+
Args:
|
|
841
|
+
modality_name: Name of the modality being processed
|
|
842
|
+
modality_tokens: Token embeddings for the modality
|
|
843
|
+
timestamps: Optional timestamps for temporal encodings
|
|
844
|
+
patch_size: Optional patch size for spatial encodings
|
|
845
|
+
input_res: Optional input resolution for spatial encodings
|
|
846
|
+
use_modality_encodings: Whether to use modality encodings
|
|
847
|
+
use_temporal_encodings: Whether to use temporal encodings
|
|
848
|
+
|
|
849
|
+
Returns:
|
|
850
|
+
Tensor with encodings applied based on modality type
|
|
851
|
+
"""
|
|
852
|
+
logger.debug(
|
|
853
|
+
f"use_modality_encodings: {use_modality_encodings}, use_temporal_encodings: {use_temporal_encodings}"
|
|
854
|
+
)
|
|
855
|
+
# TODO: Improve this implementation it is quite bad
|
|
856
|
+
|
|
857
|
+
modality = Modality.get(modality_name)
|
|
858
|
+
logger.debug(f"Applying encodings to modality {modality}")
|
|
859
|
+
if not use_modality_encodings and use_temporal_encodings:
|
|
860
|
+
b, h, w, t, _ = modality_tokens.shape
|
|
861
|
+
ein_string, ein_dict = (
|
|
862
|
+
"b h w t d",
|
|
863
|
+
{"b": b, "h": h, "w": w, "t": t},
|
|
864
|
+
)
|
|
865
|
+
elif not use_temporal_encodings and not use_modality_encodings:
|
|
866
|
+
b, h, w, _ = modality_tokens.shape
|
|
867
|
+
ein_string, ein_dict = (
|
|
868
|
+
"b h w d",
|
|
869
|
+
{"b": b, "h": h, "w": w},
|
|
870
|
+
)
|
|
871
|
+
elif not use_temporal_encodings and use_modality_encodings:
|
|
872
|
+
raise NotImplementedError("Not implemented")
|
|
873
|
+
else:
|
|
874
|
+
if modality_tokens.ndim == 3:
|
|
875
|
+
# modality_tokens = [B, Band_Sets, D]; static in space, static in time
|
|
876
|
+
b, b_s, _ = modality_tokens.shape
|
|
877
|
+
ein_string, ein_dict = "b b_s d", {"b": b, "b_s": b_s}
|
|
878
|
+
elif modality_tokens.ndim == 4:
|
|
879
|
+
b, t, b_s, _ = modality_tokens.shape
|
|
880
|
+
ein_string, ein_dict = "b t b_s d", {"b": b, "t": t, "b_s": b_s}
|
|
881
|
+
elif modality_tokens.ndim == 5:
|
|
882
|
+
b, h, w, b_s, _ = modality_tokens.shape
|
|
883
|
+
ein_string, ein_dict = (
|
|
884
|
+
"b h w b_s d",
|
|
885
|
+
{"b": b, "h": h, "w": w, "b_s": b_s},
|
|
886
|
+
)
|
|
887
|
+
elif modality_tokens.ndim == 6:
|
|
888
|
+
b, h, w, t, b_s, _ = modality_tokens.shape
|
|
889
|
+
ein_string, ein_dict = (
|
|
890
|
+
"b h w t b_s d",
|
|
891
|
+
{"b": b, "h": h, "w": w, "t": t, "b_s": b_s},
|
|
892
|
+
)
|
|
893
|
+
else:
|
|
894
|
+
raise ValueError(f"Unsupported tokens shape: {modality_tokens.shape}")
|
|
895
|
+
|
|
896
|
+
device = modality_tokens.device
|
|
897
|
+
modality_embed = torch.zeros(modality_tokens.shape, device=device)
|
|
898
|
+
n = self.embedding_dim_per_embedding_type
|
|
899
|
+
actual_bandsets = modality_tokens.shape[-2]
|
|
900
|
+
|
|
901
|
+
# Channel embeddings
|
|
902
|
+
if use_modality_encodings:
|
|
903
|
+
channel_embed = self.per_modality_channel_embeddings[modality.name]
|
|
904
|
+
if channel_embed.shape[0] != actual_bandsets:
|
|
905
|
+
raise ValueError(
|
|
906
|
+
f"Channel embeddings for {modality.name} expect "
|
|
907
|
+
f"{channel_embed.shape[0]} bandsets but tokens have "
|
|
908
|
+
f"{actual_bandsets}. Ensure tokenization_config is "
|
|
909
|
+
"consistently passed to the encoder/decoder and masking strategy."
|
|
910
|
+
)
|
|
911
|
+
channel_embed = repeat(
|
|
912
|
+
channel_embed, f"b_s d -> {ein_string}", **ein_dict
|
|
913
|
+
).to(device)
|
|
914
|
+
modality_embed[..., :n] += channel_embed
|
|
915
|
+
|
|
916
|
+
if modality.is_multitemporal and use_temporal_encodings:
|
|
917
|
+
# Time position encodings
|
|
918
|
+
time_embed = repeat(self.pos_embed[:t], f"t d -> {ein_string}", **ein_dict)
|
|
919
|
+
modality_embed[..., n : n * 2] += time_embed.to(device)
|
|
920
|
+
|
|
921
|
+
# Month encodings
|
|
922
|
+
assert timestamps is not None
|
|
923
|
+
months = timestamps[:, :, 1]
|
|
924
|
+
month_embed = self.month_embed(months)
|
|
925
|
+
month_embed = repeat(month_embed, f"b t d -> {ein_string}", **ein_dict)
|
|
926
|
+
modality_embed[..., n * 2 : n * 3] += month_embed.to(device)
|
|
927
|
+
if modality.is_spatial:
|
|
928
|
+
# Spatial encodings
|
|
929
|
+
assert input_res is not None
|
|
930
|
+
assert patch_size is not None
|
|
931
|
+
gsd_ratio = self.calculate_gsd_ratio(input_res, patch_size)
|
|
932
|
+
spatial_embed = get_2d_sincos_pos_encoding_with_resolution(
|
|
933
|
+
grid_size=h,
|
|
934
|
+
res=torch.ones(b, device=device) * gsd_ratio,
|
|
935
|
+
encoding_dim=self.embedding_dim_per_embedding_type,
|
|
936
|
+
device=device,
|
|
937
|
+
)
|
|
938
|
+
spatial_embed = rearrange(spatial_embed, "b (h w) d -> b h w d", h=h, w=w)
|
|
939
|
+
spatial_embed = repeat(
|
|
940
|
+
spatial_embed, f"b h w d -> {ein_string}", **ein_dict
|
|
941
|
+
)
|
|
942
|
+
modality_embed[..., n * 3 : n * 4] += spatial_embed
|
|
943
|
+
return modality_tokens + modality_embed
|
|
944
|
+
|
|
945
|
+
def forward(
|
|
946
|
+
self,
|
|
947
|
+
per_modality_input_tokens: dict[str, Tensor],
|
|
948
|
+
timestamps: Tensor,
|
|
949
|
+
patch_size: int,
|
|
950
|
+
input_res: int = BASE_GSD,
|
|
951
|
+
) -> dict[str, Tensor]:
|
|
952
|
+
"""Apply the encodings to the patchified data.
|
|
953
|
+
|
|
954
|
+
Args:
|
|
955
|
+
per_modality_input_tokens: Tokens only for each modality
|
|
956
|
+
timestamps: Timestamps of the data
|
|
957
|
+
patch_size: Size of patches
|
|
958
|
+
input_res: Resolution of the input data
|
|
959
|
+
|
|
960
|
+
Returns:
|
|
961
|
+
Tokens only for each modality
|
|
962
|
+
"""
|
|
963
|
+
output_dict = {}
|
|
964
|
+
available_modalities = return_modalities_from_dict(per_modality_input_tokens)
|
|
965
|
+
modalities_to_process = get_modalities_to_process(
|
|
966
|
+
available_modalities, self.supported_modality_names
|
|
967
|
+
)
|
|
968
|
+
for modality_name in modalities_to_process:
|
|
969
|
+
output_dict[modality_name] = self._apply_encodings_per_modality(
|
|
970
|
+
modality_name,
|
|
971
|
+
per_modality_input_tokens[modality_name],
|
|
972
|
+
timestamps=timestamps,
|
|
973
|
+
patch_size=patch_size,
|
|
974
|
+
input_res=input_res,
|
|
975
|
+
)
|
|
976
|
+
return output_dict
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
class FlexiVitBase(nn.Module):
|
|
980
|
+
"""FlexiVitBase is a base class for FlexiVit models."""
|
|
981
|
+
|
|
982
|
+
cross_attn: bool = False
|
|
983
|
+
|
|
984
|
+
def __init__(
|
|
985
|
+
self,
|
|
986
|
+
embedding_size: int,
|
|
987
|
+
max_sequence_length: int,
|
|
988
|
+
num_heads: int,
|
|
989
|
+
mlp_ratio: float,
|
|
990
|
+
depth: int,
|
|
991
|
+
drop_path: float,
|
|
992
|
+
supported_modalities: list[ModalitySpec],
|
|
993
|
+
learnable_channel_embeddings: bool = True,
|
|
994
|
+
random_channel_embeddings: bool = False,
|
|
995
|
+
use_flash_attn: bool = False,
|
|
996
|
+
qk_norm: bool = False,
|
|
997
|
+
tokenization_config: TokenizationConfig | None = None,
|
|
998
|
+
) -> None:
|
|
999
|
+
"""Initialize the FlexiVitBase class."""
|
|
1000
|
+
super().__init__()
|
|
1001
|
+
|
|
1002
|
+
self.embedding_size = embedding_size
|
|
1003
|
+
self.supported_modalities = supported_modalities
|
|
1004
|
+
self.supported_modality_names = [x.name for x in supported_modalities]
|
|
1005
|
+
logger.info(f"modalities being used by model: {self.supported_modality_names}")
|
|
1006
|
+
|
|
1007
|
+
self.max_sequence_length = max_sequence_length
|
|
1008
|
+
self._base_tokenization_config = tokenization_config or TokenizationConfig()
|
|
1009
|
+
|
|
1010
|
+
self.use_flash_attn = use_flash_attn
|
|
1011
|
+
self.learnable_channel_embeddings = learnable_channel_embeddings
|
|
1012
|
+
self.random_channel_embeddings = random_channel_embeddings
|
|
1013
|
+
self.blocks = nn.ModuleList(
|
|
1014
|
+
[
|
|
1015
|
+
Block(
|
|
1016
|
+
embedding_size,
|
|
1017
|
+
num_heads,
|
|
1018
|
+
mlp_ratio,
|
|
1019
|
+
qkv_bias=True,
|
|
1020
|
+
qk_norm=qk_norm,
|
|
1021
|
+
norm_layer=nn.LayerNorm, # TODO: This should be configurable
|
|
1022
|
+
cross_attn=self.cross_attn,
|
|
1023
|
+
drop_path=drop_path,
|
|
1024
|
+
use_flash_attn=self.use_flash_attn,
|
|
1025
|
+
)
|
|
1026
|
+
for _ in range(depth)
|
|
1027
|
+
]
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
self.composite_encodings = CompositeEncodings(
|
|
1031
|
+
embedding_size,
|
|
1032
|
+
self.supported_modalities,
|
|
1033
|
+
max_sequence_length,
|
|
1034
|
+
learnable_channel_embeddings,
|
|
1035
|
+
random_channel_embeddings,
|
|
1036
|
+
tokenization_config=self._base_tokenization_config,
|
|
1037
|
+
)
|
|
1038
|
+
self.apply(self._init_weights)
|
|
1039
|
+
|
|
1040
|
+
def _init_weights(self, m: nn.Module) -> None:
|
|
1041
|
+
if isinstance(m, nn.Linear):
|
|
1042
|
+
# we use xavier_uniform following official JAX ViT:
|
|
1043
|
+
torch.nn.init.xavier_uniform_(m.weight)
|
|
1044
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
1045
|
+
nn.init.constant_(m.bias, 0)
|
|
1046
|
+
|
|
1047
|
+
@staticmethod
|
|
1048
|
+
def grab_modality_specific_dims(modality_data: Tensor) -> tuple[int, ...]:
|
|
1049
|
+
"""Grab the modality specific dimensions from the modality data.
|
|
1050
|
+
|
|
1051
|
+
Assumes [B, ..., C, D]
|
|
1052
|
+
|
|
1053
|
+
Every modality will have a batch dimension, a channel dimension and embedding dimension.
|
|
1054
|
+
|
|
1055
|
+
Args:
|
|
1056
|
+
modality_data: Modality data
|
|
1057
|
+
|
|
1058
|
+
Returns:
|
|
1059
|
+
Modality specific dimensions
|
|
1060
|
+
"""
|
|
1061
|
+
return modality_data.shape[1:-2] if modality_data.ndim > 3 else ()
|
|
1062
|
+
|
|
1063
|
+
# is naming here confusing if one of these channels can be missing?
|
|
1064
|
+
def collapse_and_combine_hwtc(self, x: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
1065
|
+
"""Collapse the tokens and masks, respectively, into two tensors."""
|
|
1066
|
+
tokens, masks = [], []
|
|
1067
|
+
available_modalities = return_modalities_from_dict(x)
|
|
1068
|
+
modalities_to_process = get_modalities_to_process(
|
|
1069
|
+
available_modalities, self.supported_modality_names
|
|
1070
|
+
)
|
|
1071
|
+
for modality in modalities_to_process:
|
|
1072
|
+
masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
|
|
1073
|
+
modality
|
|
1074
|
+
)
|
|
1075
|
+
x_modality = x[modality]
|
|
1076
|
+
x_modality_mask = x[masked_modality_name]
|
|
1077
|
+
tokens.append(rearrange(x_modality, "b ... d -> b (...) d"))
|
|
1078
|
+
masks.append(rearrange(x_modality_mask, "b ... -> b (...)"))
|
|
1079
|
+
tokens = torch.cat(tokens, dim=1)
|
|
1080
|
+
masks = torch.cat(masks, dim=1)
|
|
1081
|
+
|
|
1082
|
+
return tokens, masks
|
|
1083
|
+
|
|
1084
|
+
@staticmethod
|
|
1085
|
+
def _construct_einops_pattern(
|
|
1086
|
+
spatial_dims: tuple[int, ...],
|
|
1087
|
+
) -> tuple[str, dict[str, int]]:
|
|
1088
|
+
"""Given a tuple of spatial dimensions (e.g. [B, H, W, T, ...]).
|
|
1089
|
+
|
|
1090
|
+
build (1) an einops rearrange pattern of the form:
|
|
1091
|
+
"d -> (dim0) (dim1) (dim2)... d"
|
|
1092
|
+
and (2) a dictionary mapping dim0..dimN to the actual sizes.
|
|
1093
|
+
|
|
1094
|
+
This allows reshaping a single-dimensional tensor [D] into
|
|
1095
|
+
[B, H, W, T, ..., D] using einops.
|
|
1096
|
+
"""
|
|
1097
|
+
dim_dict = {f"dim{i}": size for i, size in enumerate(spatial_dims)}
|
|
1098
|
+
# e.g., "d -> (dim0) (dim1) (dim2) (dim3) d"
|
|
1099
|
+
pattern_input = (
|
|
1100
|
+
"d -> " + " ".join(f"(dim{i})" for i in range(len(spatial_dims))) + " d"
|
|
1101
|
+
)
|
|
1102
|
+
return pattern_input, dim_dict
|
|
1103
|
+
|
|
1104
|
+
def split_tokens_masks_and_dims(
|
|
1105
|
+
self, x: dict[str, Tensor]
|
|
1106
|
+
) -> tuple[dict[str, Tensor], dict[str, Tensor], dict[str, tuple]]:
|
|
1107
|
+
"""Split the tokens, masks, and dimensions out into separate dicts."""
|
|
1108
|
+
tokens_only_dict = {}
|
|
1109
|
+
original_masks_dict = {}
|
|
1110
|
+
modalities_to_dims_dict = {}
|
|
1111
|
+
available_modalities = return_modalities_from_dict(x)
|
|
1112
|
+
modalities_to_process = get_modalities_to_process(
|
|
1113
|
+
available_modalities, self.supported_modality_names
|
|
1114
|
+
)
|
|
1115
|
+
for modality in modalities_to_process:
|
|
1116
|
+
x_modality = x[modality]
|
|
1117
|
+
tokens_only_dict[modality] = x_modality
|
|
1118
|
+
modalities_to_dims_dict[modality] = x_modality.shape
|
|
1119
|
+
masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
|
|
1120
|
+
modality
|
|
1121
|
+
)
|
|
1122
|
+
original_masks_dict[masked_modality_name] = x[masked_modality_name]
|
|
1123
|
+
return tokens_only_dict, original_masks_dict, modalities_to_dims_dict
|
|
1124
|
+
|
|
1125
|
+
@staticmethod
|
|
1126
|
+
def split_and_expand_per_modality(
|
|
1127
|
+
x: Tensor, modalities_to_dims_dict: dict
|
|
1128
|
+
) -> dict[str, Tensor]:
|
|
1129
|
+
"""Split and expand the tokens per modality.
|
|
1130
|
+
|
|
1131
|
+
Args:
|
|
1132
|
+
x: Tokens to split and expand (b n d)
|
|
1133
|
+
modalities_to_dims_dict: Dictionary mapping modalities to their dimensions
|
|
1134
|
+
Returns:
|
|
1135
|
+
tokens_only_dict: mapping modalities to their tokens
|
|
1136
|
+
"""
|
|
1137
|
+
tokens_only_dict = {}
|
|
1138
|
+
tokens_reshaped = 0
|
|
1139
|
+
for modality, dims in modalities_to_dims_dict.items():
|
|
1140
|
+
# Skip batch (first) and embedding (last) dimensions
|
|
1141
|
+
middle_dims = dims[1:-1]
|
|
1142
|
+
num_tokens_for_modality = math.prod(middle_dims)
|
|
1143
|
+
|
|
1144
|
+
# Extract tokens for this modality (b n d)
|
|
1145
|
+
modality_tokens = x[
|
|
1146
|
+
:, tokens_reshaped : tokens_reshaped + num_tokens_for_modality
|
|
1147
|
+
]
|
|
1148
|
+
|
|
1149
|
+
# TODO: see if there is a general and clean einops way to do this
|
|
1150
|
+
# Reshape to original dimensions (e.g., for 4D spatial dims: b d1 d2 d3 d4 e)
|
|
1151
|
+
x_modality = modality_tokens.view(x.shape[0], *middle_dims, x.shape[-1])
|
|
1152
|
+
|
|
1153
|
+
tokens_reshaped += num_tokens_for_modality
|
|
1154
|
+
tokens_only_dict[modality] = x_modality
|
|
1155
|
+
|
|
1156
|
+
return tokens_only_dict
|
|
1157
|
+
|
|
1158
|
+
@staticmethod
|
|
1159
|
+
def pack_tokens(tokens: Tensor, mask: Tensor) -> Tensor:
|
|
1160
|
+
"""Pack the Batch and sequence length dimensions of tokens and mask into a single tensor.
|
|
1161
|
+
|
|
1162
|
+
Args:
|
|
1163
|
+
tokens: Tokens to pack
|
|
1164
|
+
mask: Mask to pack
|
|
1165
|
+
|
|
1166
|
+
Returns:
|
|
1167
|
+
Packed tokens enabling varlen flash attention
|
|
1168
|
+
"""
|
|
1169
|
+
tokens_packed = torch.flatten(tokens, end_dim=1)
|
|
1170
|
+
mask = torch.flatten(mask)
|
|
1171
|
+
tokens = tokens_packed[mask]
|
|
1172
|
+
return tokens
|
|
1173
|
+
|
|
1174
|
+
@staticmethod
|
|
1175
|
+
def unpack_tokens(tokens: Tensor, mask: Tensor, og_shape: tuple) -> Tensor:
|
|
1176
|
+
"""Unpack the Batch and sequence length dimensions of tokens and mask into a single tensor.
|
|
1177
|
+
|
|
1178
|
+
Args:
|
|
1179
|
+
tokens: Tokens to unpack
|
|
1180
|
+
mask: Mask to unpack
|
|
1181
|
+
og_shape: Original shape of the tokens
|
|
1182
|
+
"""
|
|
1183
|
+
tokens_new = tokens.new_zeros(og_shape[0] * og_shape[1], og_shape[2])
|
|
1184
|
+
mask = torch.flatten(mask)
|
|
1185
|
+
tokens_new[mask] = tokens
|
|
1186
|
+
tokens = tokens_new.reshape(og_shape[0], og_shape[1], -1)
|
|
1187
|
+
return tokens
|
|
1188
|
+
|
|
1189
|
+
def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
|
|
1190
|
+
"""Apply FSDP to the model."""
|
|
1191
|
+
for block in self.blocks:
|
|
1192
|
+
block.apply_fsdp(**fsdp_kwargs)
|
|
1193
|
+
|
|
1194
|
+
def apply_compile(self) -> None:
|
|
1195
|
+
"""Apply torch.compile to the model."""
|
|
1196
|
+
for block in self.blocks:
|
|
1197
|
+
block.apply_compile()
|
|
1198
|
+
|
|
1199
|
+
|
|
1200
|
+
class Encoder(FlexiVitBase):
|
|
1201
|
+
"""Encoder module that processes masked input samples into token representations."""
|
|
1202
|
+
|
|
1203
|
+
cross_attn: bool = False
|
|
1204
|
+
|
|
1205
|
+
def __init__(
|
|
1206
|
+
self,
|
|
1207
|
+
embedding_size: int,
|
|
1208
|
+
max_patch_size: int,
|
|
1209
|
+
min_patch_size: int,
|
|
1210
|
+
num_heads: int,
|
|
1211
|
+
mlp_ratio: float,
|
|
1212
|
+
depth: int,
|
|
1213
|
+
drop_path: float,
|
|
1214
|
+
supported_modalities: list[ModalitySpec],
|
|
1215
|
+
max_sequence_length: int,
|
|
1216
|
+
num_register_tokens: int = 0,
|
|
1217
|
+
learnable_channel_embeddings: bool = True,
|
|
1218
|
+
random_channel_embeddings: bool = False,
|
|
1219
|
+
num_projection_layers: int = 1,
|
|
1220
|
+
aggregate_then_project: bool = True,
|
|
1221
|
+
use_flash_attn: bool = False,
|
|
1222
|
+
frozen_patch_embeddings: bool = False,
|
|
1223
|
+
qk_norm: bool = False,
|
|
1224
|
+
log_token_norm_stats: bool = False,
|
|
1225
|
+
tokenization_config: TokenizationConfig | None = None,
|
|
1226
|
+
):
|
|
1227
|
+
"""Initialize the encoder.
|
|
1228
|
+
|
|
1229
|
+
Args:
|
|
1230
|
+
embedding_size: Size of token embeddings
|
|
1231
|
+
max_patch_size: Maximum patch size for patchification
|
|
1232
|
+
min_patch_size: Minimum patch size for patchification
|
|
1233
|
+
num_heads: Number of attention heads
|
|
1234
|
+
mlp_ratio: Ratio for MLP hidden dimension
|
|
1235
|
+
depth: Number of transformer layers
|
|
1236
|
+
drop_path: Drop path rate
|
|
1237
|
+
supported_modalities: list documenting modalities used in a given model instantiation
|
|
1238
|
+
max_sequence_length: Maximum sequence length
|
|
1239
|
+
num_register_tokens: Number of register tokens to use
|
|
1240
|
+
learnable_channel_embeddings: Whether to use learnable channel embeddings
|
|
1241
|
+
random_channel_embeddings: Initialize channel embeddings randomly (zeros if False)
|
|
1242
|
+
num_projection_layers: The number of layers to use in the projection. If >1, then
|
|
1243
|
+
a ReLU activation will be applied between layers
|
|
1244
|
+
aggregate_then_project: If True, then we will average the tokens before applying
|
|
1245
|
+
the projection. If False, we will apply the projection first.
|
|
1246
|
+
use_flash_attn: Whether to use flash attention
|
|
1247
|
+
frozen_patch_embeddings: If True, we freeze the embedding layer, as recommended in
|
|
1248
|
+
https://arxiv.org/pdf/2104.02057, Section 4.2
|
|
1249
|
+
qk_norm: Whether to apply normalization to Q and K in attention
|
|
1250
|
+
log_token_norm_stats: Whether to log the token norm stats
|
|
1251
|
+
tokenization_config: Optional config for custom band groupings
|
|
1252
|
+
"""
|
|
1253
|
+
self.tokenization_config = tokenization_config or TokenizationConfig()
|
|
1254
|
+
super().__init__(
|
|
1255
|
+
embedding_size=embedding_size,
|
|
1256
|
+
depth=depth,
|
|
1257
|
+
mlp_ratio=mlp_ratio,
|
|
1258
|
+
num_heads=num_heads,
|
|
1259
|
+
max_sequence_length=max_sequence_length,
|
|
1260
|
+
learnable_channel_embeddings=learnable_channel_embeddings,
|
|
1261
|
+
drop_path=drop_path,
|
|
1262
|
+
supported_modalities=supported_modalities,
|
|
1263
|
+
use_flash_attn=use_flash_attn,
|
|
1264
|
+
random_channel_embeddings=random_channel_embeddings,
|
|
1265
|
+
qk_norm=qk_norm,
|
|
1266
|
+
tokenization_config=self.tokenization_config,
|
|
1267
|
+
)
|
|
1268
|
+
self.num_register_tokens = num_register_tokens
|
|
1269
|
+
self.has_register_tokens = num_register_tokens > 0
|
|
1270
|
+
self.log_token_norm_stats = log_token_norm_stats
|
|
1271
|
+
if self.has_register_tokens:
|
|
1272
|
+
self.register_tokens = nn.Parameter(
|
|
1273
|
+
torch.zeros(num_register_tokens, embedding_size)
|
|
1274
|
+
)
|
|
1275
|
+
self.min_patch_size = min_patch_size
|
|
1276
|
+
self.max_patch_size = max_patch_size
|
|
1277
|
+
self.embedding_size = embedding_size
|
|
1278
|
+
self.patch_embeddings = MultiModalPatchEmbeddings(
|
|
1279
|
+
self.supported_modality_names,
|
|
1280
|
+
self.max_patch_size,
|
|
1281
|
+
self.embedding_size,
|
|
1282
|
+
tokenization_config=self.tokenization_config,
|
|
1283
|
+
)
|
|
1284
|
+
self.project_and_aggregate = ProjectAndAggregate(
|
|
1285
|
+
embedding_size=self.embedding_size,
|
|
1286
|
+
num_layers=num_projection_layers,
|
|
1287
|
+
aggregate_then_project=aggregate_then_project,
|
|
1288
|
+
)
|
|
1289
|
+
self.norm = nn.LayerNorm(self.embedding_size)
|
|
1290
|
+
self.apply(self._init_weights)
|
|
1291
|
+
|
|
1292
|
+
if frozen_patch_embeddings:
|
|
1293
|
+
for p in self.patch_embeddings.parameters():
|
|
1294
|
+
p.requires_grad = False
|
|
1295
|
+
if self.has_register_tokens:
|
|
1296
|
+
self._init_register_tokens()
|
|
1297
|
+
|
|
1298
|
+
def _init_register_tokens(self) -> None:
|
|
1299
|
+
"""Initialize the register tokens."""
|
|
1300
|
+
nn.init.xavier_uniform_(self.register_tokens)
|
|
1301
|
+
|
|
1302
|
+
def create_token_exit_ids(
|
|
1303
|
+
self, x: dict[str, Tensor], token_exit_cfg: dict[str, int]
|
|
1304
|
+
) -> dict[str, Tensor]:
|
|
1305
|
+
"""Create the token exit ids for # of layers of attention for each band group.
|
|
1306
|
+
|
|
1307
|
+
Assumes modality channel groups are in the second to last dimension of the tokens.
|
|
1308
|
+
"""
|
|
1309
|
+
exit_ids_per_modality_dict = {}
|
|
1310
|
+
available_modalities = return_modalities_from_dict(x)
|
|
1311
|
+
modalities_to_process = get_modalities_to_process(
|
|
1312
|
+
available_modalities, self.supported_modality_names
|
|
1313
|
+
)
|
|
1314
|
+
for modality in modalities_to_process:
|
|
1315
|
+
num_exit_layers = token_exit_cfg[modality]
|
|
1316
|
+
exit_seq_modality = torch.full_like(x[modality], fill_value=num_exit_layers)
|
|
1317
|
+
exit_ids_per_modality_dict[modality] = exit_seq_modality
|
|
1318
|
+
return exit_ids_per_modality_dict
|
|
1319
|
+
|
|
1320
|
+
@staticmethod
|
|
1321
|
+
def remove_masked_tokens(
|
|
1322
|
+
x: Tensor, mask: Tensor
|
|
1323
|
+
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
1324
|
+
"""Remove masked tokens from the tokens and masks.
|
|
1325
|
+
|
|
1326
|
+
Implementation from https://stackoverflow.com/a/68621610/2332296
|
|
1327
|
+
|
|
1328
|
+
On Input:
|
|
1329
|
+
0 means this token should be removed
|
|
1330
|
+
1 means this token should be kept
|
|
1331
|
+
|
|
1332
|
+
Args:
|
|
1333
|
+
x: Tokens to remove masked tokens from
|
|
1334
|
+
mask: Mask to remove masked tokens from
|
|
1335
|
+
|
|
1336
|
+
Returns:
|
|
1337
|
+
tokens: [B, T, D]
|
|
1338
|
+
indices: [B, T]
|
|
1339
|
+
updated_mask: [B, T]
|
|
1340
|
+
seqlens: [B]
|
|
1341
|
+
max_length: [1]
|
|
1342
|
+
where T is the max number of unmasked tokens for an instance
|
|
1343
|
+
"""
|
|
1344
|
+
sorted_mask, indices = torch.sort(mask, dim=1, descending=True, stable=True)
|
|
1345
|
+
# Now all the places where we want to keep the token are at the front of the tensor
|
|
1346
|
+
x = x.gather(1, indices[:, :, None].expand_as(x))
|
|
1347
|
+
# Now all tokens that should be kept are first in the tensor
|
|
1348
|
+
|
|
1349
|
+
# set masked values to 0 (not really necessary since we'll ignore them anyway)
|
|
1350
|
+
x = x * sorted_mask.unsqueeze(-1)
|
|
1351
|
+
|
|
1352
|
+
# cut off to the length of the longest sequence
|
|
1353
|
+
seq_lengths = sorted_mask.sum(-1)
|
|
1354
|
+
max_length = seq_lengths.max()
|
|
1355
|
+
x = x[:, :max_length]
|
|
1356
|
+
# New mask chopped to the longest sequence
|
|
1357
|
+
updated_mask = sorted_mask[:, :max_length]
|
|
1358
|
+
|
|
1359
|
+
return x, indices, updated_mask, seq_lengths, max_length
|
|
1360
|
+
|
|
1361
|
+
@staticmethod
|
|
1362
|
+
def add_removed_tokens(
|
|
1363
|
+
x: Tensor, indices: Tensor, mask: Tensor
|
|
1364
|
+
) -> tuple[Tensor, Tensor]:
|
|
1365
|
+
"""Add removed tokens to the tokens and masks.
|
|
1366
|
+
|
|
1367
|
+
Args:
|
|
1368
|
+
x: Tokens to add removed tokens to
|
|
1369
|
+
indices: Original indices of the masked tokens
|
|
1370
|
+
mask: Mask to add removed tokens to
|
|
1371
|
+
|
|
1372
|
+
Returns:
|
|
1373
|
+
tokens: Tokens with removed tokens added
|
|
1374
|
+
mask: Mask with removed tokens added
|
|
1375
|
+
"""
|
|
1376
|
+
assert x.shape[1] > 0, (
|
|
1377
|
+
"x must have at least one token we should not mask all tokens"
|
|
1378
|
+
)
|
|
1379
|
+
masked_tokens = repeat(
|
|
1380
|
+
torch.zeros_like(x[0, 0, :]), "d -> b t d", b=x.shape[0], t=indices.shape[1]
|
|
1381
|
+
)
|
|
1382
|
+
full_mask = torch.cat(
|
|
1383
|
+
(
|
|
1384
|
+
mask,
|
|
1385
|
+
torch.zeros(
|
|
1386
|
+
(x.shape[0], indices.shape[1] - x.shape[1]),
|
|
1387
|
+
device=x.device,
|
|
1388
|
+
dtype=mask.dtype,
|
|
1389
|
+
),
|
|
1390
|
+
),
|
|
1391
|
+
dim=-1,
|
|
1392
|
+
)
|
|
1393
|
+
# can't set value on leaf variable
|
|
1394
|
+
out = masked_tokens.clone()
|
|
1395
|
+
# put tokens in full masked tensor (at the first N positions in every row)
|
|
1396
|
+
out[full_mask] = x[mask]
|
|
1397
|
+
# then move them to their original positions
|
|
1398
|
+
out = out.scatter(1, indices[:, :, None].expand_as(out), out)
|
|
1399
|
+
full_mask = full_mask.scatter(1, indices.expand_as(full_mask), full_mask)
|
|
1400
|
+
# Values that were masked out are not returned but the values that are still there are returned to the original positions
|
|
1401
|
+
return out, full_mask
|
|
1402
|
+
|
|
1403
|
+
def create_exit_seqs(
|
|
1404
|
+
self,
|
|
1405
|
+
tokens_only_dict: dict[str, Tensor],
|
|
1406
|
+
mask_only_dict: dict[str, Tensor],
|
|
1407
|
+
token_exit_cfg: dict[str, int] | None,
|
|
1408
|
+
) -> tuple[Tensor | None]:
|
|
1409
|
+
"""Create the exit sequences and tokens."""
|
|
1410
|
+
# Check that tokens_only_dict doesn't contain any mask keys
|
|
1411
|
+
assert all(not key.endswith("_mask") for key in tokens_only_dict), (
|
|
1412
|
+
"tokens_only_dict should not contain mask keys"
|
|
1413
|
+
)
|
|
1414
|
+
if token_exit_cfg:
|
|
1415
|
+
exit_ids_per_modality = self.create_token_exit_ids(
|
|
1416
|
+
tokens_only_dict, token_exit_cfg
|
|
1417
|
+
)
|
|
1418
|
+
exit_ids_per_modality.update(mask_only_dict)
|
|
1419
|
+
# Exit ids seqs tells us which layer to exit each token
|
|
1420
|
+
exit_ids_seq, _ = self.collapse_and_combine_hwtc(exit_ids_per_modality)
|
|
1421
|
+
else:
|
|
1422
|
+
exit_ids_seq = None
|
|
1423
|
+
return exit_ids_seq
|
|
1424
|
+
|
|
1425
|
+
def _maybe_get_attn_mask(
|
|
1426
|
+
self,
|
|
1427
|
+
new_mask: Tensor,
|
|
1428
|
+
fast_pass: bool,
|
|
1429
|
+
) -> Tensor | None:
|
|
1430
|
+
"""Get the attention mask or None if we should pass None to the transformer."""
|
|
1431
|
+
if fast_pass or not self.training:
|
|
1432
|
+
return None
|
|
1433
|
+
else:
|
|
1434
|
+
return new_mask
|
|
1435
|
+
|
|
1436
|
+
def add_register_tokens_and_masks(
|
|
1437
|
+
self,
|
|
1438
|
+
tokens: Tensor,
|
|
1439
|
+
attn_mask: Tensor | None,
|
|
1440
|
+
processed_register_tokens: Tensor | None = None,
|
|
1441
|
+
) -> tuple[Tensor, Tensor | None]:
|
|
1442
|
+
"""Concatenate register tokens to the tokens."""
|
|
1443
|
+
batch_size = tokens.shape[0]
|
|
1444
|
+
# Expand register tokens to match batch size: [num_register_tokens, embedding_size] -> [batch_size, num_register_tokens, embedding_size]
|
|
1445
|
+
if processed_register_tokens is None:
|
|
1446
|
+
reg_tokens = self.register_tokens.unsqueeze(0).expand(batch_size, -1, -1)
|
|
1447
|
+
else:
|
|
1448
|
+
reg_tokens = processed_register_tokens
|
|
1449
|
+
# Concatenate register tokens at the beginning: [batch_size, seq_len, embedding_size] -> [batch_size, num_register_tokens + seq_len, embedding_size]
|
|
1450
|
+
tokens = torch.cat([reg_tokens, tokens], dim=1)
|
|
1451
|
+
if attn_mask is not None:
|
|
1452
|
+
# Create mask for register tokens (all True - they should participate in attention)
|
|
1453
|
+
reg_mask = torch.ones(
|
|
1454
|
+
batch_size,
|
|
1455
|
+
self.num_register_tokens,
|
|
1456
|
+
dtype=attn_mask.dtype,
|
|
1457
|
+
device=attn_mask.device,
|
|
1458
|
+
)
|
|
1459
|
+
attn_mask = torch.cat([reg_mask, attn_mask], dim=1)
|
|
1460
|
+
else:
|
|
1461
|
+
reg_mask = None
|
|
1462
|
+
return tokens, attn_mask
|
|
1463
|
+
|
|
1464
|
+
def pop_register_tokens(self, tokens: Tensor) -> tuple[Tensor, Tensor]:
|
|
1465
|
+
"""Pop the register tokens from the tokens."""
|
|
1466
|
+
register_tokens = tokens[:, : self.num_register_tokens, :]
|
|
1467
|
+
tokens = tokens[:, self.num_register_tokens :, :]
|
|
1468
|
+
return tokens, register_tokens
|
|
1469
|
+
|
|
1470
|
+
def get_token_norm_stats(
|
|
1471
|
+
self, tokens: Tensor, register_tokens: Tensor
|
|
1472
|
+
) -> dict[str, float]:
|
|
1473
|
+
"""Get the token norm stats."""
|
|
1474
|
+
# Compute norms for register tokens: [batch_size, num_register_tokens]
|
|
1475
|
+
register_tokens_norms = torch.norm(register_tokens, dim=2)
|
|
1476
|
+
reg_norms_flat = register_tokens_norms.flatten()
|
|
1477
|
+
reg_stats = {
|
|
1478
|
+
"register_mean": reg_norms_flat.mean().item(),
|
|
1479
|
+
"register_min": reg_norms_flat.min().item(),
|
|
1480
|
+
"register_max": reg_norms_flat.max().item(),
|
|
1481
|
+
}
|
|
1482
|
+
|
|
1483
|
+
# Compute norms for non-register tokens: [batch_size, seq_len]
|
|
1484
|
+
nonreg_tokens_norms = torch.norm(tokens, dim=2)
|
|
1485
|
+
nonreg_norms_flat = nonreg_tokens_norms.flatten()
|
|
1486
|
+
percentiles = [25.0, 75.0, 90.0, 95.0, 99.0]
|
|
1487
|
+
nonreg_percentiles = torch.quantile(
|
|
1488
|
+
nonreg_norms_flat.float(),
|
|
1489
|
+
torch.tensor(
|
|
1490
|
+
[p / 100.0 for p in percentiles], device=nonreg_norms_flat.device
|
|
1491
|
+
),
|
|
1492
|
+
).tolist()
|
|
1493
|
+
nonreg_stats = {
|
|
1494
|
+
"nonregister_mean": nonreg_norms_flat.mean().item(),
|
|
1495
|
+
"nonregister_min": nonreg_norms_flat.min().item(),
|
|
1496
|
+
"nonregister_max": nonreg_norms_flat.max().item(),
|
|
1497
|
+
"nonregister_std": nonreg_norms_flat.std().item(),
|
|
1498
|
+
"nonregister_25th": nonreg_percentiles[0],
|
|
1499
|
+
"nonregister_75th": nonreg_percentiles[1],
|
|
1500
|
+
"nonregister_90th": nonreg_percentiles[2],
|
|
1501
|
+
"nonregister_95th": nonreg_percentiles[3],
|
|
1502
|
+
"nonregister_99th": nonreg_percentiles[4],
|
|
1503
|
+
}
|
|
1504
|
+
|
|
1505
|
+
token_norm_stats = {**reg_stats, **nonreg_stats}
|
|
1506
|
+
return token_norm_stats
|
|
1507
|
+
|
|
1508
|
+
def _maybe_remove_masked_tokens(
|
|
1509
|
+
self,
|
|
1510
|
+
tokens: Tensor,
|
|
1511
|
+
mask: Tensor,
|
|
1512
|
+
fast_pass: bool,
|
|
1513
|
+
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
1514
|
+
"""Remove masked tokens from the tokens and masks."""
|
|
1515
|
+
if fast_pass and not self.use_flash_attn:
|
|
1516
|
+
# This is the inference fast pass
|
|
1517
|
+
indices = None
|
|
1518
|
+
new_mask = None
|
|
1519
|
+
seq_lengths = None
|
|
1520
|
+
max_seqlen = None
|
|
1521
|
+
bool_mask = None
|
|
1522
|
+
else:
|
|
1523
|
+
bool_mask = mask == MaskValue.ONLINE_ENCODER.value
|
|
1524
|
+
tokens, indices, new_mask, seq_lengths, max_seqlen = (
|
|
1525
|
+
self.remove_masked_tokens(tokens, bool_mask)
|
|
1526
|
+
)
|
|
1527
|
+
return tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask
|
|
1528
|
+
|
|
1529
|
+
def _maybe_add_removed_tokens(
|
|
1530
|
+
self,
|
|
1531
|
+
tokens: Tensor,
|
|
1532
|
+
indices: Tensor,
|
|
1533
|
+
mask: Tensor,
|
|
1534
|
+
fast_pass: bool,
|
|
1535
|
+
) -> Tensor:
|
|
1536
|
+
"""Add removed tokens to the tokens and masks."""
|
|
1537
|
+
if not fast_pass:
|
|
1538
|
+
tokens, _ = self.add_removed_tokens(tokens, indices, mask)
|
|
1539
|
+
return tokens
|
|
1540
|
+
|
|
1541
|
+
def apply_attn(
|
|
1542
|
+
self,
|
|
1543
|
+
x: dict[str, Tensor],
|
|
1544
|
+
timestamps: Tensor,
|
|
1545
|
+
patch_size: int,
|
|
1546
|
+
input_res: int,
|
|
1547
|
+
token_exit_cfg: dict[str, int] | None = None,
|
|
1548
|
+
fast_pass: bool = False,
|
|
1549
|
+
) -> tuple[dict[str, Tensor], dict[str, Any] | None]:
|
|
1550
|
+
"""Apply the attention to the tokens and masks."""
|
|
1551
|
+
tokens_only_dict, original_masks_dict, modalities_to_dims_dict = (
|
|
1552
|
+
self.split_tokens_masks_and_dims(x)
|
|
1553
|
+
)
|
|
1554
|
+
# already a no-op but we could remove entirely
|
|
1555
|
+
exit_ids_seq = self.create_exit_seqs(
|
|
1556
|
+
tokens_only_dict, original_masks_dict, token_exit_cfg
|
|
1557
|
+
)
|
|
1558
|
+
# exited tokens are just the linear projection
|
|
1559
|
+
exited_tokens, _ = self.collapse_and_combine_hwtc(x)
|
|
1560
|
+
|
|
1561
|
+
tokens_dict = self.composite_encodings.forward(
|
|
1562
|
+
tokens_only_dict,
|
|
1563
|
+
timestamps,
|
|
1564
|
+
patch_size,
|
|
1565
|
+
input_res,
|
|
1566
|
+
)
|
|
1567
|
+
tokens_dict.update(original_masks_dict)
|
|
1568
|
+
tokens, mask = self.collapse_and_combine_hwtc(tokens_dict)
|
|
1569
|
+
|
|
1570
|
+
tokens, indices, new_mask, seq_lengths, max_seqlen, bool_mask = (
|
|
1571
|
+
self._maybe_remove_masked_tokens(tokens, mask, fast_pass)
|
|
1572
|
+
)
|
|
1573
|
+
|
|
1574
|
+
if exit_ids_seq is not None:
|
|
1575
|
+
exit_ids_seq, _, _, _, _ = self.remove_masked_tokens(
|
|
1576
|
+
exit_ids_seq, bool_mask
|
|
1577
|
+
)
|
|
1578
|
+
# still linear projections
|
|
1579
|
+
exited_tokens, _, _, _, _ = self.remove_masked_tokens(
|
|
1580
|
+
exited_tokens, bool_mask
|
|
1581
|
+
)
|
|
1582
|
+
|
|
1583
|
+
# Pack x tokens
|
|
1584
|
+
if self.use_flash_attn:
|
|
1585
|
+
cu_seqlens = get_cumulative_sequence_lengths(seq_lengths)
|
|
1586
|
+
og_shape = tokens.shape
|
|
1587
|
+
tokens = self.pack_tokens(tokens, new_mask)
|
|
1588
|
+
else:
|
|
1589
|
+
cu_seqlens = None
|
|
1590
|
+
|
|
1591
|
+
attn_mask = self._maybe_get_attn_mask(
|
|
1592
|
+
new_mask,
|
|
1593
|
+
fast_pass=fast_pass,
|
|
1594
|
+
)
|
|
1595
|
+
|
|
1596
|
+
if self.has_register_tokens:
|
|
1597
|
+
tokens, attn_mask = self.add_register_tokens_and_masks(tokens, attn_mask)
|
|
1598
|
+
|
|
1599
|
+
# Apply attn with varying encoder depths
|
|
1600
|
+
for i_blk, blk in enumerate(self.blocks):
|
|
1601
|
+
# Skip the zeroth block because we want to use the exited tokens that don't have encodings as this allows trivial solution of predicting the shared encodings
|
|
1602
|
+
if (exit_ids_seq is not None) and (i_blk > 0):
|
|
1603
|
+
# this should only ever be called by the target encoder,
|
|
1604
|
+
# in a torch.no_grad context
|
|
1605
|
+
assert exited_tokens is not None
|
|
1606
|
+
# If a token should exit, then we update the exit token with the current token at the same position
|
|
1607
|
+
exited_tokens = torch.where(
|
|
1608
|
+
condition=(exit_ids_seq == i_blk),
|
|
1609
|
+
input=tokens,
|
|
1610
|
+
other=exited_tokens,
|
|
1611
|
+
)
|
|
1612
|
+
# we take the inverse of the mask because a value
|
|
1613
|
+
# of True indicates the value *should* take part in
|
|
1614
|
+
# attention
|
|
1615
|
+
# WARNING: THIS MAY CHANGE DEPENDING ON THE ATTENTION IMPLEMENTATION
|
|
1616
|
+
|
|
1617
|
+
tokens = blk(
|
|
1618
|
+
x=tokens,
|
|
1619
|
+
cu_seqlens=cu_seqlens,
|
|
1620
|
+
max_seqlen=max_seqlen,
|
|
1621
|
+
# we will have to specify k and q lens for cross attention
|
|
1622
|
+
attn_mask=attn_mask,
|
|
1623
|
+
)
|
|
1624
|
+
|
|
1625
|
+
if self.has_register_tokens:
|
|
1626
|
+
tokens, register_tokens = self.pop_register_tokens(tokens)
|
|
1627
|
+
token_norm_stats = (
|
|
1628
|
+
self.get_token_norm_stats(tokens, register_tokens)
|
|
1629
|
+
if self.log_token_norm_stats
|
|
1630
|
+
else None
|
|
1631
|
+
)
|
|
1632
|
+
else:
|
|
1633
|
+
token_norm_stats = None
|
|
1634
|
+
|
|
1635
|
+
if self.use_flash_attn:
|
|
1636
|
+
tokens = self.unpack_tokens(tokens, new_mask, og_shape)
|
|
1637
|
+
|
|
1638
|
+
if exit_ids_seq is not None:
|
|
1639
|
+
# this should only ever be called by the target encoder,
|
|
1640
|
+
# in a torch.no_grad context
|
|
1641
|
+
assert exited_tokens is not None
|
|
1642
|
+
# full depth
|
|
1643
|
+
# IMPORTANT: write this to x
|
|
1644
|
+
tokens = torch.where(
|
|
1645
|
+
condition=(exit_ids_seq == (i_blk + 1)), # 2 for full depth
|
|
1646
|
+
input=tokens,
|
|
1647
|
+
other=exited_tokens,
|
|
1648
|
+
)
|
|
1649
|
+
# we apply the norm before we add the removed tokens,
|
|
1650
|
+
# so that the norm is only computed against "real" tokens
|
|
1651
|
+
tokens = self.norm(tokens)
|
|
1652
|
+
# we don't care about the mask returned by add_removed_tokens, since we will
|
|
1653
|
+
# just use the original, unclipped mask here
|
|
1654
|
+
tokens = self._maybe_add_removed_tokens(tokens, indices, new_mask, fast_pass)
|
|
1655
|
+
|
|
1656
|
+
tokens_per_modality_dict = self.split_and_expand_per_modality(
|
|
1657
|
+
tokens, modalities_to_dims_dict
|
|
1658
|
+
)
|
|
1659
|
+
# merge original masks and the processed tokens
|
|
1660
|
+
tokens_per_modality_dict.update(original_masks_dict)
|
|
1661
|
+
return tokens_per_modality_dict, token_norm_stats
|
|
1662
|
+
|
|
1663
|
+
def forward(
|
|
1664
|
+
self,
|
|
1665
|
+
x: MaskedOlmoEarthSample,
|
|
1666
|
+
patch_size: int,
|
|
1667
|
+
input_res: int = BASE_GSD,
|
|
1668
|
+
token_exit_cfg: dict | None = None,
|
|
1669
|
+
fast_pass: bool = False,
|
|
1670
|
+
) -> dict[str, Any]:
|
|
1671
|
+
"""Process masked input samples into token representations.
|
|
1672
|
+
|
|
1673
|
+
Args:
|
|
1674
|
+
x: Masked input sample containing the data to be encoded
|
|
1675
|
+
patch_size: Size of patches to divide the input into
|
|
1676
|
+
input_res: Resolution of the input data
|
|
1677
|
+
token_exit_cfg: Configuration for token exit
|
|
1678
|
+
fast_pass: Whether to always pass None as the mask to the transformer, this enables torch based flash attention, and skips mask construciton and sorting
|
|
1679
|
+
|
|
1680
|
+
Returns:
|
|
1681
|
+
TokensAndMasks containing the encoded representations and their masks
|
|
1682
|
+
"""
|
|
1683
|
+
if fast_pass and token_exit_cfg is not None:
|
|
1684
|
+
raise ValueError("token_exit_cfg cannot be set when fast_pass is True")
|
|
1685
|
+
|
|
1686
|
+
patchified_tokens_and_masks = self.patch_embeddings.forward(
|
|
1687
|
+
x, patch_size, fast_pass=fast_pass
|
|
1688
|
+
)
|
|
1689
|
+
if token_exit_cfg is None or any(
|
|
1690
|
+
[exit_depth > 0 for exit_depth in token_exit_cfg.values()]
|
|
1691
|
+
):
|
|
1692
|
+
patchified_tokens_and_masks, token_norm_stats = self.apply_attn(
|
|
1693
|
+
x=patchified_tokens_and_masks,
|
|
1694
|
+
timestamps=x.timestamps,
|
|
1695
|
+
patch_size=patch_size,
|
|
1696
|
+
input_res=input_res,
|
|
1697
|
+
token_exit_cfg=token_exit_cfg,
|
|
1698
|
+
fast_pass=fast_pass,
|
|
1699
|
+
)
|
|
1700
|
+
else:
|
|
1701
|
+
token_norm_stats = {}
|
|
1702
|
+
output = TokensAndMasks(**patchified_tokens_and_masks)
|
|
1703
|
+
output_dict: dict[str, Any] = {
|
|
1704
|
+
"tokens_and_masks": output,
|
|
1705
|
+
}
|
|
1706
|
+
if token_norm_stats:
|
|
1707
|
+
output_dict["token_norm_stats"] = token_norm_stats
|
|
1708
|
+
|
|
1709
|
+
if not fast_pass:
|
|
1710
|
+
output_dict["project_aggregated"] = self.project_and_aggregate(output)
|
|
1711
|
+
return output_dict
|
|
1712
|
+
|
|
1713
|
+
def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
|
|
1714
|
+
"""Apply FSDP to the model."""
|
|
1715
|
+
super().apply_fsdp(**fsdp_kwargs)
|
|
1716
|
+
# Don't Shard the small layers
|
|
1717
|
+
# fully_shard(self.patch_embeddings, **fsdp_kwargs)
|
|
1718
|
+
# register_fsdp_forward_method(self.patch_embeddings, "forward")
|
|
1719
|
+
# fully_shard(self.project_and_aggregate, **fsdp_kwargs)
|
|
1720
|
+
# register_fsdp_forward_method(self.project_and_aggregate, "forward")
|
|
1721
|
+
fully_shard(self, **fsdp_kwargs)
|
|
1722
|
+
|
|
1723
|
+
def apply_compile(self) -> None:
|
|
1724
|
+
"""Apply torch.compile to the model."""
|
|
1725
|
+
# self.compile(mode="max-autotune", dynamic=False, fullgraph=True)
|
|
1726
|
+
logger.info("Compiling blocks")
|
|
1727
|
+
# torch.compile(self.blocks, dynamic=False, mode="max-autotune", fullgraph=True)
|
|
1728
|
+
# individual block compile is still a lot slower
|
|
1729
|
+
for block in self.blocks:
|
|
1730
|
+
block.apply_compile()
|
|
1731
|
+
# torch.compile(self.patch_embeddings, dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
|
|
1732
|
+
|
|
1733
|
+
|
|
1734
|
+
class PredictorBase(FlexiVitBase):
|
|
1735
|
+
"""Predictor module that generates predictions from encoded tokens."""
|
|
1736
|
+
|
|
1737
|
+
cross_attn = True
|
|
1738
|
+
|
|
1739
|
+
def __init__(
|
|
1740
|
+
self,
|
|
1741
|
+
supported_modalities: list[ModalitySpec],
|
|
1742
|
+
encoder_embedding_size: int = 128,
|
|
1743
|
+
decoder_embedding_size: int = 128,
|
|
1744
|
+
depth: int = 2,
|
|
1745
|
+
mlp_ratio: float = 2.0,
|
|
1746
|
+
num_heads: int = 8,
|
|
1747
|
+
max_sequence_length: int = 24,
|
|
1748
|
+
drop_path: float = 0.0,
|
|
1749
|
+
learnable_channel_embeddings: bool = True,
|
|
1750
|
+
random_channel_embeddings: bool = False,
|
|
1751
|
+
output_embedding_size: int | None = None,
|
|
1752
|
+
use_flash_attn: bool = False,
|
|
1753
|
+
qk_norm: bool = False,
|
|
1754
|
+
tokenization_config: TokenizationConfig | None = None,
|
|
1755
|
+
):
|
|
1756
|
+
"""Initialize the predictor.
|
|
1757
|
+
|
|
1758
|
+
Args:
|
|
1759
|
+
supported_modalities: modalities this model instantiation supports
|
|
1760
|
+
encoder_embedding_size: Size of encoder embeddings
|
|
1761
|
+
decoder_embedding_size: Size of decoder embeddings
|
|
1762
|
+
depth: Number of transformer layers
|
|
1763
|
+
mlp_ratio: Ratio for MLP hidden dimension
|
|
1764
|
+
num_heads: Number of attention heads
|
|
1765
|
+
max_sequence_length: Maximum sequence length
|
|
1766
|
+
drop_path: Drop path rate
|
|
1767
|
+
learnable_channel_embeddings: Whether to use learnable channel embeddings
|
|
1768
|
+
random_channel_embeddings: Whether to randomly initialize channel embeddings
|
|
1769
|
+
output_embedding_size: Size of output embeddings
|
|
1770
|
+
use_flash_attn: Whether to use flash attention
|
|
1771
|
+
qk_norm: Whether to apply normalization to Q and K in attention
|
|
1772
|
+
tokenization_config: Optional config for custom band groupings
|
|
1773
|
+
"""
|
|
1774
|
+
self.tokenization_config = tokenization_config or TokenizationConfig()
|
|
1775
|
+
super().__init__(
|
|
1776
|
+
embedding_size=decoder_embedding_size,
|
|
1777
|
+
depth=depth,
|
|
1778
|
+
mlp_ratio=mlp_ratio,
|
|
1779
|
+
num_heads=num_heads,
|
|
1780
|
+
max_sequence_length=max_sequence_length,
|
|
1781
|
+
drop_path=drop_path,
|
|
1782
|
+
learnable_channel_embeddings=learnable_channel_embeddings,
|
|
1783
|
+
random_channel_embeddings=random_channel_embeddings,
|
|
1784
|
+
supported_modalities=supported_modalities,
|
|
1785
|
+
use_flash_attn=use_flash_attn,
|
|
1786
|
+
qk_norm=qk_norm,
|
|
1787
|
+
tokenization_config=self.tokenization_config,
|
|
1788
|
+
)
|
|
1789
|
+
self.learnable_channel_embeddings = learnable_channel_embeddings
|
|
1790
|
+
self.random_channel_embeddings = random_channel_embeddings
|
|
1791
|
+
self.encoder_embedding_size = encoder_embedding_size
|
|
1792
|
+
self.encoder_to_decoder_embed = nn.Linear(
|
|
1793
|
+
encoder_embedding_size, decoder_embedding_size, bias=True
|
|
1794
|
+
)
|
|
1795
|
+
if output_embedding_size is None:
|
|
1796
|
+
output_embedding_size = encoder_embedding_size
|
|
1797
|
+
self.output_embedding_size = output_embedding_size
|
|
1798
|
+
self.to_output_embed = nn.Linear(
|
|
1799
|
+
decoder_embedding_size, output_embedding_size, bias=True
|
|
1800
|
+
)
|
|
1801
|
+
# THIS is the learnable mask token
|
|
1802
|
+
self.mask_token = nn.Parameter(torch.zeros(decoder_embedding_size))
|
|
1803
|
+
|
|
1804
|
+
self.input_norm = nn.LayerNorm(encoder_embedding_size)
|
|
1805
|
+
self.norm = nn.LayerNorm(decoder_embedding_size)
|
|
1806
|
+
self.apply(self._init_weights)
|
|
1807
|
+
|
|
1808
|
+
def add_masks(self, x: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
1809
|
+
"""Replace tokens that should be decoded (MaskValue.DECODER_ONLY) with the learnable mask token.
|
|
1810
|
+
|
|
1811
|
+
in a dimension-agnostic way using einops. We assume the final dimension of each token tensor
|
|
1812
|
+
is the embedding dimension matching self.mask_token's size.
|
|
1813
|
+
"""
|
|
1814
|
+
output_dict = {}
|
|
1815
|
+
available_modalities = return_modalities_from_dict(x)
|
|
1816
|
+
modalities_to_process = get_modalities_to_process(
|
|
1817
|
+
available_modalities, self.supported_modality_names
|
|
1818
|
+
)
|
|
1819
|
+
for modality in modalities_to_process:
|
|
1820
|
+
x_modality = x[modality]
|
|
1821
|
+
mask_name = MaskedOlmoEarthSample.get_masked_modality_name(modality)
|
|
1822
|
+
mask_modality = x[mask_name]
|
|
1823
|
+
# A boolean mask: True where tokens must be replaced by the mask token
|
|
1824
|
+
kept_mask = mask_modality == MaskValue.DECODER.value
|
|
1825
|
+
|
|
1826
|
+
# Build the einops pattern and dimension dict
|
|
1827
|
+
spatial_dims = x_modality.shape[
|
|
1828
|
+
:-1
|
|
1829
|
+
] # all dimensions except the last (embedding)
|
|
1830
|
+
pattern_input, dim_dict = self._construct_einops_pattern(spatial_dims)
|
|
1831
|
+
|
|
1832
|
+
mask_token_broadcasted = repeat(self.mask_token, pattern_input, **dim_dict)
|
|
1833
|
+
|
|
1834
|
+
# Where kept_mask is True, use the broadcasted mask token
|
|
1835
|
+
x_modality = torch.where(
|
|
1836
|
+
kept_mask.unsqueeze(-1).bool(), mask_token_broadcasted, x_modality
|
|
1837
|
+
)
|
|
1838
|
+
|
|
1839
|
+
output_dict[modality] = x_modality
|
|
1840
|
+
|
|
1841
|
+
return output_dict
|
|
1842
|
+
|
|
1843
|
+
# TODO: GIVE more explicit function names
|
|
1844
|
+
@staticmethod
|
|
1845
|
+
def split_x_y(tokens: Tensor, mask: Tensor) -> tuple[Tensor, ...]:
|
|
1846
|
+
"""Splits tokens into three groups based on mask values.
|
|
1847
|
+
|
|
1848
|
+
This function:
|
|
1849
|
+
1. Sorts tokens according to the mask and gathers them in order.
|
|
1850
|
+
2. Chooses tokens to be decoded (x) based on the mask value DECODER.
|
|
1851
|
+
3. Chooses tokens to be used as context (y) based on the mask value ONLINE_ENCODER.
|
|
1852
|
+
4. Identifies missing tokens (z) based on the mask value MISSING.
|
|
1853
|
+
5. Returns boolean masks for x, y, and z along with indices to revert to the original ordering.
|
|
1854
|
+
|
|
1855
|
+
Args:
|
|
1856
|
+
tokens: Tokens to split of shape [B, T, D].
|
|
1857
|
+
mask: Mask of shape [B, T].
|
|
1858
|
+
|
|
1859
|
+
Returns:
|
|
1860
|
+
tokens_to_decode: Tokens to be decoded of shape [B, X_len, D].
|
|
1861
|
+
unmasked_tokens: Tokens to be used as context of shape [B, Y_len, D].
|
|
1862
|
+
tokens_to_decode_mask: Binary mask for x tokens of shape [B, X_len].
|
|
1863
|
+
unmasked_tokens_mask: Binary mask for y tokens of shape [B, Y_len].
|
|
1864
|
+
indices: Indices for restoring the original token ordering of shape [B, T].
|
|
1865
|
+
seqlens_tokens_to_decode: Sequence lengths of tokens to decode of shape [B].
|
|
1866
|
+
seqlens_unmasked_tokens: Sequence lengths of unmasked tokens of shape [B].
|
|
1867
|
+
max_length_of_decoded_tokens: Maximum length of decoded tokens of shape [1].
|
|
1868
|
+
max_length_of_unmasked_tokens: Maximum length of unmasked tokens of shape [1].
|
|
1869
|
+
"""
|
|
1870
|
+
# Set Missing Masks to Target Encoder ONLY so that we can have all unused tokens in the middle
|
|
1871
|
+
org_mask_dtype = mask.dtype
|
|
1872
|
+
missing_mask = mask == MaskValue.MISSING.value
|
|
1873
|
+
mask[missing_mask] = MaskValue.TARGET_ENCODER_ONLY.value
|
|
1874
|
+
|
|
1875
|
+
# Sort tokens by mask value (descending order)
|
|
1876
|
+
sorted_mask, indices = torch.sort(
|
|
1877
|
+
mask.int(), dim=1, descending=True, stable=True
|
|
1878
|
+
)
|
|
1879
|
+
tokens = tokens.gather(1, indices[:, :, None].expand_as(tokens))
|
|
1880
|
+
|
|
1881
|
+
# Create binary masks for Encoder and Decoder
|
|
1882
|
+
binarized_decoder_mask = sorted_mask == MaskValue.DECODER.value
|
|
1883
|
+
binarized_online_encoder_mask = sorted_mask == MaskValue.ONLINE_ENCODER.value
|
|
1884
|
+
|
|
1885
|
+
seqlens_unmasked_tokens = binarized_online_encoder_mask.sum(dim=-1)
|
|
1886
|
+
max_length_of_unmasked_tokens = seqlens_unmasked_tokens.max()
|
|
1887
|
+
seqlens_tokens_to_decode = binarized_decoder_mask.sum(dim=-1)
|
|
1888
|
+
max_length_of_decoded_tokens = seqlens_tokens_to_decode.max()
|
|
1889
|
+
|
|
1890
|
+
# the y mask is going to be used to determine which of the y values take. True values
|
|
1891
|
+
# take part in the attention (we don't take the inverse here, unlike in the decoder)
|
|
1892
|
+
tokens_to_decode = tokens[:, :max_length_of_decoded_tokens]
|
|
1893
|
+
tokens_to_decode_mask = binarized_decoder_mask[
|
|
1894
|
+
:, :max_length_of_decoded_tokens
|
|
1895
|
+
].to(org_mask_dtype)
|
|
1896
|
+
|
|
1897
|
+
unmasked_tokens = tokens[:, -max_length_of_unmasked_tokens:]
|
|
1898
|
+
# the x_mask is just going to be used in the reconstruction, to know which
|
|
1899
|
+
# x tokens to add back into the token list. TODO is this even necessary? it could
|
|
1900
|
+
# get padded with noise tokens since we don't care about reconstruction at all
|
|
1901
|
+
# for a whole bunch of tokens
|
|
1902
|
+
unmasked_tokens_mask = binarized_online_encoder_mask[
|
|
1903
|
+
:, -max_length_of_unmasked_tokens:
|
|
1904
|
+
].to(org_mask_dtype)
|
|
1905
|
+
|
|
1906
|
+
return (
|
|
1907
|
+
tokens_to_decode,
|
|
1908
|
+
unmasked_tokens,
|
|
1909
|
+
tokens_to_decode_mask,
|
|
1910
|
+
unmasked_tokens_mask,
|
|
1911
|
+
indices,
|
|
1912
|
+
seqlens_tokens_to_decode,
|
|
1913
|
+
seqlens_unmasked_tokens,
|
|
1914
|
+
max_length_of_decoded_tokens,
|
|
1915
|
+
max_length_of_unmasked_tokens,
|
|
1916
|
+
)
|
|
1917
|
+
|
|
1918
|
+
@staticmethod
|
|
1919
|
+
def combine_x_y(
|
|
1920
|
+
tokens_to_decode: Tensor,
|
|
1921
|
+
unmasked_tokens: Tensor,
|
|
1922
|
+
tokens_to_decode_mask: Tensor,
|
|
1923
|
+
unmasked_tokens_mask: Tensor,
|
|
1924
|
+
indices: Tensor,
|
|
1925
|
+
) -> Tensor:
|
|
1926
|
+
"""Reintegrate the separated token sequences into their original order.
|
|
1927
|
+
|
|
1928
|
+
The token masks zero out positions which are not used/needed,
|
|
1929
|
+
and the final scatter step re-applies the original ordering tracked in 'indices'.
|
|
1930
|
+
|
|
1931
|
+
Args:
|
|
1932
|
+
tokens_to_decode: Key/value tokens of shape [B, X_len, D].
|
|
1933
|
+
unmasked_tokens: Query tokens of shape [B, Y_len, D].
|
|
1934
|
+
tokens_to_decode_mask: Binary mask for tokens to decode of shape [B, X_len].
|
|
1935
|
+
unmasked_tokens_mask: Binary mask for unmasked tokens of shape [B, Y_len].
|
|
1936
|
+
indices: Indices for restoring the original token ordering of shape [B, T].
|
|
1937
|
+
|
|
1938
|
+
Returns:
|
|
1939
|
+
A merged tokens tensor of shape [B, T, D] with all tokens in their
|
|
1940
|
+
original positions.
|
|
1941
|
+
"""
|
|
1942
|
+
# Get dimensions
|
|
1943
|
+
B, T = indices.shape[0], indices.shape[1]
|
|
1944
|
+
D = tokens_to_decode.shape[-1]
|
|
1945
|
+
tokens = torch.zeros(
|
|
1946
|
+
(B, T, D), dtype=tokens_to_decode.dtype, device=tokens_to_decode.device
|
|
1947
|
+
)
|
|
1948
|
+
tokens[:, -unmasked_tokens.shape[1] :] = (
|
|
1949
|
+
unmasked_tokens * unmasked_tokens_mask.unsqueeze(-1)
|
|
1950
|
+
)
|
|
1951
|
+
tokens[:, : tokens_to_decode.shape[1]] += (
|
|
1952
|
+
tokens_to_decode * tokens_to_decode_mask.unsqueeze(-1)
|
|
1953
|
+
)
|
|
1954
|
+
tokens = tokens.scatter(1, indices[:, :, None].expand_as(tokens), tokens)
|
|
1955
|
+
return tokens
|
|
1956
|
+
|
|
1957
|
+
def is_any_data_to_be_decoded(self, modality_mask: Tensor) -> bool:
|
|
1958
|
+
"""Check if any data is to be decoded for a given modality."""
|
|
1959
|
+
return (MaskValue.DECODER.value == modality_mask).any()
|
|
1960
|
+
|
|
1961
|
+
def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
|
|
1962
|
+
"""Apply FSDP to the model."""
|
|
1963
|
+
super().apply_fsdp(**fsdp_kwargs)
|
|
1964
|
+
fully_shard(self, **fsdp_kwargs)
|
|
1965
|
+
|
|
1966
|
+
|
|
1967
|
+
class Predictor(PredictorBase):
|
|
1968
|
+
"""Predictor module that generates predictions from encoded tokens."""
|
|
1969
|
+
|
|
1970
|
+
cross_attn = True
|
|
1971
|
+
|
|
1972
|
+
def apply_attn(
|
|
1973
|
+
self,
|
|
1974
|
+
x: dict[str, Tensor],
|
|
1975
|
+
timestamps: Tensor,
|
|
1976
|
+
patch_size: int,
|
|
1977
|
+
input_res: int,
|
|
1978
|
+
) -> dict[str, Tensor]:
|
|
1979
|
+
"""Apply attention to the tokens."""
|
|
1980
|
+
tokens_only_dict, original_masks_dict, modalities_to_dims_dict = (
|
|
1981
|
+
self.split_tokens_masks_and_dims(x)
|
|
1982
|
+
)
|
|
1983
|
+
tokens_dict = self.composite_encodings(
|
|
1984
|
+
tokens_only_dict, timestamps, patch_size, input_res
|
|
1985
|
+
)
|
|
1986
|
+
tokens_dict.update(original_masks_dict)
|
|
1987
|
+
all_tokens, mask = self.collapse_and_combine_hwtc(tokens_dict)
|
|
1988
|
+
# X contains the tokens to decode, Y contains the tokens to attend to for context
|
|
1989
|
+
(
|
|
1990
|
+
tokens_to_decode,
|
|
1991
|
+
unmasked_tokens,
|
|
1992
|
+
tokens_to_decode_mask,
|
|
1993
|
+
unmasked_tokens_mask,
|
|
1994
|
+
indices,
|
|
1995
|
+
seqlens_tokens_to_decode,
|
|
1996
|
+
seqlens_unmasked_tokens,
|
|
1997
|
+
max_length_of_tokens_to_decode,
|
|
1998
|
+
max_length_of_unmasked_tokens,
|
|
1999
|
+
) = self.split_x_y(all_tokens, mask)
|
|
2000
|
+
# Pack x tokens
|
|
2001
|
+
if self.use_flash_attn:
|
|
2002
|
+
og_shape_tokens_to_decode = tokens_to_decode.shape
|
|
2003
|
+
tokens_to_decode = self.pack_tokens(
|
|
2004
|
+
tokens_to_decode, tokens_to_decode_mask.bool()
|
|
2005
|
+
)
|
|
2006
|
+
og_shape_unmasked_tokens = unmasked_tokens.shape
|
|
2007
|
+
unmasked_tokens = self.pack_tokens(
|
|
2008
|
+
unmasked_tokens, unmasked_tokens_mask.bool()
|
|
2009
|
+
)
|
|
2010
|
+
cu_seqlens_tokens_to_decode = get_cumulative_sequence_lengths(
|
|
2011
|
+
seqlens_tokens_to_decode
|
|
2012
|
+
)
|
|
2013
|
+
cu_seqlens_unmasked_tokens = get_cumulative_sequence_lengths(
|
|
2014
|
+
seqlens_unmasked_tokens
|
|
2015
|
+
)
|
|
2016
|
+
else:
|
|
2017
|
+
cu_seqlens_tokens_to_decode = None
|
|
2018
|
+
cu_seqlens_unmasked_tokens = None
|
|
2019
|
+
|
|
2020
|
+
for blk in self.blocks:
|
|
2021
|
+
# note that we are not taking the inverse of the mask, since split_x_y gives us
|
|
2022
|
+
# true values for values we want to take part in attention
|
|
2023
|
+
tokens_to_decode = blk(
|
|
2024
|
+
x=tokens_to_decode,
|
|
2025
|
+
y=unmasked_tokens,
|
|
2026
|
+
attn_mask=(
|
|
2027
|
+
unmasked_tokens_mask.bool() if not self.use_flash_attn else None
|
|
2028
|
+
), # only for flash attn though this should not be left in
|
|
2029
|
+
cu_seqlens_q=cu_seqlens_tokens_to_decode,
|
|
2030
|
+
cu_seqlens_k=cu_seqlens_unmasked_tokens,
|
|
2031
|
+
max_seqlen_q=max_length_of_tokens_to_decode,
|
|
2032
|
+
max_seqlen_k=max_length_of_unmasked_tokens,
|
|
2033
|
+
)
|
|
2034
|
+
|
|
2035
|
+
if self.use_flash_attn:
|
|
2036
|
+
tokens_to_decode = self.unpack_tokens(
|
|
2037
|
+
tokens_to_decode,
|
|
2038
|
+
tokens_to_decode_mask.bool(),
|
|
2039
|
+
og_shape_tokens_to_decode,
|
|
2040
|
+
)
|
|
2041
|
+
unmasked_tokens = self.unpack_tokens(
|
|
2042
|
+
unmasked_tokens, unmasked_tokens_mask.bool(), og_shape_unmasked_tokens
|
|
2043
|
+
)
|
|
2044
|
+
|
|
2045
|
+
x = self.combine_x_y(
|
|
2046
|
+
tokens_to_decode=tokens_to_decode,
|
|
2047
|
+
unmasked_tokens=unmasked_tokens,
|
|
2048
|
+
tokens_to_decode_mask=tokens_to_decode_mask,
|
|
2049
|
+
unmasked_tokens_mask=unmasked_tokens_mask,
|
|
2050
|
+
indices=indices,
|
|
2051
|
+
)
|
|
2052
|
+
tokens_per_modality_dict = self.split_and_expand_per_modality(
|
|
2053
|
+
x, modalities_to_dims_dict
|
|
2054
|
+
)
|
|
2055
|
+
tokens_per_modality_dict.update(original_masks_dict)
|
|
2056
|
+
return tokens_per_modality_dict
|
|
2057
|
+
|
|
2058
|
+
def forward(
|
|
2059
|
+
self,
|
|
2060
|
+
x: TokensAndMasks,
|
|
2061
|
+
timestamps: Tensor,
|
|
2062
|
+
patch_size: int,
|
|
2063
|
+
input_res: int = BASE_GSD,
|
|
2064
|
+
) -> TokensAndMasks:
|
|
2065
|
+
"""Generate predictions from encoded token representations.
|
|
2066
|
+
|
|
2067
|
+
Args:
|
|
2068
|
+
x: TokensAndMasks containing the encoded tokens to make predictions from
|
|
2069
|
+
timestamps: Timestamps of the tokens
|
|
2070
|
+
patch_size: Patch size of the tokens
|
|
2071
|
+
input_res: Input resolution of the tokens
|
|
2072
|
+
|
|
2073
|
+
Returns:
|
|
2074
|
+
TokensAndMasks containing the predicted tokens and their masks
|
|
2075
|
+
"""
|
|
2076
|
+
decoder_emedded_dict = x.as_dict(return_none=False)
|
|
2077
|
+
# Apply Input Norms and encoder to decoder embeds to each modality
|
|
2078
|
+
available_modalities = x.modalities
|
|
2079
|
+
modalities_to_process = get_modalities_to_process(
|
|
2080
|
+
available_modalities, self.supported_modality_names
|
|
2081
|
+
)
|
|
2082
|
+
for modality in modalities_to_process:
|
|
2083
|
+
x_modality = getattr(x, modality)
|
|
2084
|
+
# Although, we do not account for missing tokens both proj and normalize are on token dimension so there is no mixing with real tokens
|
|
2085
|
+
x_modality = self.input_norm(x_modality)
|
|
2086
|
+
x_modality = self.encoder_to_decoder_embed(x_modality)
|
|
2087
|
+
masked_modality_name = x.get_masked_modality_name(modality)
|
|
2088
|
+
decoder_emedded_dict[modality] = x_modality
|
|
2089
|
+
decoder_emedded_dict[masked_modality_name] = getattr(
|
|
2090
|
+
x, masked_modality_name
|
|
2091
|
+
)
|
|
2092
|
+
|
|
2093
|
+
tokens_only_dict = self.add_masks(decoder_emedded_dict)
|
|
2094
|
+
decoder_emedded_dict.update(tokens_only_dict)
|
|
2095
|
+
tokens_and_masks = self.apply_attn(
|
|
2096
|
+
decoder_emedded_dict, timestamps, patch_size, input_res
|
|
2097
|
+
)
|
|
2098
|
+
# TODO: Factor this out into a more readable function
|
|
2099
|
+
output_dict = {}
|
|
2100
|
+
available_modalities = return_modalities_from_dict(tokens_and_masks)
|
|
2101
|
+
modalities_to_process = get_modalities_to_process(
|
|
2102
|
+
available_modalities, self.supported_modality_names
|
|
2103
|
+
)
|
|
2104
|
+
for modality in modalities_to_process:
|
|
2105
|
+
masked_modality_name = MaskedOlmoEarthSample.get_masked_modality_name(
|
|
2106
|
+
modality
|
|
2107
|
+
)
|
|
2108
|
+
modality_mask = tokens_and_masks[masked_modality_name]
|
|
2109
|
+
# patchify masked data
|
|
2110
|
+
per_modality_output_tokens = []
|
|
2111
|
+
modality_data = tokens_and_masks[modality]
|
|
2112
|
+
|
|
2113
|
+
num_band_sets = self.tokenization_config.get_num_bandsets(modality)
|
|
2114
|
+
for idx in range(num_band_sets):
|
|
2115
|
+
per_channel_modality_data = modality_data[..., idx, :]
|
|
2116
|
+
output_data = self.to_output_embed(self.norm(per_channel_modality_data))
|
|
2117
|
+
per_modality_output_tokens.append(output_data)
|
|
2118
|
+
output_dict[modality] = torch.stack(per_modality_output_tokens, dim=-2)
|
|
2119
|
+
output_dict[masked_modality_name] = modality_mask
|
|
2120
|
+
return TokensAndMasks(**output_dict)
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
@dataclass
|
|
2124
|
+
class EncoderConfig(Config):
|
|
2125
|
+
"""Configuration for the Encoder."""
|
|
2126
|
+
|
|
2127
|
+
supported_modality_names: list[str]
|
|
2128
|
+
|
|
2129
|
+
embedding_size: int = 16
|
|
2130
|
+
# This is the base patch size for the patch embedder
|
|
2131
|
+
max_patch_size: int = 8
|
|
2132
|
+
min_patch_size: int = 1
|
|
2133
|
+
num_heads: int = 2
|
|
2134
|
+
mlp_ratio: float = 1.0
|
|
2135
|
+
depth: int = 2
|
|
2136
|
+
drop_path: float = 0.1
|
|
2137
|
+
max_sequence_length: int = 12
|
|
2138
|
+
num_register_tokens: int = 0
|
|
2139
|
+
learnable_channel_embeddings: bool = True
|
|
2140
|
+
random_channel_embeddings: bool = False
|
|
2141
|
+
num_projection_layers: int = 1
|
|
2142
|
+
aggregate_then_project: bool = True
|
|
2143
|
+
use_flash_attn: bool = False
|
|
2144
|
+
frozen_patch_embeddings: bool = False
|
|
2145
|
+
qk_norm: bool = False
|
|
2146
|
+
log_token_norm_stats: bool = False
|
|
2147
|
+
tokenization_config: TokenizationConfig | None = None
|
|
2148
|
+
|
|
2149
|
+
def validate(self) -> None:
|
|
2150
|
+
"""Validate the configuration."""
|
|
2151
|
+
if len(self.supported_modalities) == 0:
|
|
2152
|
+
raise ValueError("At least one modality must be added!")
|
|
2153
|
+
else:
|
|
2154
|
+
for modality in self.supported_modalities:
|
|
2155
|
+
if modality not in Modality.values():
|
|
2156
|
+
raise ValueError(f"Modality {modality} is not supported")
|
|
2157
|
+
if self.tokenization_config is not None:
|
|
2158
|
+
self.tokenization_config.validate()
|
|
2159
|
+
|
|
2160
|
+
@property
|
|
2161
|
+
def supported_modalities(self) -> list[ModalitySpec]:
|
|
2162
|
+
"""Get the supported modalities."""
|
|
2163
|
+
return get_modality_specs_from_names(self.supported_modality_names)
|
|
2164
|
+
|
|
2165
|
+
def build(self) -> "Encoder":
|
|
2166
|
+
"""Build the encoder."""
|
|
2167
|
+
self.validate()
|
|
2168
|
+
kwargs = self.as_dict(exclude_none=True, recurse=False)
|
|
2169
|
+
# supported_modality_names is replaced by supported_modalities
|
|
2170
|
+
kwargs.pop("supported_modality_names")
|
|
2171
|
+
kwargs["supported_modalities"] = self.supported_modalities
|
|
2172
|
+
logger.info(f"Encoder kwargs: {kwargs}")
|
|
2173
|
+
return Encoder(**kwargs)
|
|
2174
|
+
|
|
2175
|
+
|
|
2176
|
+
@dataclass
|
|
2177
|
+
class PredictorConfig(Config):
|
|
2178
|
+
"""Configuration for the Predictor."""
|
|
2179
|
+
|
|
2180
|
+
supported_modality_names: list[str]
|
|
2181
|
+
encoder_embedding_size: int = 16
|
|
2182
|
+
decoder_embedding_size: int = 16
|
|
2183
|
+
depth: int = 2
|
|
2184
|
+
mlp_ratio: float = 1.0
|
|
2185
|
+
num_heads: int = 2
|
|
2186
|
+
max_sequence_length: int = 12
|
|
2187
|
+
drop_path: float = 0.0
|
|
2188
|
+
learnable_channel_embeddings: bool = True
|
|
2189
|
+
random_channel_embeddings: bool = False
|
|
2190
|
+
output_embedding_size: int | None = None
|
|
2191
|
+
use_flash_attn: bool = False
|
|
2192
|
+
qk_norm: bool = False
|
|
2193
|
+
tokenization_config: TokenizationConfig | None = None
|
|
2194
|
+
|
|
2195
|
+
def validate(self) -> None:
|
|
2196
|
+
"""Validate the configuration."""
|
|
2197
|
+
if len(self.supported_modalities) == 0:
|
|
2198
|
+
raise ValueError("At least one modality must be added!")
|
|
2199
|
+
else:
|
|
2200
|
+
for modality in self.supported_modalities:
|
|
2201
|
+
if modality not in Modality.values():
|
|
2202
|
+
raise ValueError(f"Modality {modality} is not supported")
|
|
2203
|
+
if self.tokenization_config is not None:
|
|
2204
|
+
self.tokenization_config.validate()
|
|
2205
|
+
|
|
2206
|
+
@property
|
|
2207
|
+
def supported_modalities(self) -> list[ModalitySpec]:
|
|
2208
|
+
"""Get the supported modalities."""
|
|
2209
|
+
return get_modality_specs_from_names(self.supported_modality_names)
|
|
2210
|
+
|
|
2211
|
+
def build(self) -> "PredictorBase":
|
|
2212
|
+
"""Build the predictor."""
|
|
2213
|
+
self.validate()
|
|
2214
|
+
kwargs = self.as_dict(exclude_none=True, recurse=False)
|
|
2215
|
+
# supported_modality_names is replaced by supported_modalities
|
|
2216
|
+
kwargs.pop("supported_modality_names")
|
|
2217
|
+
kwargs["supported_modalities"] = self.supported_modalities
|
|
2218
|
+
logger.info(f"Predictor kwargs: {kwargs}")
|
|
2219
|
+
return Predictor(**kwargs)
|