olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- olmoearth_pretrain_minimal/__init__.py +16 -0
- olmoearth_pretrain_minimal/model_loader.py +123 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
- olmoearth_pretrain_minimal/test.py +51 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Simple set up of latent predictor."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from torch.distributed import DeviceMesh
|
|
11
|
+
from torch.distributed.fsdp import (
|
|
12
|
+
MixedPrecisionPolicy,
|
|
13
|
+
fully_shard,
|
|
14
|
+
register_fsdp_forward_method,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
|
|
18
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample
|
|
19
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import TokensAndMasks
|
|
20
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import DistributedMixins, unpack_encoder_output
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LatentMIM(nn.Module, DistributedMixins):
|
|
26
|
+
"""Latent MIM Style."""
|
|
27
|
+
|
|
28
|
+
supports_multiple_modalities_at_once = True
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
encoder: nn.Module,
|
|
33
|
+
decoder: nn.Module,
|
|
34
|
+
reconstructor: torch.nn.Module | None = None,
|
|
35
|
+
):
|
|
36
|
+
"""Initialize the Latent MIM Style.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
encoder: The encoder to use.
|
|
40
|
+
decoder: The decoder to use.
|
|
41
|
+
reconstructor: Optional reconstructor for auto-encoding.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.encoder = encoder
|
|
45
|
+
self.decoder = decoder
|
|
46
|
+
self.reconstructor = reconstructor
|
|
47
|
+
self.target_encoder = deepcopy(self.encoder)
|
|
48
|
+
for p in self.target_encoder.parameters():
|
|
49
|
+
p.requires_grad = False
|
|
50
|
+
|
|
51
|
+
def forward(
|
|
52
|
+
self, x: MaskedOlmoEarthSample, patch_size: int
|
|
53
|
+
) -> tuple[
|
|
54
|
+
TokensAndMasks,
|
|
55
|
+
TokensAndMasks,
|
|
56
|
+
torch.Tensor,
|
|
57
|
+
TokensAndMasks | None,
|
|
58
|
+
dict[str, Any],
|
|
59
|
+
]:
|
|
60
|
+
"""Forward pass for the Latent MIM Style.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
latent: embeddings from encoder
|
|
64
|
+
decoded: predictions from decoder for masked tokens
|
|
65
|
+
latent_projected_and_pooled: pooled tokens for contrastive loss
|
|
66
|
+
reconstructed: MAE predictions if enabled
|
|
67
|
+
"""
|
|
68
|
+
# TODO: Input And outputs here are not consistent between encoder and decoder need a tokensandmaks++
|
|
69
|
+
output_dict = self.encoder(x, patch_size=patch_size)
|
|
70
|
+
token_norm_stats = output_dict.pop("token_norm_stats", None)
|
|
71
|
+
latent, latent_projected_and_pooled, decoder_kwargs = unpack_encoder_output(
|
|
72
|
+
output_dict
|
|
73
|
+
)
|
|
74
|
+
extra_metrics = {}
|
|
75
|
+
if token_norm_stats is not None:
|
|
76
|
+
extra_metrics["token_norm_stats"] = token_norm_stats
|
|
77
|
+
reconstructed = None
|
|
78
|
+
if self.reconstructor:
|
|
79
|
+
reconstructed = self.reconstructor(latent, x.timestamps, patch_size)
|
|
80
|
+
decoded = self.decoder(
|
|
81
|
+
latent, timestamps=x.timestamps, patch_size=patch_size, **decoder_kwargs
|
|
82
|
+
)
|
|
83
|
+
return (
|
|
84
|
+
latent,
|
|
85
|
+
decoded,
|
|
86
|
+
latent_projected_and_pooled,
|
|
87
|
+
reconstructed,
|
|
88
|
+
extra_metrics,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def apply_fsdp(
|
|
92
|
+
self,
|
|
93
|
+
dp_mesh: DeviceMesh | None = None,
|
|
94
|
+
param_dtype: torch.dtype | None = None,
|
|
95
|
+
reduce_dtype: torch.dtype = torch.float32,
|
|
96
|
+
prefetch_factor: int = 0,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Apply FSDP to the model."""
|
|
99
|
+
mp_policy = MixedPrecisionPolicy(
|
|
100
|
+
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
|
101
|
+
)
|
|
102
|
+
fsdp_config = dict(mesh=dp_mesh, mp_policy=mp_policy)
|
|
103
|
+
|
|
104
|
+
self.encoder.apply_fsdp(**fsdp_config)
|
|
105
|
+
self.decoder.apply_fsdp(**fsdp_config)
|
|
106
|
+
self.target_encoder.apply_fsdp(**fsdp_config)
|
|
107
|
+
if self.reconstructor:
|
|
108
|
+
self.reconstructor.apply_fsdp(**fsdp_config)
|
|
109
|
+
# TODO: More finegrained wrapping of the encoder transformer layers next time
|
|
110
|
+
fully_shard(self, **fsdp_config)
|
|
111
|
+
register_fsdp_forward_method(self.target_encoder, "forward")
|
|
112
|
+
|
|
113
|
+
def apply_compile(self) -> None:
|
|
114
|
+
"""Apply torch.compile to the model."""
|
|
115
|
+
logger.info("Applying torch.compile to the model")
|
|
116
|
+
self.encoder.apply_compile()
|
|
117
|
+
logger.info("Applied torch.compile to the encoder")
|
|
118
|
+
self.decoder.apply_compile()
|
|
119
|
+
logger.info("Applied torch.compile to the decoder")
|
|
120
|
+
self.target_encoder.apply_compile()
|
|
121
|
+
logger.info("Applied torch.compile to the target encoder")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class LatentMIMConfig(Config):
|
|
126
|
+
"""Configuration for the Latent Predictor."""
|
|
127
|
+
|
|
128
|
+
encoder_config: Config
|
|
129
|
+
decoder_config: Config
|
|
130
|
+
reconstructor_config: Config | None = None
|
|
131
|
+
|
|
132
|
+
def validate(self) -> None:
|
|
133
|
+
"""Validate the configuration."""
|
|
134
|
+
if (
|
|
135
|
+
self.encoder_config.supported_modalities
|
|
136
|
+
!= self.decoder_config.supported_modalities
|
|
137
|
+
):
|
|
138
|
+
raise ValueError("Encoder and decoder must support the same modalities")
|
|
139
|
+
if (
|
|
140
|
+
self.encoder_config.max_sequence_length
|
|
141
|
+
!= self.decoder_config.max_sequence_length
|
|
142
|
+
):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
"Encoder and decoder must have the same max sequence length"
|
|
145
|
+
)
|
|
146
|
+
if (
|
|
147
|
+
self.encoder_config.embedding_size
|
|
148
|
+
!= self.decoder_config.encoder_embedding_size
|
|
149
|
+
):
|
|
150
|
+
raise ValueError("Encoder embedding size must be consistent!")
|
|
151
|
+
|
|
152
|
+
def build(self) -> "LatentMIM":
|
|
153
|
+
"""Build the Latent Predictor."""
|
|
154
|
+
self.validate()
|
|
155
|
+
encoder = self.encoder_config.build()
|
|
156
|
+
decoder = self.decoder_config.build()
|
|
157
|
+
reconstructor = (
|
|
158
|
+
self.reconstructor_config.build()
|
|
159
|
+
if self.reconstructor_config is not None
|
|
160
|
+
else None
|
|
161
|
+
)
|
|
162
|
+
return LatentMIM(
|
|
163
|
+
encoder=encoder,
|
|
164
|
+
decoder=decoder,
|
|
165
|
+
reconstructor=reconstructor,
|
|
166
|
+
)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Tokenization configuration for custom band grouping strategies.
|
|
2
|
+
|
|
3
|
+
This module allows customizing how bands are grouped into tokens for each modality,
|
|
4
|
+
enabling experiments with different tokenization strategies (e.g., per-band tokens,
|
|
5
|
+
spectral groupings, etc.).
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
>>> from olmoearth_pretrain.nn.tokenization import TokenizationConfig, ModalityTokenization
|
|
9
|
+
>>> from olmoearth_pretrain.utils.constants import Modality
|
|
10
|
+
>>>
|
|
11
|
+
>>> # Create config with per-band tokenization for Sentinel-2
|
|
12
|
+
>>> s2_bands = Modality.SENTINEL2_L2A.band_order
|
|
13
|
+
>>> config = TokenizationConfig(
|
|
14
|
+
... overrides={
|
|
15
|
+
... Modality.SENTINEL2_L2A.name: ModalityTokenization(
|
|
16
|
+
... band_groups=[[b] for b in s2_bands]
|
|
17
|
+
... )
|
|
18
|
+
... }
|
|
19
|
+
... )
|
|
20
|
+
>>>
|
|
21
|
+
>>> # Use default tokenization for other modalities
|
|
22
|
+
>>> num_bandsets = config.get_num_bandsets(Modality.SENTINEL1.name)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
|
|
27
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality, ModalitySpec
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ModalityTokenization:
|
|
32
|
+
"""Custom tokenization configuration for a single modality.
|
|
33
|
+
|
|
34
|
+
Specifies how bands should be grouped into tokens. Each band_group
|
|
35
|
+
becomes a separate token.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
band_groups: List of band groups, where each group is a list of band names.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
band_groups: list[list[str]]
|
|
42
|
+
|
|
43
|
+
def compute_indices(self, base_modality: ModalitySpec) -> list[list[int]]:
|
|
44
|
+
"""Map band names to indices based on the base modality's band order.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
base_modality: The ModalitySpec that defines the canonical band order.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of index lists, one per band group.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: If a band name doesn't exist in the modality's band_order.
|
|
54
|
+
"""
|
|
55
|
+
name_to_idx = {name: i for i, name in enumerate(base_modality.band_order)}
|
|
56
|
+
result = []
|
|
57
|
+
for group in self.band_groups:
|
|
58
|
+
group_indices = []
|
|
59
|
+
for band in group:
|
|
60
|
+
if band not in name_to_idx:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Band '{band}' not found in modality '{base_modality.name}'. "
|
|
63
|
+
f"Valid bands: {list(base_modality.band_order)}"
|
|
64
|
+
)
|
|
65
|
+
group_indices.append(name_to_idx[band])
|
|
66
|
+
result.append(group_indices)
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
def get_num_bands_per_group(self) -> list[int]:
|
|
70
|
+
"""Get the number of bands in each group."""
|
|
71
|
+
return [len(group) for group in self.band_groups]
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def num_band_sets(self) -> int:
|
|
75
|
+
"""Get the number of band sets (token groups)."""
|
|
76
|
+
return len(self.band_groups)
|
|
77
|
+
|
|
78
|
+
def validate(self, base_modality: ModalitySpec) -> None:
|
|
79
|
+
"""Validate that all band names exist in the modality.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
base_modality: The ModalitySpec to validate against.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ValueError: If a band name doesn't exist in the modality's band_order.
|
|
86
|
+
"""
|
|
87
|
+
valid_bands = set(base_modality.band_order)
|
|
88
|
+
for group in self.band_groups:
|
|
89
|
+
for band in group:
|
|
90
|
+
if band not in valid_bands:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Band '{band}' not found in modality '{base_modality.name}'. "
|
|
93
|
+
f"Valid bands: {valid_bands}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class TokenizationConfig:
|
|
99
|
+
"""Configuration for custom tokenization strategies.
|
|
100
|
+
|
|
101
|
+
Allows overriding the default bandset groupings for specific modalities.
|
|
102
|
+
Modalities without overrides use their default bandset configuration
|
|
103
|
+
from ModalitySpec.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
overrides: dict[str, ModalityTokenization] = field(default_factory=dict)
|
|
107
|
+
_bandset_indices_cache: dict[str, list[list[int]]] = field(
|
|
108
|
+
default_factory=dict, init=False, repr=False
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def get_bandset_indices(self, modality_name: str) -> list[list[int]]:
|
|
112
|
+
"""Get band indices for tokenization, using override or default.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
modality_name: Name of the modality.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
List of index lists, one per bandset/token group.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError: If modality_name is invalid or band names don't exist.
|
|
122
|
+
"""
|
|
123
|
+
# Check cache first
|
|
124
|
+
if modality_name in self._bandset_indices_cache:
|
|
125
|
+
return self._bandset_indices_cache[modality_name]
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
base_spec = Modality.get(modality_name)
|
|
129
|
+
except (AttributeError, AssertionError) as e:
|
|
130
|
+
raise ValueError(f"Invalid modality: {modality_name}") from e
|
|
131
|
+
|
|
132
|
+
if modality_name in self.overrides:
|
|
133
|
+
result = self.overrides[modality_name].compute_indices(base_spec)
|
|
134
|
+
else:
|
|
135
|
+
result = base_spec.bandsets_as_indices()
|
|
136
|
+
|
|
137
|
+
# Cache the result
|
|
138
|
+
self._bandset_indices_cache[modality_name] = result
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
def get_num_bandsets(self, modality_name: str) -> int:
|
|
142
|
+
"""Get number of bandsets (tokens per spatial location).
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
modality_name: Name of the modality.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Number of bandsets.
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: If modality_name is invalid.
|
|
152
|
+
"""
|
|
153
|
+
if modality_name in self.overrides:
|
|
154
|
+
return self.overrides[modality_name].num_band_sets
|
|
155
|
+
try:
|
|
156
|
+
return Modality.get(modality_name).num_band_sets
|
|
157
|
+
except (AttributeError, AssertionError) as e:
|
|
158
|
+
raise ValueError(f"Invalid modality: {modality_name}") from e
|
|
159
|
+
|
|
160
|
+
def get_num_bands_per_bandset(self, modality_name: str) -> list[int]:
|
|
161
|
+
"""Get the number of bands in each bandset.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
modality_name: Name of the modality.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
List of band counts, one per bandset.
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
ValueError: If modality_name is invalid.
|
|
171
|
+
"""
|
|
172
|
+
if modality_name in self.overrides:
|
|
173
|
+
return self.overrides[modality_name].get_num_bands_per_group()
|
|
174
|
+
try:
|
|
175
|
+
base_spec = Modality.get(modality_name)
|
|
176
|
+
except (AttributeError, AssertionError) as e:
|
|
177
|
+
raise ValueError(f"Invalid modality: {modality_name}") from e
|
|
178
|
+
return [len(bs.bands) for bs in base_spec.band_sets]
|
|
179
|
+
|
|
180
|
+
def validate(self) -> None:
|
|
181
|
+
"""Validate all overrides against their modalities.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValueError: If any modality name or band name is invalid.
|
|
185
|
+
"""
|
|
186
|
+
for modality_name, tokenization in self.overrides.items():
|
|
187
|
+
try:
|
|
188
|
+
base_spec = Modality.get(modality_name)
|
|
189
|
+
except (AttributeError, AssertionError):
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Invalid modality name in overrides: '{modality_name}'. "
|
|
192
|
+
f"Valid modalities: {Modality.names()}"
|
|
193
|
+
)
|
|
194
|
+
tokenization.validate(base_spec)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Utilities for the nn module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.distributed import DeviceMesh
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def unpack_encoder_output(
|
|
10
|
+
output_dict: dict[str, Any],
|
|
11
|
+
) -> tuple:
|
|
12
|
+
"""Unpack the output of an encoder.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
output_dict (dict[str, Any]): The output of an encoder.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
tuple[TokensAndMasks, TokensAndMasks, dict[str, Any]]: The unpacked output.
|
|
19
|
+
"""
|
|
20
|
+
latent = output_dict.pop("tokens_and_masks", None)
|
|
21
|
+
latent_projected_and_pooled = output_dict.pop("project_aggregated", None)
|
|
22
|
+
# Pass through all other outputs that might be specific to an encoder decoder pair
|
|
23
|
+
# remove token_norm_stats
|
|
24
|
+
output_dict.pop("token_norm_stats", None)
|
|
25
|
+
decoder_kwargs = output_dict
|
|
26
|
+
return latent, latent_projected_and_pooled, decoder_kwargs
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_cumulative_sequence_lengths(seq_lengths: torch.Tensor) -> torch.Tensor:
|
|
30
|
+
"""Get the cumulative sequence lengths of a tensor.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
seq_lengths (torch.Tensor): The sequence lengths of a tensor.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
torch.Tensor: The cumulative sequence lengths of a tensor.
|
|
37
|
+
"""
|
|
38
|
+
return torch.cat(
|
|
39
|
+
[
|
|
40
|
+
torch.tensor([0], dtype=torch.int32, device=seq_lengths.device),
|
|
41
|
+
torch.cumsum(
|
|
42
|
+
seq_lengths.masked_select(seq_lengths != 0), 0, dtype=torch.int32
|
|
43
|
+
),
|
|
44
|
+
]
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# TODO: maybe this should just be functional or something
|
|
49
|
+
class DistributedMixins:
|
|
50
|
+
"""Mixin for distributed training."""
|
|
51
|
+
|
|
52
|
+
def apply_ddp(
|
|
53
|
+
self,
|
|
54
|
+
dp_mesh: DeviceMesh | None = None,
|
|
55
|
+
compile_enabled: bool = False,
|
|
56
|
+
autograd_compile_enabled: bool = False,
|
|
57
|
+
find_unused_parameters: bool = True,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Apply DDP to the model.
|
|
60
|
+
|
|
61
|
+
.. warning::
|
|
62
|
+
Usually this does not need to be called directly, as :meth:`TransformerConfig.build()`
|
|
63
|
+
will call it for you.
|
|
64
|
+
"""
|
|
65
|
+
from torch.distributed._composable.replicate import replicate
|
|
66
|
+
|
|
67
|
+
# Adapted from
|
|
68
|
+
# https://github.com/pytorch/torchtitan/blob/90c889e972b56b9faadebbb78fc985dedc537ed9/torchtitan/parallelisms/parallelize_llama.py#L328
|
|
69
|
+
if compile_enabled:
|
|
70
|
+
if autograd_compile_enabled:
|
|
71
|
+
torch._dynamo.config.optimize_ddp = (
|
|
72
|
+
"python_reducer_without_compiled_forward" # type: ignore
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore
|
|
76
|
+
# Forwards kwargs to torch DDP class, find_unused_parameters=True is required for MAE
|
|
77
|
+
# Small performance hit could be possible for other models
|
|
78
|
+
replicate(
|
|
79
|
+
self,
|
|
80
|
+
device_mesh=dp_mesh,
|
|
81
|
+
bucket_cap_mb=100,
|
|
82
|
+
find_unused_parameters=find_unused_parameters,
|
|
83
|
+
)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""OlmoEarth Pretrain v1 model initialization.
|
|
2
|
+
|
|
3
|
+
This module provides a simple interface to initialize OlmoEarth v1 models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality
|
|
13
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import EncoderConfig, PredictorConfig
|
|
14
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.latent_mim import LatentMIM, LatentMIMConfig
|
|
15
|
+
|
|
16
|
+
# Model size configurations matching the official OlmoEarth v1 models
|
|
17
|
+
MODEL_SIZE_CONFIGS = {
|
|
18
|
+
"nano_shallow_decoder": {
|
|
19
|
+
"decoder_depth": 4,
|
|
20
|
+
"encoder_embedding_size": 128,
|
|
21
|
+
"decoder_embedding_size": 128,
|
|
22
|
+
"encoder_depth": 4,
|
|
23
|
+
"encoder_num_heads": 8,
|
|
24
|
+
"decoder_num_heads": 8,
|
|
25
|
+
"mlp_ratio": 4.0,
|
|
26
|
+
},
|
|
27
|
+
"tiny_shallow_decoder": {
|
|
28
|
+
"decoder_depth": 4,
|
|
29
|
+
"encoder_embedding_size": 192,
|
|
30
|
+
"decoder_embedding_size": 192,
|
|
31
|
+
"encoder_depth": 12,
|
|
32
|
+
"encoder_num_heads": 3,
|
|
33
|
+
"decoder_num_heads": 3,
|
|
34
|
+
"mlp_ratio": 4.0,
|
|
35
|
+
},
|
|
36
|
+
"base_shallow_decoder": {
|
|
37
|
+
"decoder_depth": 4,
|
|
38
|
+
"encoder_embedding_size": 768,
|
|
39
|
+
"decoder_embedding_size": 768,
|
|
40
|
+
"encoder_depth": 12,
|
|
41
|
+
"encoder_num_heads": 12,
|
|
42
|
+
"decoder_num_heads": 12,
|
|
43
|
+
"mlp_ratio": 4.0,
|
|
44
|
+
},
|
|
45
|
+
"large_shallow_decoder": {
|
|
46
|
+
"decoder_depth": 4,
|
|
47
|
+
"encoder_embedding_size": 1024,
|
|
48
|
+
"decoder_embedding_size": 1024,
|
|
49
|
+
"encoder_depth": 24,
|
|
50
|
+
"encoder_num_heads": 16,
|
|
51
|
+
"decoder_num_heads": 16,
|
|
52
|
+
"mlp_ratio": 4.0,
|
|
53
|
+
},
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
# Default modalities used in OlmoEarth v1 training
|
|
57
|
+
DEFAULT_MODALITIES = [
|
|
58
|
+
Modality.SENTINEL2_L2A.name,
|
|
59
|
+
Modality.SENTINEL1.name,
|
|
60
|
+
Modality.LANDSAT.name,
|
|
61
|
+
Modality.WORLDCOVER.name,
|
|
62
|
+
Modality.SRTM.name,
|
|
63
|
+
Modality.OPENSTREETMAP_RASTER.name,
|
|
64
|
+
Modality.WRI_CANOPY_HEIGHT_MAP.name,
|
|
65
|
+
Modality.CDL.name,
|
|
66
|
+
Modality.WORLDCEREAL.name,
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class OlmoEarthPretrain_v1(torch.nn.Module):
|
|
71
|
+
"""OlmoEarth Pretrain v1 model.
|
|
72
|
+
|
|
73
|
+
This class provides a simple interface to initialize OlmoEarth v1 models
|
|
74
|
+
directly from the repository. Models are initialized with random weights.
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
model_size: Literal["nano", "tiny", "base", "large"] = "nano",
|
|
81
|
+
supported_modality_names: list[str] | None = None,
|
|
82
|
+
max_patch_size: int = 8,
|
|
83
|
+
max_sequence_length: int = 12,
|
|
84
|
+
drop_path: float = 0.1,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Initialize an OlmoEarth Pretrain v1 model.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
model_size: Size of the model. Options: "nano", "tiny", "base", "large".
|
|
90
|
+
supported_modality_names: List of modality names to support. If None,
|
|
91
|
+
uses the default modalities from OlmoEarth v1 training.
|
|
92
|
+
max_patch_size: Maximum patch size for the encoder.
|
|
93
|
+
max_sequence_length: Maximum sequence length.
|
|
94
|
+
drop_path: Drop path rate for regularization.
|
|
95
|
+
"""
|
|
96
|
+
super().__init__()
|
|
97
|
+
|
|
98
|
+
# Map user-facing model size to internal config key with shallow_decoder suffix
|
|
99
|
+
config_key = f"{model_size}_shallow_decoder"
|
|
100
|
+
if config_key not in MODEL_SIZE_CONFIGS:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid model_size: {model_size}. "
|
|
103
|
+
f"Must be one of {['nano', 'tiny', 'base', 'large']}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if supported_modality_names is None:
|
|
107
|
+
supported_modality_names = DEFAULT_MODALITIES
|
|
108
|
+
|
|
109
|
+
model_config = MODEL_SIZE_CONFIGS[config_key]
|
|
110
|
+
|
|
111
|
+
# Build encoder config
|
|
112
|
+
encoder_config = EncoderConfig(
|
|
113
|
+
embedding_size=model_config["encoder_embedding_size"],
|
|
114
|
+
num_heads=model_config["encoder_num_heads"],
|
|
115
|
+
depth=model_config["encoder_depth"],
|
|
116
|
+
mlp_ratio=model_config["mlp_ratio"],
|
|
117
|
+
supported_modality_names=supported_modality_names,
|
|
118
|
+
max_patch_size=max_patch_size,
|
|
119
|
+
drop_path=drop_path,
|
|
120
|
+
max_sequence_length=max_sequence_length,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Build decoder config
|
|
124
|
+
decoder_config = PredictorConfig(
|
|
125
|
+
encoder_embedding_size=model_config["encoder_embedding_size"],
|
|
126
|
+
decoder_embedding_size=model_config["decoder_embedding_size"],
|
|
127
|
+
depth=model_config["decoder_depth"],
|
|
128
|
+
mlp_ratio=model_config["mlp_ratio"],
|
|
129
|
+
num_heads=model_config["decoder_num_heads"],
|
|
130
|
+
supported_modality_names=supported_modality_names,
|
|
131
|
+
max_sequence_length=max_sequence_length,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Build model config and initialize the model
|
|
135
|
+
model_config_obj = LatentMIMConfig(
|
|
136
|
+
encoder_config=encoder_config,
|
|
137
|
+
decoder_config=decoder_config,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.model = model_config_obj.build()
|
|
141
|
+
|
|
142
|
+
def forward(self, *args, **kwargs):
|
|
143
|
+
"""Forward pass through the model."""
|
|
144
|
+
return self.model(*args, **kwargs)
|
|
145
|
+
|
|
146
|
+
def __getattr__(self, name: str):
|
|
147
|
+
"""Delegate attribute access to the underlying model."""
|
|
148
|
+
try:
|
|
149
|
+
return super().__getattr__(name)
|
|
150
|
+
except AttributeError:
|
|
151
|
+
return getattr(self.model, name)
|
|
152
|
+
|