olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. olmoearth_pretrain_minimal/__init__.py +16 -0
  2. olmoearth_pretrain_minimal/model_loader.py +123 -0
  3. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
  4. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
  5. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
  6. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
  7. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
  8. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
  9. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
  10. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
  11. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
  12. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
  13. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
  14. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
  15. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
  16. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
  17. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
  18. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
  19. olmoearth_pretrain_minimal/test.py +51 -0
  20. olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
  21. olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
  22. olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
  23. olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
  24. olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,166 @@
1
+ """Simple set up of latent predictor."""
2
+
3
+ import logging
4
+ from copy import deepcopy
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributed import DeviceMesh
11
+ from torch.distributed.fsdp import (
12
+ MixedPrecisionPolicy,
13
+ fully_shard,
14
+ register_fsdp_forward_method,
15
+ )
16
+
17
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
18
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import MaskedOlmoEarthSample
19
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import TokensAndMasks
20
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.utils import DistributedMixins, unpack_encoder_output
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class LatentMIM(nn.Module, DistributedMixins):
26
+ """Latent MIM Style."""
27
+
28
+ supports_multiple_modalities_at_once = True
29
+
30
+ def __init__(
31
+ self,
32
+ encoder: nn.Module,
33
+ decoder: nn.Module,
34
+ reconstructor: torch.nn.Module | None = None,
35
+ ):
36
+ """Initialize the Latent MIM Style.
37
+
38
+ Args:
39
+ encoder: The encoder to use.
40
+ decoder: The decoder to use.
41
+ reconstructor: Optional reconstructor for auto-encoding.
42
+ """
43
+ super().__init__()
44
+ self.encoder = encoder
45
+ self.decoder = decoder
46
+ self.reconstructor = reconstructor
47
+ self.target_encoder = deepcopy(self.encoder)
48
+ for p in self.target_encoder.parameters():
49
+ p.requires_grad = False
50
+
51
+ def forward(
52
+ self, x: MaskedOlmoEarthSample, patch_size: int
53
+ ) -> tuple[
54
+ TokensAndMasks,
55
+ TokensAndMasks,
56
+ torch.Tensor,
57
+ TokensAndMasks | None,
58
+ dict[str, Any],
59
+ ]:
60
+ """Forward pass for the Latent MIM Style.
61
+
62
+ Returns:
63
+ latent: embeddings from encoder
64
+ decoded: predictions from decoder for masked tokens
65
+ latent_projected_and_pooled: pooled tokens for contrastive loss
66
+ reconstructed: MAE predictions if enabled
67
+ """
68
+ # TODO: Input And outputs here are not consistent between encoder and decoder need a tokensandmaks++
69
+ output_dict = self.encoder(x, patch_size=patch_size)
70
+ token_norm_stats = output_dict.pop("token_norm_stats", None)
71
+ latent, latent_projected_and_pooled, decoder_kwargs = unpack_encoder_output(
72
+ output_dict
73
+ )
74
+ extra_metrics = {}
75
+ if token_norm_stats is not None:
76
+ extra_metrics["token_norm_stats"] = token_norm_stats
77
+ reconstructed = None
78
+ if self.reconstructor:
79
+ reconstructed = self.reconstructor(latent, x.timestamps, patch_size)
80
+ decoded = self.decoder(
81
+ latent, timestamps=x.timestamps, patch_size=patch_size, **decoder_kwargs
82
+ )
83
+ return (
84
+ latent,
85
+ decoded,
86
+ latent_projected_and_pooled,
87
+ reconstructed,
88
+ extra_metrics,
89
+ )
90
+
91
+ def apply_fsdp(
92
+ self,
93
+ dp_mesh: DeviceMesh | None = None,
94
+ param_dtype: torch.dtype | None = None,
95
+ reduce_dtype: torch.dtype = torch.float32,
96
+ prefetch_factor: int = 0,
97
+ ) -> None:
98
+ """Apply FSDP to the model."""
99
+ mp_policy = MixedPrecisionPolicy(
100
+ param_dtype=param_dtype, reduce_dtype=reduce_dtype
101
+ )
102
+ fsdp_config = dict(mesh=dp_mesh, mp_policy=mp_policy)
103
+
104
+ self.encoder.apply_fsdp(**fsdp_config)
105
+ self.decoder.apply_fsdp(**fsdp_config)
106
+ self.target_encoder.apply_fsdp(**fsdp_config)
107
+ if self.reconstructor:
108
+ self.reconstructor.apply_fsdp(**fsdp_config)
109
+ # TODO: More finegrained wrapping of the encoder transformer layers next time
110
+ fully_shard(self, **fsdp_config)
111
+ register_fsdp_forward_method(self.target_encoder, "forward")
112
+
113
+ def apply_compile(self) -> None:
114
+ """Apply torch.compile to the model."""
115
+ logger.info("Applying torch.compile to the model")
116
+ self.encoder.apply_compile()
117
+ logger.info("Applied torch.compile to the encoder")
118
+ self.decoder.apply_compile()
119
+ logger.info("Applied torch.compile to the decoder")
120
+ self.target_encoder.apply_compile()
121
+ logger.info("Applied torch.compile to the target encoder")
122
+
123
+
124
+ @dataclass
125
+ class LatentMIMConfig(Config):
126
+ """Configuration for the Latent Predictor."""
127
+
128
+ encoder_config: Config
129
+ decoder_config: Config
130
+ reconstructor_config: Config | None = None
131
+
132
+ def validate(self) -> None:
133
+ """Validate the configuration."""
134
+ if (
135
+ self.encoder_config.supported_modalities
136
+ != self.decoder_config.supported_modalities
137
+ ):
138
+ raise ValueError("Encoder and decoder must support the same modalities")
139
+ if (
140
+ self.encoder_config.max_sequence_length
141
+ != self.decoder_config.max_sequence_length
142
+ ):
143
+ raise ValueError(
144
+ "Encoder and decoder must have the same max sequence length"
145
+ )
146
+ if (
147
+ self.encoder_config.embedding_size
148
+ != self.decoder_config.encoder_embedding_size
149
+ ):
150
+ raise ValueError("Encoder embedding size must be consistent!")
151
+
152
+ def build(self) -> "LatentMIM":
153
+ """Build the Latent Predictor."""
154
+ self.validate()
155
+ encoder = self.encoder_config.build()
156
+ decoder = self.decoder_config.build()
157
+ reconstructor = (
158
+ self.reconstructor_config.build()
159
+ if self.reconstructor_config is not None
160
+ else None
161
+ )
162
+ return LatentMIM(
163
+ encoder=encoder,
164
+ decoder=decoder,
165
+ reconstructor=reconstructor,
166
+ )
@@ -0,0 +1,194 @@
1
+ """Tokenization configuration for custom band grouping strategies.
2
+
3
+ This module allows customizing how bands are grouped into tokens for each modality,
4
+ enabling experiments with different tokenization strategies (e.g., per-band tokens,
5
+ spectral groupings, etc.).
6
+
7
+ Example:
8
+ >>> from olmoearth_pretrain.nn.tokenization import TokenizationConfig, ModalityTokenization
9
+ >>> from olmoearth_pretrain.utils.constants import Modality
10
+ >>>
11
+ >>> # Create config with per-band tokenization for Sentinel-2
12
+ >>> s2_bands = Modality.SENTINEL2_L2A.band_order
13
+ >>> config = TokenizationConfig(
14
+ ... overrides={
15
+ ... Modality.SENTINEL2_L2A.name: ModalityTokenization(
16
+ ... band_groups=[[b] for b in s2_bands]
17
+ ... )
18
+ ... }
19
+ ... )
20
+ >>>
21
+ >>> # Use default tokenization for other modalities
22
+ >>> num_bandsets = config.get_num_bandsets(Modality.SENTINEL1.name)
23
+ """
24
+
25
+ from dataclasses import dataclass, field
26
+
27
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality, ModalitySpec
28
+
29
+
30
+ @dataclass
31
+ class ModalityTokenization:
32
+ """Custom tokenization configuration for a single modality.
33
+
34
+ Specifies how bands should be grouped into tokens. Each band_group
35
+ becomes a separate token.
36
+
37
+ Args:
38
+ band_groups: List of band groups, where each group is a list of band names.
39
+ """
40
+
41
+ band_groups: list[list[str]]
42
+
43
+ def compute_indices(self, base_modality: ModalitySpec) -> list[list[int]]:
44
+ """Map band names to indices based on the base modality's band order.
45
+
46
+ Args:
47
+ base_modality: The ModalitySpec that defines the canonical band order.
48
+
49
+ Returns:
50
+ List of index lists, one per band group.
51
+
52
+ Raises:
53
+ ValueError: If a band name doesn't exist in the modality's band_order.
54
+ """
55
+ name_to_idx = {name: i for i, name in enumerate(base_modality.band_order)}
56
+ result = []
57
+ for group in self.band_groups:
58
+ group_indices = []
59
+ for band in group:
60
+ if band not in name_to_idx:
61
+ raise ValueError(
62
+ f"Band '{band}' not found in modality '{base_modality.name}'. "
63
+ f"Valid bands: {list(base_modality.band_order)}"
64
+ )
65
+ group_indices.append(name_to_idx[band])
66
+ result.append(group_indices)
67
+ return result
68
+
69
+ def get_num_bands_per_group(self) -> list[int]:
70
+ """Get the number of bands in each group."""
71
+ return [len(group) for group in self.band_groups]
72
+
73
+ @property
74
+ def num_band_sets(self) -> int:
75
+ """Get the number of band sets (token groups)."""
76
+ return len(self.band_groups)
77
+
78
+ def validate(self, base_modality: ModalitySpec) -> None:
79
+ """Validate that all band names exist in the modality.
80
+
81
+ Args:
82
+ base_modality: The ModalitySpec to validate against.
83
+
84
+ Raises:
85
+ ValueError: If a band name doesn't exist in the modality's band_order.
86
+ """
87
+ valid_bands = set(base_modality.band_order)
88
+ for group in self.band_groups:
89
+ for band in group:
90
+ if band not in valid_bands:
91
+ raise ValueError(
92
+ f"Band '{band}' not found in modality '{base_modality.name}'. "
93
+ f"Valid bands: {valid_bands}"
94
+ )
95
+
96
+
97
+ @dataclass
98
+ class TokenizationConfig:
99
+ """Configuration for custom tokenization strategies.
100
+
101
+ Allows overriding the default bandset groupings for specific modalities.
102
+ Modalities without overrides use their default bandset configuration
103
+ from ModalitySpec.
104
+ """
105
+
106
+ overrides: dict[str, ModalityTokenization] = field(default_factory=dict)
107
+ _bandset_indices_cache: dict[str, list[list[int]]] = field(
108
+ default_factory=dict, init=False, repr=False
109
+ )
110
+
111
+ def get_bandset_indices(self, modality_name: str) -> list[list[int]]:
112
+ """Get band indices for tokenization, using override or default.
113
+
114
+ Args:
115
+ modality_name: Name of the modality.
116
+
117
+ Returns:
118
+ List of index lists, one per bandset/token group.
119
+
120
+ Raises:
121
+ ValueError: If modality_name is invalid or band names don't exist.
122
+ """
123
+ # Check cache first
124
+ if modality_name in self._bandset_indices_cache:
125
+ return self._bandset_indices_cache[modality_name]
126
+
127
+ try:
128
+ base_spec = Modality.get(modality_name)
129
+ except (AttributeError, AssertionError) as e:
130
+ raise ValueError(f"Invalid modality: {modality_name}") from e
131
+
132
+ if modality_name in self.overrides:
133
+ result = self.overrides[modality_name].compute_indices(base_spec)
134
+ else:
135
+ result = base_spec.bandsets_as_indices()
136
+
137
+ # Cache the result
138
+ self._bandset_indices_cache[modality_name] = result
139
+ return result
140
+
141
+ def get_num_bandsets(self, modality_name: str) -> int:
142
+ """Get number of bandsets (tokens per spatial location).
143
+
144
+ Args:
145
+ modality_name: Name of the modality.
146
+
147
+ Returns:
148
+ Number of bandsets.
149
+
150
+ Raises:
151
+ ValueError: If modality_name is invalid.
152
+ """
153
+ if modality_name in self.overrides:
154
+ return self.overrides[modality_name].num_band_sets
155
+ try:
156
+ return Modality.get(modality_name).num_band_sets
157
+ except (AttributeError, AssertionError) as e:
158
+ raise ValueError(f"Invalid modality: {modality_name}") from e
159
+
160
+ def get_num_bands_per_bandset(self, modality_name: str) -> list[int]:
161
+ """Get the number of bands in each bandset.
162
+
163
+ Args:
164
+ modality_name: Name of the modality.
165
+
166
+ Returns:
167
+ List of band counts, one per bandset.
168
+
169
+ Raises:
170
+ ValueError: If modality_name is invalid.
171
+ """
172
+ if modality_name in self.overrides:
173
+ return self.overrides[modality_name].get_num_bands_per_group()
174
+ try:
175
+ base_spec = Modality.get(modality_name)
176
+ except (AttributeError, AssertionError) as e:
177
+ raise ValueError(f"Invalid modality: {modality_name}") from e
178
+ return [len(bs.bands) for bs in base_spec.band_sets]
179
+
180
+ def validate(self) -> None:
181
+ """Validate all overrides against their modalities.
182
+
183
+ Raises:
184
+ ValueError: If any modality name or band name is invalid.
185
+ """
186
+ for modality_name, tokenization in self.overrides.items():
187
+ try:
188
+ base_spec = Modality.get(modality_name)
189
+ except (AttributeError, AssertionError):
190
+ raise ValueError(
191
+ f"Invalid modality name in overrides: '{modality_name}'. "
192
+ f"Valid modalities: {Modality.names()}"
193
+ )
194
+ tokenization.validate(base_spec)
@@ -0,0 +1,83 @@
1
+ """Utilities for the nn module."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torch.distributed import DeviceMesh
7
+
8
+
9
+ def unpack_encoder_output(
10
+ output_dict: dict[str, Any],
11
+ ) -> tuple:
12
+ """Unpack the output of an encoder.
13
+
14
+ Args:
15
+ output_dict (dict[str, Any]): The output of an encoder.
16
+
17
+ Returns:
18
+ tuple[TokensAndMasks, TokensAndMasks, dict[str, Any]]: The unpacked output.
19
+ """
20
+ latent = output_dict.pop("tokens_and_masks", None)
21
+ latent_projected_and_pooled = output_dict.pop("project_aggregated", None)
22
+ # Pass through all other outputs that might be specific to an encoder decoder pair
23
+ # remove token_norm_stats
24
+ output_dict.pop("token_norm_stats", None)
25
+ decoder_kwargs = output_dict
26
+ return latent, latent_projected_and_pooled, decoder_kwargs
27
+
28
+
29
+ def get_cumulative_sequence_lengths(seq_lengths: torch.Tensor) -> torch.Tensor:
30
+ """Get the cumulative sequence lengths of a tensor.
31
+
32
+ Args:
33
+ seq_lengths (torch.Tensor): The sequence lengths of a tensor.
34
+
35
+ Returns:
36
+ torch.Tensor: The cumulative sequence lengths of a tensor.
37
+ """
38
+ return torch.cat(
39
+ [
40
+ torch.tensor([0], dtype=torch.int32, device=seq_lengths.device),
41
+ torch.cumsum(
42
+ seq_lengths.masked_select(seq_lengths != 0), 0, dtype=torch.int32
43
+ ),
44
+ ]
45
+ )
46
+
47
+
48
+ # TODO: maybe this should just be functional or something
49
+ class DistributedMixins:
50
+ """Mixin for distributed training."""
51
+
52
+ def apply_ddp(
53
+ self,
54
+ dp_mesh: DeviceMesh | None = None,
55
+ compile_enabled: bool = False,
56
+ autograd_compile_enabled: bool = False,
57
+ find_unused_parameters: bool = True,
58
+ ) -> None:
59
+ """Apply DDP to the model.
60
+
61
+ .. warning::
62
+ Usually this does not need to be called directly, as :meth:`TransformerConfig.build()`
63
+ will call it for you.
64
+ """
65
+ from torch.distributed._composable.replicate import replicate
66
+
67
+ # Adapted from
68
+ # https://github.com/pytorch/torchtitan/blob/90c889e972b56b9faadebbb78fc985dedc537ed9/torchtitan/parallelisms/parallelize_llama.py#L328
69
+ if compile_enabled:
70
+ if autograd_compile_enabled:
71
+ torch._dynamo.config.optimize_ddp = (
72
+ "python_reducer_without_compiled_forward" # type: ignore
73
+ )
74
+ else:
75
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer" # type: ignore
76
+ # Forwards kwargs to torch DDP class, find_unused_parameters=True is required for MAE
77
+ # Small performance hit could be possible for other models
78
+ replicate(
79
+ self,
80
+ device_mesh=dp_mesh,
81
+ bucket_cap_mb=100,
82
+ find_unused_parameters=find_unused_parameters,
83
+ )
@@ -0,0 +1,152 @@
1
+ """OlmoEarth Pretrain v1 model initialization.
2
+
3
+ This module provides a simple interface to initialize OlmoEarth v1 models.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Literal
9
+
10
+ import torch
11
+
12
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality
13
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import EncoderConfig, PredictorConfig
14
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.latent_mim import LatentMIM, LatentMIMConfig
15
+
16
+ # Model size configurations matching the official OlmoEarth v1 models
17
+ MODEL_SIZE_CONFIGS = {
18
+ "nano_shallow_decoder": {
19
+ "decoder_depth": 4,
20
+ "encoder_embedding_size": 128,
21
+ "decoder_embedding_size": 128,
22
+ "encoder_depth": 4,
23
+ "encoder_num_heads": 8,
24
+ "decoder_num_heads": 8,
25
+ "mlp_ratio": 4.0,
26
+ },
27
+ "tiny_shallow_decoder": {
28
+ "decoder_depth": 4,
29
+ "encoder_embedding_size": 192,
30
+ "decoder_embedding_size": 192,
31
+ "encoder_depth": 12,
32
+ "encoder_num_heads": 3,
33
+ "decoder_num_heads": 3,
34
+ "mlp_ratio": 4.0,
35
+ },
36
+ "base_shallow_decoder": {
37
+ "decoder_depth": 4,
38
+ "encoder_embedding_size": 768,
39
+ "decoder_embedding_size": 768,
40
+ "encoder_depth": 12,
41
+ "encoder_num_heads": 12,
42
+ "decoder_num_heads": 12,
43
+ "mlp_ratio": 4.0,
44
+ },
45
+ "large_shallow_decoder": {
46
+ "decoder_depth": 4,
47
+ "encoder_embedding_size": 1024,
48
+ "decoder_embedding_size": 1024,
49
+ "encoder_depth": 24,
50
+ "encoder_num_heads": 16,
51
+ "decoder_num_heads": 16,
52
+ "mlp_ratio": 4.0,
53
+ },
54
+ }
55
+
56
+ # Default modalities used in OlmoEarth v1 training
57
+ DEFAULT_MODALITIES = [
58
+ Modality.SENTINEL2_L2A.name,
59
+ Modality.SENTINEL1.name,
60
+ Modality.LANDSAT.name,
61
+ Modality.WORLDCOVER.name,
62
+ Modality.SRTM.name,
63
+ Modality.OPENSTREETMAP_RASTER.name,
64
+ Modality.WRI_CANOPY_HEIGHT_MAP.name,
65
+ Modality.CDL.name,
66
+ Modality.WORLDCEREAL.name,
67
+ ]
68
+
69
+
70
+ class OlmoEarthPretrain_v1(torch.nn.Module):
71
+ """OlmoEarth Pretrain v1 model.
72
+
73
+ This class provides a simple interface to initialize OlmoEarth v1 models
74
+ directly from the repository. Models are initialized with random weights.
75
+
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ model_size: Literal["nano", "tiny", "base", "large"] = "nano",
81
+ supported_modality_names: list[str] | None = None,
82
+ max_patch_size: int = 8,
83
+ max_sequence_length: int = 12,
84
+ drop_path: float = 0.1,
85
+ ) -> None:
86
+ """Initialize an OlmoEarth Pretrain v1 model.
87
+
88
+ Args:
89
+ model_size: Size of the model. Options: "nano", "tiny", "base", "large".
90
+ supported_modality_names: List of modality names to support. If None,
91
+ uses the default modalities from OlmoEarth v1 training.
92
+ max_patch_size: Maximum patch size for the encoder.
93
+ max_sequence_length: Maximum sequence length.
94
+ drop_path: Drop path rate for regularization.
95
+ """
96
+ super().__init__()
97
+
98
+ # Map user-facing model size to internal config key with shallow_decoder suffix
99
+ config_key = f"{model_size}_shallow_decoder"
100
+ if config_key not in MODEL_SIZE_CONFIGS:
101
+ raise ValueError(
102
+ f"Invalid model_size: {model_size}. "
103
+ f"Must be one of {['nano', 'tiny', 'base', 'large']}"
104
+ )
105
+
106
+ if supported_modality_names is None:
107
+ supported_modality_names = DEFAULT_MODALITIES
108
+
109
+ model_config = MODEL_SIZE_CONFIGS[config_key]
110
+
111
+ # Build encoder config
112
+ encoder_config = EncoderConfig(
113
+ embedding_size=model_config["encoder_embedding_size"],
114
+ num_heads=model_config["encoder_num_heads"],
115
+ depth=model_config["encoder_depth"],
116
+ mlp_ratio=model_config["mlp_ratio"],
117
+ supported_modality_names=supported_modality_names,
118
+ max_patch_size=max_patch_size,
119
+ drop_path=drop_path,
120
+ max_sequence_length=max_sequence_length,
121
+ )
122
+
123
+ # Build decoder config
124
+ decoder_config = PredictorConfig(
125
+ encoder_embedding_size=model_config["encoder_embedding_size"],
126
+ decoder_embedding_size=model_config["decoder_embedding_size"],
127
+ depth=model_config["decoder_depth"],
128
+ mlp_ratio=model_config["mlp_ratio"],
129
+ num_heads=model_config["decoder_num_heads"],
130
+ supported_modality_names=supported_modality_names,
131
+ max_sequence_length=max_sequence_length,
132
+ )
133
+
134
+ # Build model config and initialize the model
135
+ model_config_obj = LatentMIMConfig(
136
+ encoder_config=encoder_config,
137
+ decoder_config=decoder_config,
138
+ )
139
+
140
+ self.model = model_config_obj.build()
141
+
142
+ def forward(self, *args, **kwargs):
143
+ """Forward pass through the model."""
144
+ return self.model(*args, **kwargs)
145
+
146
+ def __getattr__(self, name: str):
147
+ """Delegate attribute access to the underlying model."""
148
+ try:
149
+ return super().__getattr__(name)
150
+ except AttributeError:
151
+ return getattr(self.model, name)
152
+
@@ -0,0 +1,2 @@
1
+ """Utility modules for OlmoEarth Pretrain."""
2
+