olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- olmoearth_pretrain_minimal/__init__.py +16 -0
- olmoearth_pretrain_minimal/model_loader.py +123 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
- olmoearth_pretrain_minimal/test.py +51 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,559 @@
|
|
|
1
|
+
"""Attention Components for OlmoEarth Pretrain."""
|
|
2
|
+
|
|
3
|
+
from logging import getLogger
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
from torch.distributed.fsdp import fully_shard
|
|
11
|
+
from torch.jit import Final
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import flash_attn
|
|
15
|
+
except ImportError:
|
|
16
|
+
flash_attn = None
|
|
17
|
+
|
|
18
|
+
logger = getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@torch._dynamo.disable()
|
|
22
|
+
def dispatch_flash_attn(
|
|
23
|
+
q: torch.Tensor,
|
|
24
|
+
k: torch.Tensor,
|
|
25
|
+
v: torch.Tensor,
|
|
26
|
+
*,
|
|
27
|
+
cu_seqlens: torch.Tensor | None = None,
|
|
28
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
29
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
30
|
+
max_seqlen: int | None = None,
|
|
31
|
+
max_seqlen_q: int | None = None,
|
|
32
|
+
max_seqlen_k: int | None = None,
|
|
33
|
+
dropout_p: float = 0.0,
|
|
34
|
+
softmax_scale: float | None = None,
|
|
35
|
+
causal: bool = False,
|
|
36
|
+
) -> torch.Tensor:
|
|
37
|
+
"""Dispatch flash attention.
|
|
38
|
+
|
|
39
|
+
Modeled after olmo core but doesnt flatten internally
|
|
40
|
+
"""
|
|
41
|
+
if flash_attn is None:
|
|
42
|
+
raise RuntimeError("flash-attn is required!")
|
|
43
|
+
|
|
44
|
+
if cu_seqlens is not None:
|
|
45
|
+
if cu_seqlens_q is None:
|
|
46
|
+
cu_seqlens_q = cu_seqlens
|
|
47
|
+
if cu_seqlens_k is None:
|
|
48
|
+
cu_seqlens_k = cu_seqlens
|
|
49
|
+
if max_seqlen is not None:
|
|
50
|
+
if max_seqlen_q is None:
|
|
51
|
+
max_seqlen_q = max_seqlen
|
|
52
|
+
if max_seqlen_k is None:
|
|
53
|
+
max_seqlen_k = max_seqlen
|
|
54
|
+
|
|
55
|
+
varlen = all(
|
|
56
|
+
x is not None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if varlen:
|
|
60
|
+
assert q.ndim == 3, "q must be pre-packed"
|
|
61
|
+
logger.debug("using varlen")
|
|
62
|
+
|
|
63
|
+
return flash_attn.flash_attn_varlen_func(
|
|
64
|
+
q,
|
|
65
|
+
k,
|
|
66
|
+
v,
|
|
67
|
+
cu_seqlens_q,
|
|
68
|
+
cu_seqlens_k,
|
|
69
|
+
max_seqlen_q,
|
|
70
|
+
max_seqlen_k,
|
|
71
|
+
dropout_p=dropout_p,
|
|
72
|
+
softmax_scale=softmax_scale,
|
|
73
|
+
causal=causal,
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
return flash_attn.flash_attn_func(
|
|
77
|
+
q,
|
|
78
|
+
k,
|
|
79
|
+
v,
|
|
80
|
+
dropout_p=dropout_p,
|
|
81
|
+
softmax_scale=softmax_scale,
|
|
82
|
+
causal=causal,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Attention(nn.Module):
|
|
87
|
+
"""Multi-head attention module with optional cross-attention support.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
dim: Input dimension
|
|
91
|
+
num_heads: Number of attention heads. Defaults to 8.
|
|
92
|
+
qkv_bias: Enable bias for QKV projections. Defaults to False.
|
|
93
|
+
qk_norm: Apply normalization to Q and K. Defaults to False.
|
|
94
|
+
attn_drop: Attention dropout rate. Defaults to 0.0.
|
|
95
|
+
proj_drop: Output projection dropout rate. Defaults to 0.0.
|
|
96
|
+
norm_layer: Normalization layer. Defaults to nn.LayerNorm.
|
|
97
|
+
cross_attn: Enable cross-attention. Defaults to False.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
fast_attn: Final[bool]
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
dim: int,
|
|
105
|
+
num_heads: int = 8,
|
|
106
|
+
qkv_bias: bool = False,
|
|
107
|
+
qk_norm: bool = False,
|
|
108
|
+
attn_drop: float = 0.0,
|
|
109
|
+
proj_drop: float = 0.0,
|
|
110
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
111
|
+
cross_attn: bool = False,
|
|
112
|
+
use_flash_attn: bool = False,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Initialize the attention module.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
dim: Input dimension
|
|
118
|
+
num_heads: Number of attention heads
|
|
119
|
+
qkv_bias: Enable bias for QKV projections
|
|
120
|
+
qk_norm: Apply normalization to Q and K
|
|
121
|
+
attn_drop: Attention dropout rate
|
|
122
|
+
proj_drop: Output projection dropout rate
|
|
123
|
+
norm_layer: Normalization layer
|
|
124
|
+
cross_attn: Enable cross-attention
|
|
125
|
+
use_flash_attn: Use flash attention
|
|
126
|
+
"""
|
|
127
|
+
super().__init__()
|
|
128
|
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
|
129
|
+
self.num_heads = num_heads
|
|
130
|
+
self.head_dim = dim // num_heads
|
|
131
|
+
self.scale = self.head_dim**-0.5
|
|
132
|
+
|
|
133
|
+
self.cross_attn = cross_attn
|
|
134
|
+
self.use_flash_attn = use_flash_attn
|
|
135
|
+
self.fast_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
|
136
|
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
|
137
|
+
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
|
138
|
+
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
|
139
|
+
|
|
140
|
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
141
|
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
142
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
143
|
+
self.proj = nn.Linear(dim, dim)
|
|
144
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
145
|
+
|
|
146
|
+
def sdpa(
|
|
147
|
+
self,
|
|
148
|
+
q: torch.Tensor,
|
|
149
|
+
k: torch.Tensor,
|
|
150
|
+
v: torch.Tensor,
|
|
151
|
+
n: int,
|
|
152
|
+
cu_seqlens: torch.Tensor | None = None,
|
|
153
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
154
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
155
|
+
max_seqlen: int | None = None,
|
|
156
|
+
max_seqlen_q: int | None = None,
|
|
157
|
+
max_seqlen_k: int | None = None,
|
|
158
|
+
attn_mask: torch.Tensor | None = None,
|
|
159
|
+
) -> torch.Tensor:
|
|
160
|
+
"""Compute scaled dot product attention.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
q: Query tensor of shape (B, H, N, D)
|
|
164
|
+
k: Key tensor of shape (B, H, N, D)
|
|
165
|
+
v: Value tensor of shape (B, H, N, D)
|
|
166
|
+
n: Number of tokens
|
|
167
|
+
attn_mask: Attention mask. Defaults to None.
|
|
168
|
+
cu_seqlens: Optional cumulative sequence lengths for the input tensor needed for varlen flash attention
|
|
169
|
+
cu_seqlens_q: Optional cumulative sequence lengths for the query tensor, needed for cross varlen flash attention
|
|
170
|
+
cu_seqlens_k: Optional cumulative sequence lengths for the key tensor, needed for cross varlen flash attention
|
|
171
|
+
max_seqlen: Optional maximum sequence length for the input tensor, needed for varlen flash attention
|
|
172
|
+
max_seqlen_q: Optional maximum sequence length for the query tensor, needed for cross varlen flash attention
|
|
173
|
+
max_seqlen_k: Optional maximum sequence length for the key tensor, needed for cross varlen flash attention
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Output tensor of shape (B, H, N, D)
|
|
177
|
+
"""
|
|
178
|
+
if self.use_flash_attn:
|
|
179
|
+
x = dispatch_flash_attn(
|
|
180
|
+
q,
|
|
181
|
+
k,
|
|
182
|
+
v,
|
|
183
|
+
cu_seqlens=cu_seqlens,
|
|
184
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
185
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
186
|
+
max_seqlen=max_seqlen,
|
|
187
|
+
max_seqlen_q=max_seqlen_q,
|
|
188
|
+
max_seqlen_k=max_seqlen_k,
|
|
189
|
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
|
190
|
+
softmax_scale=self.scale,
|
|
191
|
+
causal=False,
|
|
192
|
+
)
|
|
193
|
+
# Output is (B, Nq, H, D), transpose back to (B, H, Nq, D)
|
|
194
|
+
# matching the transpose of the other attention implementations that need to be transposed back
|
|
195
|
+
x = x.transpose(1, 2)
|
|
196
|
+
elif self.fast_attn:
|
|
197
|
+
if attn_mask is not None:
|
|
198
|
+
attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, n, 1))
|
|
199
|
+
x = F.scaled_dot_product_attention(
|
|
200
|
+
q,
|
|
201
|
+
k,
|
|
202
|
+
v,
|
|
203
|
+
# a value of True indicates that the element should take part in attention
|
|
204
|
+
attn_mask=attn_mask,
|
|
205
|
+
dropout_p=self.attn_drop.p,
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
# Backward Compatible for older PyTorch versions
|
|
209
|
+
if attn_mask is not None:
|
|
210
|
+
raise NotImplementedError
|
|
211
|
+
q = q * self.scale
|
|
212
|
+
attn = q @ k.transpose(-2, -1)
|
|
213
|
+
attn = attn.softmax(dim=-1)
|
|
214
|
+
attn = self.attn_drop(attn)
|
|
215
|
+
x = attn @ v
|
|
216
|
+
|
|
217
|
+
return x
|
|
218
|
+
|
|
219
|
+
def forward(
|
|
220
|
+
self,
|
|
221
|
+
x: torch.Tensor,
|
|
222
|
+
y: torch.Tensor | None = None,
|
|
223
|
+
cu_seqlens: torch.Tensor | None = None,
|
|
224
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
225
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
226
|
+
max_seqlen: int | None = None,
|
|
227
|
+
max_seqlen_q: int | None = None,
|
|
228
|
+
max_seqlen_k: int | None = None,
|
|
229
|
+
attn_mask: torch.Tensor | None = None,
|
|
230
|
+
) -> torch.Tensor:
|
|
231
|
+
"""Forward pass.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
x: Input tensor of shape (B, N, C) or (B* N , C) if packed
|
|
235
|
+
y: Second input for cross-attention. Defaults to None.
|
|
236
|
+
attn_mask: Attention mask. Defaults to None.
|
|
237
|
+
cu_seqlens: Optional cumulative sequence lengths for the input tensor needed for varlen flash attention
|
|
238
|
+
cu_seqlens_q: Optional cumulative sequence lengths for the query tensor, needed for cross varlen flash attention
|
|
239
|
+
cu_seqlens_k: Optional cumulative sequence lengths for the key tensor, needed for cross varlen flash attention
|
|
240
|
+
max_seqlen: Optional maximum sequence length for the input tensor, needed for varlen flash attention
|
|
241
|
+
max_seqlen_q: Optional maximum sequence length for the query tensor, needed for cross varlen flash attention
|
|
242
|
+
max_seqlen_k: Optional maximum sequence length for the key tensor, needed for cross varlen flash attention
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Output tensor of shape (B, N, C) or (B* N , C) if packed
|
|
246
|
+
"""
|
|
247
|
+
original_shape = x.shape
|
|
248
|
+
|
|
249
|
+
q = self.q(x)
|
|
250
|
+
|
|
251
|
+
if y is None:
|
|
252
|
+
assert not self.cross_attn
|
|
253
|
+
k = self.k(x)
|
|
254
|
+
v = self.v(x)
|
|
255
|
+
else:
|
|
256
|
+
assert self.cross_attn
|
|
257
|
+
k = self.k(y)
|
|
258
|
+
v = self.v(y)
|
|
259
|
+
if not self.use_flash_attn:
|
|
260
|
+
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
|
|
261
|
+
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
|
|
262
|
+
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
|
|
263
|
+
else:
|
|
264
|
+
q = rearrange(q, "bn (h d) -> bn h d", h=self.num_heads)
|
|
265
|
+
# Flash attention only supports k v heads that divide the number of query heads
|
|
266
|
+
k = rearrange(k, "bn (h d) -> bn h d", h=self.num_heads)
|
|
267
|
+
v = rearrange(v, "bn (h d) -> bn h d", h=self.num_heads)
|
|
268
|
+
# logger.info(f"q shape: {q.shape} k shape: {k.shape} v shape: {v.shape}")
|
|
269
|
+
|
|
270
|
+
q, k = self.q_norm(q), self.k_norm(k)
|
|
271
|
+
x = self.sdpa(
|
|
272
|
+
q,
|
|
273
|
+
k,
|
|
274
|
+
v,
|
|
275
|
+
n=original_shape[
|
|
276
|
+
-2
|
|
277
|
+
], # supposed to be the number of tokens in each sample with padding
|
|
278
|
+
cu_seqlens=cu_seqlens,
|
|
279
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
280
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
281
|
+
max_seqlen=max_seqlen,
|
|
282
|
+
max_seqlen_q=max_seqlen_q,
|
|
283
|
+
max_seqlen_k=max_seqlen_k,
|
|
284
|
+
attn_mask=attn_mask,
|
|
285
|
+
)
|
|
286
|
+
x = x.transpose(1, 2).reshape(original_shape)
|
|
287
|
+
x = self.proj(x)
|
|
288
|
+
x = self.proj_drop(x)
|
|
289
|
+
return x
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class Mlp(nn.Module):
|
|
293
|
+
"""MLP module used in Vision Transformer, MLP-Mixer and related networks.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
in_features: Number of input features
|
|
297
|
+
hidden_features: Hidden dimension. Defaults to None.
|
|
298
|
+
out_features: Output dimension. Defaults to None.
|
|
299
|
+
act_layer: Activation layer. Defaults to nn.GELU.
|
|
300
|
+
bias: Enable bias in linear layers. Defaults to True.
|
|
301
|
+
drop: Dropout rate. Defaults to 0.0.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
in_features: int,
|
|
307
|
+
hidden_features: int | None = None,
|
|
308
|
+
out_features: int | None = None,
|
|
309
|
+
act_layer: nn.Module = nn.GELU,
|
|
310
|
+
bias: bool = True,
|
|
311
|
+
drop: float = 0.0,
|
|
312
|
+
) -> None:
|
|
313
|
+
"""Initialize the MLP module.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
in_features: Number of input features
|
|
317
|
+
hidden_features: Hidden dimension. Defaults to None.
|
|
318
|
+
out_features: Output dimension. Defaults to None.
|
|
319
|
+
act_layer: Activation layer. Defaults to nn.GELU.
|
|
320
|
+
bias: Enable bias in linear layers. Defaults to True.
|
|
321
|
+
drop: Dropout rate. Defaults to 0.0.
|
|
322
|
+
"""
|
|
323
|
+
super().__init__()
|
|
324
|
+
out_features = out_features or in_features
|
|
325
|
+
hidden_features = hidden_features or in_features
|
|
326
|
+
|
|
327
|
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
328
|
+
self.act = act_layer()
|
|
329
|
+
self.drop1 = nn.Dropout(drop)
|
|
330
|
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
|
331
|
+
self.drop2 = nn.Dropout(drop)
|
|
332
|
+
|
|
333
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
334
|
+
"""Forward pass.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
x: Input tensor
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Output tensor
|
|
341
|
+
"""
|
|
342
|
+
x = self.fc1(x)
|
|
343
|
+
x = self.act(x)
|
|
344
|
+
x = self.drop1(x)
|
|
345
|
+
x = self.fc2(x)
|
|
346
|
+
x = self.drop2(x)
|
|
347
|
+
return x
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class LayerScale(nn.Module):
|
|
351
|
+
"""Learnable scaling layer.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
dim: Input dimension
|
|
355
|
+
init_values: Initial scaling value. Defaults to 1e-5.
|
|
356
|
+
inplace: Perform scaling operation in-place. Defaults to False.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
def __init__(
|
|
360
|
+
self, dim: int, init_values: float = 1e-5, inplace: bool = False
|
|
361
|
+
) -> None:
|
|
362
|
+
"""Initialize the LayerScale module.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
dim: Input dimension
|
|
366
|
+
init_values: Initial scaling value
|
|
367
|
+
inplace: Perform scaling operation in-place
|
|
368
|
+
"""
|
|
369
|
+
super().__init__()
|
|
370
|
+
self.inplace = inplace
|
|
371
|
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
372
|
+
|
|
373
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
374
|
+
"""Forward pass.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
x: Input tensor
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Scaled output tensor
|
|
381
|
+
"""
|
|
382
|
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class DropPath(nn.Module):
|
|
386
|
+
"""Drop paths (Stochastic Depth) per sample when applied in main path of residual blocks.
|
|
387
|
+
|
|
388
|
+
This is a regularization technique that randomly drops entire layers/paths during training
|
|
389
|
+
to prevent overfitting. During inference, all paths are kept.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
drop_prob: Probability of dropping the path. Defaults to None.
|
|
393
|
+
|
|
394
|
+
References:
|
|
395
|
+
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
|
396
|
+
"""
|
|
397
|
+
|
|
398
|
+
def __init__(self, drop_prob: float) -> None:
|
|
399
|
+
"""Initialize the DropPath module.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
drop_prob: Probability of dropping the path. Defaults to None.
|
|
403
|
+
"""
|
|
404
|
+
super().__init__()
|
|
405
|
+
self.drop_prob = drop_prob
|
|
406
|
+
|
|
407
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
408
|
+
"""Forward pass applying stochastic depth to input tensor.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
x: Input tensor of any shape (B, ...)
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Tensor with same shape as input, with paths randomly dropped during training
|
|
415
|
+
"""
|
|
416
|
+
if self.drop_prob is None or self.drop_prob == 0.0 or not self.training:
|
|
417
|
+
return x
|
|
418
|
+
|
|
419
|
+
keep_prob = 1 - self.drop_prob
|
|
420
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # (B, 1, 1, ...)
|
|
421
|
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
422
|
+
random_tensor.floor_() # binarize
|
|
423
|
+
return x.div(keep_prob) * random_tensor
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class Block(nn.Module):
|
|
427
|
+
"""Transformer block with self/cross attention and MLP.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
dim: Input dimension
|
|
431
|
+
num_heads: Number of attention heads
|
|
432
|
+
mlp_ratio: Ratio of mlp hidden dim to input dim. Default: 4.0
|
|
433
|
+
qkv_bias: Add bias to qkv projections. Default: False
|
|
434
|
+
qk_norm: Apply normalization to q,k. Default: False
|
|
435
|
+
drop: Dropout rate. Default: 0.0
|
|
436
|
+
attn_drop: Attention dropout rate. Default: 0.0
|
|
437
|
+
drop_path: Drop path rate. Default: 0.0
|
|
438
|
+
init_values: Layer scale initialization value. Default: None
|
|
439
|
+
act_layer: Activation layer. Default: nn.GELU
|
|
440
|
+
norm_layer: Normalization layer. Default: nn.LayerNorm
|
|
441
|
+
cross_attn: Whether to use cross attention. Default: False
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
def __init__(
|
|
445
|
+
self,
|
|
446
|
+
dim: int,
|
|
447
|
+
num_heads: int,
|
|
448
|
+
mlp_ratio: float = 4.0,
|
|
449
|
+
qkv_bias: bool = False,
|
|
450
|
+
qk_norm: bool = False,
|
|
451
|
+
drop: float = 0.0,
|
|
452
|
+
attn_drop: float = 0.0,
|
|
453
|
+
drop_path: float = 0.0,
|
|
454
|
+
init_values: float | None = None,
|
|
455
|
+
act_layer: nn.Module = nn.GELU,
|
|
456
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
457
|
+
cross_attn: bool = False,
|
|
458
|
+
use_flash_attn: bool = False,
|
|
459
|
+
) -> None:
|
|
460
|
+
"""Initialize the Transformer block.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
dim: Input dimension
|
|
464
|
+
num_heads: Number of attention heads
|
|
465
|
+
mlp_ratio: Ratio of mlp hidden dim to input dim
|
|
466
|
+
qkv_bias: Add bias to qkv projections
|
|
467
|
+
qk_norm: Apply normalization to q,k
|
|
468
|
+
drop: Dropout rate
|
|
469
|
+
attn_drop: Attention dropout rate
|
|
470
|
+
drop_path: Drop path rate
|
|
471
|
+
init_values: Layer scale initialization value
|
|
472
|
+
act_layer: Activation layer
|
|
473
|
+
norm_layer: Normalization layer
|
|
474
|
+
cross_attn: Whether to use cross attention
|
|
475
|
+
use_flash_attn: Whether to use flash attention
|
|
476
|
+
"""
|
|
477
|
+
super().__init__()
|
|
478
|
+
self.norm1 = norm_layer(dim)
|
|
479
|
+
self.attn = Attention(
|
|
480
|
+
dim,
|
|
481
|
+
num_heads=num_heads,
|
|
482
|
+
qkv_bias=qkv_bias,
|
|
483
|
+
qk_norm=qk_norm,
|
|
484
|
+
attn_drop=attn_drop,
|
|
485
|
+
proj_drop=drop,
|
|
486
|
+
norm_layer=norm_layer,
|
|
487
|
+
cross_attn=cross_attn,
|
|
488
|
+
use_flash_attn=use_flash_attn,
|
|
489
|
+
)
|
|
490
|
+
self.ls1 = (
|
|
491
|
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
492
|
+
)
|
|
493
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
494
|
+
|
|
495
|
+
self.norm2 = norm_layer(dim)
|
|
496
|
+
self.mlp = Mlp(
|
|
497
|
+
in_features=dim,
|
|
498
|
+
hidden_features=int(dim * mlp_ratio),
|
|
499
|
+
act_layer=act_layer,
|
|
500
|
+
drop=drop,
|
|
501
|
+
)
|
|
502
|
+
self.ls2 = (
|
|
503
|
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
def forward(
|
|
507
|
+
self,
|
|
508
|
+
x: torch.Tensor,
|
|
509
|
+
y: torch.Tensor | None = None,
|
|
510
|
+
cu_seqlens: torch.Tensor | None = None,
|
|
511
|
+
cu_seqlens_q: torch.Tensor | None = None,
|
|
512
|
+
cu_seqlens_k: torch.Tensor | None = None,
|
|
513
|
+
max_seqlen: int | None = None,
|
|
514
|
+
max_seqlen_q: int | None = None,
|
|
515
|
+
max_seqlen_k: int | None = None,
|
|
516
|
+
attn_mask: torch.Tensor | None = None,
|
|
517
|
+
) -> torch.Tensor:
|
|
518
|
+
"""Forward pass.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
x: Input tensor of shape (B, N, C)
|
|
522
|
+
y: Optional context tensor for cross attention of shape (B, M, C)
|
|
523
|
+
attn_mask: Optional attention mask tensor
|
|
524
|
+
cu_seqlens: Optional cumulative sequence lengths for the input tensor needed for varlen flash attention
|
|
525
|
+
cu_seqlens_q: Optional cumulative sequence lengths for the query tensor, needed for cross varlen flash attention
|
|
526
|
+
cu_seqlens_k: Optional cumulative sequence lengths for the key tensor, needed for cross varlen flash attention
|
|
527
|
+
max_seqlen: Optional maximum sequence length for the input tensor, needed for varlen flash attention
|
|
528
|
+
max_seqlen_q: Optional maximum sequence length for the query tensor, needed for cross varlen flash attention
|
|
529
|
+
max_seqlen_k: Optional maximum sequence length for the key tensor, needed for cross varlen flash attention
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
Output tensor of shape (B, N, C)
|
|
533
|
+
"""
|
|
534
|
+
x = x + self.drop_path(
|
|
535
|
+
self.ls1(
|
|
536
|
+
self.attn(
|
|
537
|
+
x=self.norm1(x),
|
|
538
|
+
y=y,
|
|
539
|
+
cu_seqlens=cu_seqlens,
|
|
540
|
+
cu_seqlens_q=cu_seqlens_q,
|
|
541
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
542
|
+
max_seqlen=max_seqlen,
|
|
543
|
+
max_seqlen_q=max_seqlen_q,
|
|
544
|
+
max_seqlen_k=max_seqlen_k,
|
|
545
|
+
attn_mask=attn_mask,
|
|
546
|
+
)
|
|
547
|
+
)
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
|
|
551
|
+
return x
|
|
552
|
+
|
|
553
|
+
def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
|
|
554
|
+
"""Apply FSDP to the model."""
|
|
555
|
+
fully_shard(self, **fsdp_kwargs)
|
|
556
|
+
|
|
557
|
+
def apply_compile(self) -> None:
|
|
558
|
+
"""Apply torch.compile to the model."""
|
|
559
|
+
self.compile(dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""A collection of functions for creating position encodings for the OlmoEarth Pretrain model.
|
|
2
|
+
|
|
3
|
+
These functions are based on the following repository:
|
|
4
|
+
https://github.com/bair-climate-initiative/scale-mae/blob/main/mae/util/pos_embed.py
|
|
5
|
+
|
|
6
|
+
They cover the following:
|
|
7
|
+
- 2D sinusoidal position encoding (for spatial data)
|
|
8
|
+
- 1D sinusoidal position encoding (for temporal data)
|
|
9
|
+
- Month encoding (for temporal data)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_1d_sincos_pos_encoding(pos: torch.Tensor, encoding_dim: int) -> torch.Tensor:
|
|
17
|
+
"""Get 1D sin cos position encoding for a given set of positions.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
pos: a list of positions to be encoded: size (L,) this can be a time or space dimension
|
|
21
|
+
encoding_dim: output dimension for each position
|
|
22
|
+
Returns:
|
|
23
|
+
encoding: position encoding for the given positions: size (L, D)
|
|
24
|
+
"""
|
|
25
|
+
assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}"
|
|
26
|
+
omega = torch.arange(encoding_dim // 2, device=pos.device) / encoding_dim / 2.0
|
|
27
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
|
28
|
+
|
|
29
|
+
pos = pos.reshape(-1) # (L,)
|
|
30
|
+
out = torch.einsum("l,d->ld", pos, omega) # (L, D/2), outer product
|
|
31
|
+
encoding_sin = torch.sin(out) # (L, D/2)
|
|
32
|
+
encoding_cos = torch.cos(out) # (L, D/2)
|
|
33
|
+
|
|
34
|
+
encoding = torch.cat([encoding_sin, encoding_cos], dim=1) # (L, D)
|
|
35
|
+
return encoding
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_2d_sincos_pos_encoding(grid: torch.Tensor, encoding_dim: int) -> torch.Tensor:
|
|
39
|
+
"""Get 2D sin cos position encoding for a given grid of positions.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
grid: a grid of positions to be encoded: size 2 x h x w
|
|
43
|
+
encoding_dim: output dimension for each position
|
|
44
|
+
Returns:
|
|
45
|
+
encoding: position encoding for the given grid: size (h*w, D)
|
|
46
|
+
"""
|
|
47
|
+
assert encoding_dim % 2 == 0
|
|
48
|
+
|
|
49
|
+
# use half of dimensions to encode grid_h
|
|
50
|
+
encoding_dim_1d = encoding_dim // 2
|
|
51
|
+
emb_h = get_1d_sincos_pos_encoding(grid[0], encoding_dim_1d) # (h*w, D/2)
|
|
52
|
+
emb_w = get_1d_sincos_pos_encoding(grid[1], encoding_dim_1d) # (h*w, D/2)
|
|
53
|
+
|
|
54
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # (h*w, D)
|
|
55
|
+
return emb
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_2d_sincos_pos_encoding_with_resolution(
|
|
59
|
+
grid_size: int,
|
|
60
|
+
res: torch.Tensor,
|
|
61
|
+
encoding_dim: int,
|
|
62
|
+
device: torch.device,
|
|
63
|
+
cls_token: bool = False,
|
|
64
|
+
) -> torch.Tensor:
|
|
65
|
+
"""Get 2D sin cos position encoding for a given grid of positions with resolution.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
grid_size: int of the grid height and width
|
|
69
|
+
res: array of size n, representing the resolution of a pixel (say, in meters),
|
|
70
|
+
where n is the number of spatial dimensions
|
|
71
|
+
encoding_dim: output dimension for each position
|
|
72
|
+
cls_token: whether to add a cls token to the encoding
|
|
73
|
+
device: device to run the encoding on
|
|
74
|
+
Returns:
|
|
75
|
+
encoding: position encoding for the given grid: size (H*W, D)
|
|
76
|
+
"""
|
|
77
|
+
# TODO: What happens when the res array is bigger than 1?
|
|
78
|
+
grid_h = torch.arange(grid_size, device=device)
|
|
79
|
+
grid_w = torch.arange(grid_size, device=device)
|
|
80
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # (h_grid, w_grid)
|
|
81
|
+
grid = torch.stack(grid, dim=0) # 2 x h x w
|
|
82
|
+
|
|
83
|
+
# create resolution scaled grid
|
|
84
|
+
grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
|
|
85
|
+
_, n, h, w = grid.shape
|
|
86
|
+
pos_embed = get_2d_sincos_pos_encoding(grid, encoding_dim) # (nxH*W, D/2)
|
|
87
|
+
pos_embed = pos_embed.reshape(n, h * w, encoding_dim)
|
|
88
|
+
if cls_token:
|
|
89
|
+
pos_embed = torch.cat(
|
|
90
|
+
[
|
|
91
|
+
torch.zeros([n, 1, encoding_dim], device=pos_embed.device),
|
|
92
|
+
pos_embed,
|
|
93
|
+
],
|
|
94
|
+
dim=1,
|
|
95
|
+
)
|
|
96
|
+
return pos_embed
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_month_encoding_table(encoding_dim: int) -> torch.Tensor:
|
|
100
|
+
"""Sinusoid month encoding table, for 12 months indexed from 0-11.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
encoding_dim: output dimension for each position
|
|
104
|
+
Returns:
|
|
105
|
+
month_table: position encoding for the given grid: size (M, D)
|
|
106
|
+
"""
|
|
107
|
+
assert encoding_dim % 2 == 0
|
|
108
|
+
angles = torch.arange(0, 13) / (12 / (2 * np.pi))
|
|
109
|
+
|
|
110
|
+
dim_per_table = encoding_dim // 2
|
|
111
|
+
sin_table = torch.sin(torch.stack([angles for _ in range(dim_per_table)], axis=-1))
|
|
112
|
+
cos_table = torch.cos(torch.stack([angles for _ in range(dim_per_table)], axis=-1))
|
|
113
|
+
month_table = torch.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
|
|
114
|
+
|
|
115
|
+
return month_table # (M, D)
|