olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- olmoearth_pretrain_minimal/__init__.py +16 -0
- olmoearth_pretrain_minimal/model_loader.py +123 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
- olmoearth_pretrain_minimal/test.py +51 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Standalone config handling for olmoearth_pretrain_minimal.
|
|
2
|
+
|
|
3
|
+
This module provides a minimal Config class for inference-only mode.
|
|
4
|
+
It does not depend on olmo-core and supports loading models from JSON configs.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class MyConfig(Config):
|
|
11
|
+
...
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import warnings
|
|
17
|
+
from dataclasses import dataclass, fields, is_dataclass
|
|
18
|
+
from importlib import import_module
|
|
19
|
+
from typing import Any, TypeVar
|
|
20
|
+
|
|
21
|
+
# olmo-core is not used in the minimal package
|
|
22
|
+
OLMO_CORE_AVAILABLE = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
C = TypeVar("C", bound="_StandaloneConfig")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _StandaloneConfig:
|
|
30
|
+
"""Minimal Config for inference-only mode without olmo-core.
|
|
31
|
+
|
|
32
|
+
This provides just enough functionality to deserialize model configs from JSON
|
|
33
|
+
and build models. It intentionally does NOT support:
|
|
34
|
+
- OmegaConf-based merging
|
|
35
|
+
- CLI overrides via dotlist
|
|
36
|
+
- YAML loading
|
|
37
|
+
- Validation beyond what dataclasses provide
|
|
38
|
+
|
|
39
|
+
For full functionality, install olmo-core.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
CLASS_NAME_FIELD = "_CLASS_"
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def _resolve_class(cls, class_name: str) -> type | None:
|
|
46
|
+
"""Resolve a fully-qualified class name to a class object."""
|
|
47
|
+
if "." not in class_name:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
# Map old package paths to new ones for compatibility
|
|
51
|
+
# Handle both "helios" (old name) and "olmoearth_pretrain" package names
|
|
52
|
+
if class_name.startswith("helios."):
|
|
53
|
+
class_name = class_name.replace("helios.", "olmoearth_pretrain_minimal.olmoearth_pretrain_v1.", 1)
|
|
54
|
+
# Fix common typos in config files
|
|
55
|
+
class_name = class_name.replace("flexihelios", "flexi_vit")
|
|
56
|
+
elif class_name.startswith("olmoearth_pretrain."):
|
|
57
|
+
class_name = class_name.replace("olmoearth_pretrain.", "olmoearth_pretrain_minimal.olmoearth_pretrain_v1.", 1)
|
|
58
|
+
|
|
59
|
+
*modules, cls_name = class_name.split(".")
|
|
60
|
+
module_name = ".".join(modules)
|
|
61
|
+
try:
|
|
62
|
+
module = import_module(module_name)
|
|
63
|
+
return getattr(module, cls_name)
|
|
64
|
+
except (ImportError, AttributeError):
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def _clean_data(cls, data: Any) -> Any:
|
|
69
|
+
"""Recursively clean data, resolving _CLASS_ fields to actual instances."""
|
|
70
|
+
if isinstance(data, dict):
|
|
71
|
+
# Check if this dict represents a config class
|
|
72
|
+
class_name = data.get(cls.CLASS_NAME_FIELD)
|
|
73
|
+
|
|
74
|
+
# First, recursively clean all nested values
|
|
75
|
+
# This will resolve nested configs that have _CLASS_ fields
|
|
76
|
+
cleaned = {}
|
|
77
|
+
for k, v in data.items():
|
|
78
|
+
if k != cls.CLASS_NAME_FIELD:
|
|
79
|
+
cleaned_value = cls._clean_data(v)
|
|
80
|
+
cleaned[k] = cleaned_value
|
|
81
|
+
|
|
82
|
+
if class_name is not None:
|
|
83
|
+
resolved_cls = cls._resolve_class(class_name)
|
|
84
|
+
if resolved_cls is not None and is_dataclass(resolved_cls):
|
|
85
|
+
# Get the field names for this dataclass
|
|
86
|
+
field_names = {f.name for f in fields(resolved_cls)}
|
|
87
|
+
# Filter to only include valid fields
|
|
88
|
+
valid_kwargs = {
|
|
89
|
+
k: v for k, v in cleaned.items() if k in field_names
|
|
90
|
+
}
|
|
91
|
+
# Ensure nested dicts that should be Config instances are resolved
|
|
92
|
+
# The recursive _clean_data() should have resolved them, but resolve any remaining dicts
|
|
93
|
+
for key, value in list(valid_kwargs.items()):
|
|
94
|
+
if isinstance(value, dict) and not is_dataclass(value):
|
|
95
|
+
# Try to resolve as Config using from_dict
|
|
96
|
+
if cls.CLASS_NAME_FIELD in value:
|
|
97
|
+
nested_class_name = value[cls.CLASS_NAME_FIELD]
|
|
98
|
+
nested_resolved_cls = cls._resolve_class(nested_class_name)
|
|
99
|
+
if nested_resolved_cls is not None and is_dataclass(nested_resolved_cls):
|
|
100
|
+
nested_dict = {k: v for k, v in value.items() if k != cls.CLASS_NAME_FIELD}
|
|
101
|
+
valid_kwargs[key] = nested_resolved_cls.from_dict(nested_dict)
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Could not resolve nested config class '{nested_class_name}' for field '{key}'"
|
|
105
|
+
)
|
|
106
|
+
try:
|
|
107
|
+
return resolved_cls(**valid_kwargs)
|
|
108
|
+
except TypeError as e:
|
|
109
|
+
raise TypeError(
|
|
110
|
+
f"Failed to instantiate {class_name}: {e}"
|
|
111
|
+
) from e
|
|
112
|
+
# If class resolution failed, keep _CLASS_ field in dict for from_dict() to retry
|
|
113
|
+
cleaned[cls.CLASS_NAME_FIELD] = class_name
|
|
114
|
+
return cleaned
|
|
115
|
+
|
|
116
|
+
elif isinstance(data, list | tuple):
|
|
117
|
+
cleaned_items = [cls._clean_data(item) for item in data]
|
|
118
|
+
return type(data)(cleaned_items)
|
|
119
|
+
|
|
120
|
+
else:
|
|
121
|
+
return data
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_dict(
|
|
125
|
+
cls: type[C], data: dict[str, Any], overrides: list[str] | None = None
|
|
126
|
+
) -> C:
|
|
127
|
+
"""Deserialize from a dictionary, handling nested _CLASS_ fields.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
data: Dictionary representation of the config.
|
|
131
|
+
overrides: Ignored in standalone mode (requires olmo-core for dotlist support).
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
An instance of the config class.
|
|
135
|
+
|
|
136
|
+
Note:
|
|
137
|
+
The `overrides` parameter is accepted for API compatibility but ignored.
|
|
138
|
+
Install olmo-core for full override support.
|
|
139
|
+
"""
|
|
140
|
+
if overrides:
|
|
141
|
+
warnings.warn(
|
|
142
|
+
"Config overrides are not supported in standalone mode. "
|
|
143
|
+
"Install olmo-core for full functionality.",
|
|
144
|
+
UserWarning,
|
|
145
|
+
stacklevel=2,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
cleaned = cls._clean_data(data)
|
|
149
|
+
|
|
150
|
+
# If _clean_data resolved a config class instance (from _CLASS_ field), return it directly
|
|
151
|
+
if is_dataclass(cleaned) and not isinstance(cleaned, type):
|
|
152
|
+
return cleaned
|
|
153
|
+
elif isinstance(cleaned, cls):
|
|
154
|
+
return cleaned
|
|
155
|
+
elif isinstance(cleaned, dict):
|
|
156
|
+
# Check if the dict has a _CLASS_ field that we should try to resolve
|
|
157
|
+
if cls.CLASS_NAME_FIELD in cleaned:
|
|
158
|
+
class_name = cleaned[cls.CLASS_NAME_FIELD]
|
|
159
|
+
resolved_cls = cls._resolve_class(class_name)
|
|
160
|
+
if resolved_cls is not None and is_dataclass(resolved_cls):
|
|
161
|
+
config_dict = {k: v for k, v in cleaned.items() if k != cls.CLASS_NAME_FIELD}
|
|
162
|
+
return resolved_cls.from_dict(config_dict)
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"Could not resolve class '{class_name}' from _CLASS_ field. "
|
|
166
|
+
f"Make sure the class exists and is importable."
|
|
167
|
+
)
|
|
168
|
+
# No _CLASS_ field, try to create base Config instance
|
|
169
|
+
field_names = {f.name for f in fields(cls)}
|
|
170
|
+
valid_kwargs = {k: v for k, v in cleaned.items() if k in field_names}
|
|
171
|
+
return cls(**valid_kwargs)
|
|
172
|
+
else:
|
|
173
|
+
raise TypeError(f"Expected dict or config instance, got {type(cleaned)}")
|
|
174
|
+
|
|
175
|
+
def as_dict(
|
|
176
|
+
self,
|
|
177
|
+
*,
|
|
178
|
+
exclude_none: bool = False,
|
|
179
|
+
exclude_private_fields: bool = False,
|
|
180
|
+
include_class_name: bool = False,
|
|
181
|
+
json_safe: bool = False,
|
|
182
|
+
recurse: bool = True,
|
|
183
|
+
) -> dict[str, Any]:
|
|
184
|
+
"""Convert to a dictionary.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
exclude_none: Don't include values that are None.
|
|
188
|
+
exclude_private_fields: Don't include private fields (starting with _).
|
|
189
|
+
include_class_name: Include _CLASS_ field with fully-qualified class name.
|
|
190
|
+
json_safe: Convert non-JSON-safe types to strings.
|
|
191
|
+
recurse: Recursively convert nested dataclasses.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Dictionary representation of this config.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def convert(obj: Any) -> Any:
|
|
198
|
+
if is_dataclass(obj) and not isinstance(obj, type):
|
|
199
|
+
result = {}
|
|
200
|
+
if include_class_name:
|
|
201
|
+
result[self.CLASS_NAME_FIELD] = (
|
|
202
|
+
f"{obj.__class__.__module__}.{obj.__class__.__name__}"
|
|
203
|
+
)
|
|
204
|
+
for field in fields(obj):
|
|
205
|
+
if exclude_private_fields and field.name.startswith("_"):
|
|
206
|
+
continue
|
|
207
|
+
value = getattr(obj, field.name)
|
|
208
|
+
if exclude_none and value is None:
|
|
209
|
+
continue
|
|
210
|
+
if recurse:
|
|
211
|
+
value = convert(value)
|
|
212
|
+
result[field.name] = value
|
|
213
|
+
return result
|
|
214
|
+
elif isinstance(obj, dict):
|
|
215
|
+
return {k: convert(v) if recurse else v for k, v in obj.items()}
|
|
216
|
+
elif isinstance(obj, list | tuple | set):
|
|
217
|
+
converted = [convert(item) if recurse else item for item in obj]
|
|
218
|
+
if json_safe:
|
|
219
|
+
return converted
|
|
220
|
+
return type(obj)(converted)
|
|
221
|
+
elif obj is None or isinstance(obj, float | int | bool | str):
|
|
222
|
+
return obj
|
|
223
|
+
elif json_safe:
|
|
224
|
+
return str(obj)
|
|
225
|
+
else:
|
|
226
|
+
return obj
|
|
227
|
+
|
|
228
|
+
return convert(self)
|
|
229
|
+
|
|
230
|
+
def as_config_dict(self) -> dict[str, Any]:
|
|
231
|
+
"""Convert to a JSON-safe dictionary suitable for serialization.
|
|
232
|
+
|
|
233
|
+
This is a convenience wrapper around as_dict() with settings appropriate
|
|
234
|
+
for saving configs to JSON files.
|
|
235
|
+
"""
|
|
236
|
+
return self.as_dict(
|
|
237
|
+
exclude_none=True,
|
|
238
|
+
exclude_private_fields=True,
|
|
239
|
+
include_class_name=True,
|
|
240
|
+
json_safe=True,
|
|
241
|
+
recurse=True,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def validate(self) -> None:
|
|
245
|
+
"""Validate the config. Override in subclasses."""
|
|
246
|
+
pass
|
|
247
|
+
|
|
248
|
+
def build(self) -> Any:
|
|
249
|
+
"""Build the object this config represents.
|
|
250
|
+
|
|
251
|
+
Subclasses must implement this method.
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
NotImplementedError: Always, unless overridden by subclass.
|
|
255
|
+
"""
|
|
256
|
+
raise NotImplementedError("Subclasses must implement build()")
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# === The unified export ===
|
|
260
|
+
# Always use standalone config for minimal package (no olmo-core dependency)
|
|
261
|
+
Config = _StandaloneConfig
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
__all__ = ["Config", "OLMO_CORE_AVAILABLE"]
|