olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. olmoearth_pretrain_minimal/__init__.py +16 -0
  2. olmoearth_pretrain_minimal/model_loader.py +123 -0
  3. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
  4. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
  5. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
  6. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
  7. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
  8. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
  9. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
  10. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
  11. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
  12. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
  13. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
  14. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
  15. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
  16. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
  17. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
  18. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
  19. olmoearth_pretrain_minimal/test.py +51 -0
  20. olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
  21. olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
  22. olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
  23. olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
  24. olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,559 @@
1
+ """Attention Components for OlmoEarth Pretrain."""
2
+
3
+ from logging import getLogger
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from torch.distributed.fsdp import fully_shard
11
+ from torch.jit import Final
12
+
13
+ try:
14
+ import flash_attn
15
+ except ImportError:
16
+ flash_attn = None
17
+
18
+ logger = getLogger(__name__)
19
+
20
+
21
+ @torch._dynamo.disable()
22
+ def dispatch_flash_attn(
23
+ q: torch.Tensor,
24
+ k: torch.Tensor,
25
+ v: torch.Tensor,
26
+ *,
27
+ cu_seqlens: torch.Tensor | None = None,
28
+ cu_seqlens_q: torch.Tensor | None = None,
29
+ cu_seqlens_k: torch.Tensor | None = None,
30
+ max_seqlen: int | None = None,
31
+ max_seqlen_q: int | None = None,
32
+ max_seqlen_k: int | None = None,
33
+ dropout_p: float = 0.0,
34
+ softmax_scale: float | None = None,
35
+ causal: bool = False,
36
+ ) -> torch.Tensor:
37
+ """Dispatch flash attention.
38
+
39
+ Modeled after olmo core but doesnt flatten internally
40
+ """
41
+ if flash_attn is None:
42
+ raise RuntimeError("flash-attn is required!")
43
+
44
+ if cu_seqlens is not None:
45
+ if cu_seqlens_q is None:
46
+ cu_seqlens_q = cu_seqlens
47
+ if cu_seqlens_k is None:
48
+ cu_seqlens_k = cu_seqlens
49
+ if max_seqlen is not None:
50
+ if max_seqlen_q is None:
51
+ max_seqlen_q = max_seqlen
52
+ if max_seqlen_k is None:
53
+ max_seqlen_k = max_seqlen
54
+
55
+ varlen = all(
56
+ x is not None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
57
+ )
58
+
59
+ if varlen:
60
+ assert q.ndim == 3, "q must be pre-packed"
61
+ logger.debug("using varlen")
62
+
63
+ return flash_attn.flash_attn_varlen_func(
64
+ q,
65
+ k,
66
+ v,
67
+ cu_seqlens_q,
68
+ cu_seqlens_k,
69
+ max_seqlen_q,
70
+ max_seqlen_k,
71
+ dropout_p=dropout_p,
72
+ softmax_scale=softmax_scale,
73
+ causal=causal,
74
+ )
75
+ else:
76
+ return flash_attn.flash_attn_func(
77
+ q,
78
+ k,
79
+ v,
80
+ dropout_p=dropout_p,
81
+ softmax_scale=softmax_scale,
82
+ causal=causal,
83
+ )
84
+
85
+
86
+ class Attention(nn.Module):
87
+ """Multi-head attention module with optional cross-attention support.
88
+
89
+ Args:
90
+ dim: Input dimension
91
+ num_heads: Number of attention heads. Defaults to 8.
92
+ qkv_bias: Enable bias for QKV projections. Defaults to False.
93
+ qk_norm: Apply normalization to Q and K. Defaults to False.
94
+ attn_drop: Attention dropout rate. Defaults to 0.0.
95
+ proj_drop: Output projection dropout rate. Defaults to 0.0.
96
+ norm_layer: Normalization layer. Defaults to nn.LayerNorm.
97
+ cross_attn: Enable cross-attention. Defaults to False.
98
+ """
99
+
100
+ fast_attn: Final[bool]
101
+
102
+ def __init__(
103
+ self,
104
+ dim: int,
105
+ num_heads: int = 8,
106
+ qkv_bias: bool = False,
107
+ qk_norm: bool = False,
108
+ attn_drop: float = 0.0,
109
+ proj_drop: float = 0.0,
110
+ norm_layer: nn.Module = nn.LayerNorm,
111
+ cross_attn: bool = False,
112
+ use_flash_attn: bool = False,
113
+ ) -> None:
114
+ """Initialize the attention module.
115
+
116
+ Args:
117
+ dim: Input dimension
118
+ num_heads: Number of attention heads
119
+ qkv_bias: Enable bias for QKV projections
120
+ qk_norm: Apply normalization to Q and K
121
+ attn_drop: Attention dropout rate
122
+ proj_drop: Output projection dropout rate
123
+ norm_layer: Normalization layer
124
+ cross_attn: Enable cross-attention
125
+ use_flash_attn: Use flash attention
126
+ """
127
+ super().__init__()
128
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
129
+ self.num_heads = num_heads
130
+ self.head_dim = dim // num_heads
131
+ self.scale = self.head_dim**-0.5
132
+
133
+ self.cross_attn = cross_attn
134
+ self.use_flash_attn = use_flash_attn
135
+ self.fast_attn = hasattr(torch.nn.functional, "scaled_dot_product_attention")
136
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
137
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
138
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
139
+
140
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
141
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
142
+ self.attn_drop = nn.Dropout(attn_drop)
143
+ self.proj = nn.Linear(dim, dim)
144
+ self.proj_drop = nn.Dropout(proj_drop)
145
+
146
+ def sdpa(
147
+ self,
148
+ q: torch.Tensor,
149
+ k: torch.Tensor,
150
+ v: torch.Tensor,
151
+ n: int,
152
+ cu_seqlens: torch.Tensor | None = None,
153
+ cu_seqlens_q: torch.Tensor | None = None,
154
+ cu_seqlens_k: torch.Tensor | None = None,
155
+ max_seqlen: int | None = None,
156
+ max_seqlen_q: int | None = None,
157
+ max_seqlen_k: int | None = None,
158
+ attn_mask: torch.Tensor | None = None,
159
+ ) -> torch.Tensor:
160
+ """Compute scaled dot product attention.
161
+
162
+ Args:
163
+ q: Query tensor of shape (B, H, N, D)
164
+ k: Key tensor of shape (B, H, N, D)
165
+ v: Value tensor of shape (B, H, N, D)
166
+ n: Number of tokens
167
+ attn_mask: Attention mask. Defaults to None.
168
+ cu_seqlens: Optional cumulative sequence lengths for the input tensor needed for varlen flash attention
169
+ cu_seqlens_q: Optional cumulative sequence lengths for the query tensor, needed for cross varlen flash attention
170
+ cu_seqlens_k: Optional cumulative sequence lengths for the key tensor, needed for cross varlen flash attention
171
+ max_seqlen: Optional maximum sequence length for the input tensor, needed for varlen flash attention
172
+ max_seqlen_q: Optional maximum sequence length for the query tensor, needed for cross varlen flash attention
173
+ max_seqlen_k: Optional maximum sequence length for the key tensor, needed for cross varlen flash attention
174
+
175
+ Returns:
176
+ Output tensor of shape (B, H, N, D)
177
+ """
178
+ if self.use_flash_attn:
179
+ x = dispatch_flash_attn(
180
+ q,
181
+ k,
182
+ v,
183
+ cu_seqlens=cu_seqlens,
184
+ cu_seqlens_q=cu_seqlens_q,
185
+ cu_seqlens_k=cu_seqlens_k,
186
+ max_seqlen=max_seqlen,
187
+ max_seqlen_q=max_seqlen_q,
188
+ max_seqlen_k=max_seqlen_k,
189
+ dropout_p=self.attn_drop.p if self.training else 0.0,
190
+ softmax_scale=self.scale,
191
+ causal=False,
192
+ )
193
+ # Output is (B, Nq, H, D), transpose back to (B, H, Nq, D)
194
+ # matching the transpose of the other attention implementations that need to be transposed back
195
+ x = x.transpose(1, 2)
196
+ elif self.fast_attn:
197
+ if attn_mask is not None:
198
+ attn_mask = attn_mask[:, None, None].repeat((1, self.num_heads, n, 1))
199
+ x = F.scaled_dot_product_attention(
200
+ q,
201
+ k,
202
+ v,
203
+ # a value of True indicates that the element should take part in attention
204
+ attn_mask=attn_mask,
205
+ dropout_p=self.attn_drop.p,
206
+ )
207
+ else:
208
+ # Backward Compatible for older PyTorch versions
209
+ if attn_mask is not None:
210
+ raise NotImplementedError
211
+ q = q * self.scale
212
+ attn = q @ k.transpose(-2, -1)
213
+ attn = attn.softmax(dim=-1)
214
+ attn = self.attn_drop(attn)
215
+ x = attn @ v
216
+
217
+ return x
218
+
219
+ def forward(
220
+ self,
221
+ x: torch.Tensor,
222
+ y: torch.Tensor | None = None,
223
+ cu_seqlens: torch.Tensor | None = None,
224
+ cu_seqlens_q: torch.Tensor | None = None,
225
+ cu_seqlens_k: torch.Tensor | None = None,
226
+ max_seqlen: int | None = None,
227
+ max_seqlen_q: int | None = None,
228
+ max_seqlen_k: int | None = None,
229
+ attn_mask: torch.Tensor | None = None,
230
+ ) -> torch.Tensor:
231
+ """Forward pass.
232
+
233
+ Args:
234
+ x: Input tensor of shape (B, N, C) or (B* N , C) if packed
235
+ y: Second input for cross-attention. Defaults to None.
236
+ attn_mask: Attention mask. Defaults to None.
237
+ cu_seqlens: Optional cumulative sequence lengths for the input tensor needed for varlen flash attention
238
+ cu_seqlens_q: Optional cumulative sequence lengths for the query tensor, needed for cross varlen flash attention
239
+ cu_seqlens_k: Optional cumulative sequence lengths for the key tensor, needed for cross varlen flash attention
240
+ max_seqlen: Optional maximum sequence length for the input tensor, needed for varlen flash attention
241
+ max_seqlen_q: Optional maximum sequence length for the query tensor, needed for cross varlen flash attention
242
+ max_seqlen_k: Optional maximum sequence length for the key tensor, needed for cross varlen flash attention
243
+
244
+ Returns:
245
+ Output tensor of shape (B, N, C) or (B* N , C) if packed
246
+ """
247
+ original_shape = x.shape
248
+
249
+ q = self.q(x)
250
+
251
+ if y is None:
252
+ assert not self.cross_attn
253
+ k = self.k(x)
254
+ v = self.v(x)
255
+ else:
256
+ assert self.cross_attn
257
+ k = self.k(y)
258
+ v = self.v(y)
259
+ if not self.use_flash_attn:
260
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
261
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
262
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
263
+ else:
264
+ q = rearrange(q, "bn (h d) -> bn h d", h=self.num_heads)
265
+ # Flash attention only supports k v heads that divide the number of query heads
266
+ k = rearrange(k, "bn (h d) -> bn h d", h=self.num_heads)
267
+ v = rearrange(v, "bn (h d) -> bn h d", h=self.num_heads)
268
+ # logger.info(f"q shape: {q.shape} k shape: {k.shape} v shape: {v.shape}")
269
+
270
+ q, k = self.q_norm(q), self.k_norm(k)
271
+ x = self.sdpa(
272
+ q,
273
+ k,
274
+ v,
275
+ n=original_shape[
276
+ -2
277
+ ], # supposed to be the number of tokens in each sample with padding
278
+ cu_seqlens=cu_seqlens,
279
+ cu_seqlens_q=cu_seqlens_q,
280
+ cu_seqlens_k=cu_seqlens_k,
281
+ max_seqlen=max_seqlen,
282
+ max_seqlen_q=max_seqlen_q,
283
+ max_seqlen_k=max_seqlen_k,
284
+ attn_mask=attn_mask,
285
+ )
286
+ x = x.transpose(1, 2).reshape(original_shape)
287
+ x = self.proj(x)
288
+ x = self.proj_drop(x)
289
+ return x
290
+
291
+
292
+ class Mlp(nn.Module):
293
+ """MLP module used in Vision Transformer, MLP-Mixer and related networks.
294
+
295
+ Args:
296
+ in_features: Number of input features
297
+ hidden_features: Hidden dimension. Defaults to None.
298
+ out_features: Output dimension. Defaults to None.
299
+ act_layer: Activation layer. Defaults to nn.GELU.
300
+ bias: Enable bias in linear layers. Defaults to True.
301
+ drop: Dropout rate. Defaults to 0.0.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ in_features: int,
307
+ hidden_features: int | None = None,
308
+ out_features: int | None = None,
309
+ act_layer: nn.Module = nn.GELU,
310
+ bias: bool = True,
311
+ drop: float = 0.0,
312
+ ) -> None:
313
+ """Initialize the MLP module.
314
+
315
+ Args:
316
+ in_features: Number of input features
317
+ hidden_features: Hidden dimension. Defaults to None.
318
+ out_features: Output dimension. Defaults to None.
319
+ act_layer: Activation layer. Defaults to nn.GELU.
320
+ bias: Enable bias in linear layers. Defaults to True.
321
+ drop: Dropout rate. Defaults to 0.0.
322
+ """
323
+ super().__init__()
324
+ out_features = out_features or in_features
325
+ hidden_features = hidden_features or in_features
326
+
327
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
328
+ self.act = act_layer()
329
+ self.drop1 = nn.Dropout(drop)
330
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
331
+ self.drop2 = nn.Dropout(drop)
332
+
333
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
334
+ """Forward pass.
335
+
336
+ Args:
337
+ x: Input tensor
338
+
339
+ Returns:
340
+ Output tensor
341
+ """
342
+ x = self.fc1(x)
343
+ x = self.act(x)
344
+ x = self.drop1(x)
345
+ x = self.fc2(x)
346
+ x = self.drop2(x)
347
+ return x
348
+
349
+
350
+ class LayerScale(nn.Module):
351
+ """Learnable scaling layer.
352
+
353
+ Args:
354
+ dim: Input dimension
355
+ init_values: Initial scaling value. Defaults to 1e-5.
356
+ inplace: Perform scaling operation in-place. Defaults to False.
357
+ """
358
+
359
+ def __init__(
360
+ self, dim: int, init_values: float = 1e-5, inplace: bool = False
361
+ ) -> None:
362
+ """Initialize the LayerScale module.
363
+
364
+ Args:
365
+ dim: Input dimension
366
+ init_values: Initial scaling value
367
+ inplace: Perform scaling operation in-place
368
+ """
369
+ super().__init__()
370
+ self.inplace = inplace
371
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
372
+
373
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
374
+ """Forward pass.
375
+
376
+ Args:
377
+ x: Input tensor
378
+
379
+ Returns:
380
+ Scaled output tensor
381
+ """
382
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
383
+
384
+
385
+ class DropPath(nn.Module):
386
+ """Drop paths (Stochastic Depth) per sample when applied in main path of residual blocks.
387
+
388
+ This is a regularization technique that randomly drops entire layers/paths during training
389
+ to prevent overfitting. During inference, all paths are kept.
390
+
391
+ Args:
392
+ drop_prob: Probability of dropping the path. Defaults to None.
393
+
394
+ References:
395
+ Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
396
+ """
397
+
398
+ def __init__(self, drop_prob: float) -> None:
399
+ """Initialize the DropPath module.
400
+
401
+ Args:
402
+ drop_prob: Probability of dropping the path. Defaults to None.
403
+ """
404
+ super().__init__()
405
+ self.drop_prob = drop_prob
406
+
407
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
408
+ """Forward pass applying stochastic depth to input tensor.
409
+
410
+ Args:
411
+ x: Input tensor of any shape (B, ...)
412
+
413
+ Returns:
414
+ Tensor with same shape as input, with paths randomly dropped during training
415
+ """
416
+ if self.drop_prob is None or self.drop_prob == 0.0 or not self.training:
417
+ return x
418
+
419
+ keep_prob = 1 - self.drop_prob
420
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # (B, 1, 1, ...)
421
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
422
+ random_tensor.floor_() # binarize
423
+ return x.div(keep_prob) * random_tensor
424
+
425
+
426
+ class Block(nn.Module):
427
+ """Transformer block with self/cross attention and MLP.
428
+
429
+ Args:
430
+ dim: Input dimension
431
+ num_heads: Number of attention heads
432
+ mlp_ratio: Ratio of mlp hidden dim to input dim. Default: 4.0
433
+ qkv_bias: Add bias to qkv projections. Default: False
434
+ qk_norm: Apply normalization to q,k. Default: False
435
+ drop: Dropout rate. Default: 0.0
436
+ attn_drop: Attention dropout rate. Default: 0.0
437
+ drop_path: Drop path rate. Default: 0.0
438
+ init_values: Layer scale initialization value. Default: None
439
+ act_layer: Activation layer. Default: nn.GELU
440
+ norm_layer: Normalization layer. Default: nn.LayerNorm
441
+ cross_attn: Whether to use cross attention. Default: False
442
+ """
443
+
444
+ def __init__(
445
+ self,
446
+ dim: int,
447
+ num_heads: int,
448
+ mlp_ratio: float = 4.0,
449
+ qkv_bias: bool = False,
450
+ qk_norm: bool = False,
451
+ drop: float = 0.0,
452
+ attn_drop: float = 0.0,
453
+ drop_path: float = 0.0,
454
+ init_values: float | None = None,
455
+ act_layer: nn.Module = nn.GELU,
456
+ norm_layer: nn.Module = nn.LayerNorm,
457
+ cross_attn: bool = False,
458
+ use_flash_attn: bool = False,
459
+ ) -> None:
460
+ """Initialize the Transformer block.
461
+
462
+ Args:
463
+ dim: Input dimension
464
+ num_heads: Number of attention heads
465
+ mlp_ratio: Ratio of mlp hidden dim to input dim
466
+ qkv_bias: Add bias to qkv projections
467
+ qk_norm: Apply normalization to q,k
468
+ drop: Dropout rate
469
+ attn_drop: Attention dropout rate
470
+ drop_path: Drop path rate
471
+ init_values: Layer scale initialization value
472
+ act_layer: Activation layer
473
+ norm_layer: Normalization layer
474
+ cross_attn: Whether to use cross attention
475
+ use_flash_attn: Whether to use flash attention
476
+ """
477
+ super().__init__()
478
+ self.norm1 = norm_layer(dim)
479
+ self.attn = Attention(
480
+ dim,
481
+ num_heads=num_heads,
482
+ qkv_bias=qkv_bias,
483
+ qk_norm=qk_norm,
484
+ attn_drop=attn_drop,
485
+ proj_drop=drop,
486
+ norm_layer=norm_layer,
487
+ cross_attn=cross_attn,
488
+ use_flash_attn=use_flash_attn,
489
+ )
490
+ self.ls1 = (
491
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
492
+ )
493
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
494
+
495
+ self.norm2 = norm_layer(dim)
496
+ self.mlp = Mlp(
497
+ in_features=dim,
498
+ hidden_features=int(dim * mlp_ratio),
499
+ act_layer=act_layer,
500
+ drop=drop,
501
+ )
502
+ self.ls2 = (
503
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
504
+ )
505
+
506
+ def forward(
507
+ self,
508
+ x: torch.Tensor,
509
+ y: torch.Tensor | None = None,
510
+ cu_seqlens: torch.Tensor | None = None,
511
+ cu_seqlens_q: torch.Tensor | None = None,
512
+ cu_seqlens_k: torch.Tensor | None = None,
513
+ max_seqlen: int | None = None,
514
+ max_seqlen_q: int | None = None,
515
+ max_seqlen_k: int | None = None,
516
+ attn_mask: torch.Tensor | None = None,
517
+ ) -> torch.Tensor:
518
+ """Forward pass.
519
+
520
+ Args:
521
+ x: Input tensor of shape (B, N, C)
522
+ y: Optional context tensor for cross attention of shape (B, M, C)
523
+ attn_mask: Optional attention mask tensor
524
+ cu_seqlens: Optional cumulative sequence lengths for the input tensor needed for varlen flash attention
525
+ cu_seqlens_q: Optional cumulative sequence lengths for the query tensor, needed for cross varlen flash attention
526
+ cu_seqlens_k: Optional cumulative sequence lengths for the key tensor, needed for cross varlen flash attention
527
+ max_seqlen: Optional maximum sequence length for the input tensor, needed for varlen flash attention
528
+ max_seqlen_q: Optional maximum sequence length for the query tensor, needed for cross varlen flash attention
529
+ max_seqlen_k: Optional maximum sequence length for the key tensor, needed for cross varlen flash attention
530
+
531
+ Returns:
532
+ Output tensor of shape (B, N, C)
533
+ """
534
+ x = x + self.drop_path(
535
+ self.ls1(
536
+ self.attn(
537
+ x=self.norm1(x),
538
+ y=y,
539
+ cu_seqlens=cu_seqlens,
540
+ cu_seqlens_q=cu_seqlens_q,
541
+ cu_seqlens_k=cu_seqlens_k,
542
+ max_seqlen=max_seqlen,
543
+ max_seqlen_q=max_seqlen_q,
544
+ max_seqlen_k=max_seqlen_k,
545
+ attn_mask=attn_mask,
546
+ )
547
+ )
548
+ )
549
+
550
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
551
+ return x
552
+
553
+ def apply_fsdp(self, **fsdp_kwargs: Any) -> None:
554
+ """Apply FSDP to the model."""
555
+ fully_shard(self, **fsdp_kwargs)
556
+
557
+ def apply_compile(self) -> None:
558
+ """Apply torch.compile to the model."""
559
+ self.compile(dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
@@ -0,0 +1,115 @@
1
+ """A collection of functions for creating position encodings for the OlmoEarth Pretrain model.
2
+
3
+ These functions are based on the following repository:
4
+ https://github.com/bair-climate-initiative/scale-mae/blob/main/mae/util/pos_embed.py
5
+
6
+ They cover the following:
7
+ - 2D sinusoidal position encoding (for spatial data)
8
+ - 1D sinusoidal position encoding (for temporal data)
9
+ - Month encoding (for temporal data)
10
+ """
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ def get_1d_sincos_pos_encoding(pos: torch.Tensor, encoding_dim: int) -> torch.Tensor:
17
+ """Get 1D sin cos position encoding for a given set of positions.
18
+
19
+ Args:
20
+ pos: a list of positions to be encoded: size (L,) this can be a time or space dimension
21
+ encoding_dim: output dimension for each position
22
+ Returns:
23
+ encoding: position encoding for the given positions: size (L, D)
24
+ """
25
+ assert encoding_dim % 2 == 0, f"encoding_dim must be even, got {encoding_dim}"
26
+ omega = torch.arange(encoding_dim // 2, device=pos.device) / encoding_dim / 2.0
27
+ omega = 1.0 / 10000**omega # (D/2,)
28
+
29
+ pos = pos.reshape(-1) # (L,)
30
+ out = torch.einsum("l,d->ld", pos, omega) # (L, D/2), outer product
31
+ encoding_sin = torch.sin(out) # (L, D/2)
32
+ encoding_cos = torch.cos(out) # (L, D/2)
33
+
34
+ encoding = torch.cat([encoding_sin, encoding_cos], dim=1) # (L, D)
35
+ return encoding
36
+
37
+
38
+ def get_2d_sincos_pos_encoding(grid: torch.Tensor, encoding_dim: int) -> torch.Tensor:
39
+ """Get 2D sin cos position encoding for a given grid of positions.
40
+
41
+ Args:
42
+ grid: a grid of positions to be encoded: size 2 x h x w
43
+ encoding_dim: output dimension for each position
44
+ Returns:
45
+ encoding: position encoding for the given grid: size (h*w, D)
46
+ """
47
+ assert encoding_dim % 2 == 0
48
+
49
+ # use half of dimensions to encode grid_h
50
+ encoding_dim_1d = encoding_dim // 2
51
+ emb_h = get_1d_sincos_pos_encoding(grid[0], encoding_dim_1d) # (h*w, D/2)
52
+ emb_w = get_1d_sincos_pos_encoding(grid[1], encoding_dim_1d) # (h*w, D/2)
53
+
54
+ emb = torch.cat([emb_h, emb_w], dim=1) # (h*w, D)
55
+ return emb
56
+
57
+
58
+ def get_2d_sincos_pos_encoding_with_resolution(
59
+ grid_size: int,
60
+ res: torch.Tensor,
61
+ encoding_dim: int,
62
+ device: torch.device,
63
+ cls_token: bool = False,
64
+ ) -> torch.Tensor:
65
+ """Get 2D sin cos position encoding for a given grid of positions with resolution.
66
+
67
+ Args:
68
+ grid_size: int of the grid height and width
69
+ res: array of size n, representing the resolution of a pixel (say, in meters),
70
+ where n is the number of spatial dimensions
71
+ encoding_dim: output dimension for each position
72
+ cls_token: whether to add a cls token to the encoding
73
+ device: device to run the encoding on
74
+ Returns:
75
+ encoding: position encoding for the given grid: size (H*W, D)
76
+ """
77
+ # TODO: What happens when the res array is bigger than 1?
78
+ grid_h = torch.arange(grid_size, device=device)
79
+ grid_w = torch.arange(grid_size, device=device)
80
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # (h_grid, w_grid)
81
+ grid = torch.stack(grid, dim=0) # 2 x h x w
82
+
83
+ # create resolution scaled grid
84
+ grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
85
+ _, n, h, w = grid.shape
86
+ pos_embed = get_2d_sincos_pos_encoding(grid, encoding_dim) # (nxH*W, D/2)
87
+ pos_embed = pos_embed.reshape(n, h * w, encoding_dim)
88
+ if cls_token:
89
+ pos_embed = torch.cat(
90
+ [
91
+ torch.zeros([n, 1, encoding_dim], device=pos_embed.device),
92
+ pos_embed,
93
+ ],
94
+ dim=1,
95
+ )
96
+ return pos_embed
97
+
98
+
99
+ def get_month_encoding_table(encoding_dim: int) -> torch.Tensor:
100
+ """Sinusoid month encoding table, for 12 months indexed from 0-11.
101
+
102
+ Args:
103
+ encoding_dim: output dimension for each position
104
+ Returns:
105
+ month_table: position encoding for the given grid: size (M, D)
106
+ """
107
+ assert encoding_dim % 2 == 0
108
+ angles = torch.arange(0, 13) / (12 / (2 * np.pi))
109
+
110
+ dim_per_table = encoding_dim // 2
111
+ sin_table = torch.sin(torch.stack([angles for _ in range(dim_per_table)], axis=-1))
112
+ cos_table = torch.cos(torch.stack([angles for _ in range(dim_per_table)], axis=-1))
113
+ month_table = torch.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
114
+
115
+ return month_table # (M, D)