olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- olmoearth_pretrain_minimal/__init__.py +16 -0
- olmoearth_pretrain_minimal/model_loader.py +123 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
- olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
- olmoearth_pretrain_minimal/test.py +51 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
- olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Root package for the OlmoEarth Pretrain Minimal library."""
|
|
2
|
+
|
|
3
|
+
from olmoearth_pretrain_minimal.model_loader import (
|
|
4
|
+
ModelID,
|
|
5
|
+
load_model_from_id,
|
|
6
|
+
load_model_from_path,
|
|
7
|
+
)
|
|
8
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1 import OlmoEarthPretrain_v1
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"OlmoEarthPretrain_v1",
|
|
12
|
+
"ModelID",
|
|
13
|
+
"load_model_from_id",
|
|
14
|
+
"load_model_from_path",
|
|
15
|
+
]
|
|
16
|
+
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Load the OlmoEarth models from Hugging Face.
|
|
2
|
+
|
|
3
|
+
This module works with or without olmo-core installed:
|
|
4
|
+
- Without olmo-core: inference-only mode (loading pre-trained models)
|
|
5
|
+
- With olmo-core: full functionality including training
|
|
6
|
+
|
|
7
|
+
The weights are converted to pth file from distributed checkpoint like this:
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from olmo_core.config import Config
|
|
15
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
16
|
+
|
|
17
|
+
checkpoint_path = Path("/weka/dfive-default/helios/checkpoints/joer/nano_lr0.001_wd0.002/step370000")
|
|
18
|
+
with (checkpoint_path / "config.json").open() as f:
|
|
19
|
+
config_dict = json.load(f)
|
|
20
|
+
model_config = Config.from_dict(config_dict["model"])
|
|
21
|
+
|
|
22
|
+
model = model_config.build()
|
|
23
|
+
|
|
24
|
+
train_module_dir = checkpoint_path / "model_and_optim"
|
|
25
|
+
load_model_and_optim_state(str(train_module_dir), model)
|
|
26
|
+
torch.save(model.state_dict(), "OlmoEarth-v1-Nano.pth")
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import json
|
|
30
|
+
from enum import StrEnum
|
|
31
|
+
from os import PathLike
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
from huggingface_hub import hf_hub_download
|
|
35
|
+
from upath import UPath
|
|
36
|
+
|
|
37
|
+
from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
|
|
38
|
+
|
|
39
|
+
CONFIG_FILENAME = "config.json"
|
|
40
|
+
WEIGHTS_FILENAME = "weights.pth"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelID(StrEnum):
|
|
44
|
+
"""OlmoEarth pre-trained model ID."""
|
|
45
|
+
|
|
46
|
+
OLMOEARTH_V1_NANO = "OlmoEarth-v1-Nano"
|
|
47
|
+
OLMOEARTH_V1_TINY = "OlmoEarth-v1-Tiny"
|
|
48
|
+
OLMOEARTH_V1_BASE = "OlmoEarth-v1-Base"
|
|
49
|
+
OLMOEARTH_V1_LARGE = "OlmoEarth-v1-Large"
|
|
50
|
+
|
|
51
|
+
def repo_id(self) -> str:
|
|
52
|
+
"""Return the Hugging Face repo ID for this model."""
|
|
53
|
+
return f"allenai/{self.value}"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def load_model_from_id(model_id: ModelID, load_weights: bool = True) -> torch.nn.Module:
|
|
57
|
+
"""Initialize and load the weights for the specified model from Hugging Face.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
model_id: the model ID to load.
|
|
61
|
+
load_weights: whether to load the weights. Set false to skip downloading the
|
|
62
|
+
weights from Hugging Face and leave them randomly initialized. Note that
|
|
63
|
+
the config.json will still be downloaded from Hugging Face.
|
|
64
|
+
"""
|
|
65
|
+
config_fpath = _resolve_artifact_path(model_id, CONFIG_FILENAME)
|
|
66
|
+
model = _load_model_from_config(config_fpath)
|
|
67
|
+
|
|
68
|
+
if not load_weights:
|
|
69
|
+
return model
|
|
70
|
+
|
|
71
|
+
state_dict_fpath = _resolve_artifact_path(model_id, WEIGHTS_FILENAME)
|
|
72
|
+
state_dict = _load_state_dict(state_dict_fpath)
|
|
73
|
+
model.load_state_dict(state_dict)
|
|
74
|
+
return model
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def load_model_from_path(
|
|
78
|
+
model_path: PathLike | str, load_weights: bool = True
|
|
79
|
+
) -> torch.nn.Module:
|
|
80
|
+
"""Initialize and load the weights for the specified model from a path.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_path: the path to the model.
|
|
84
|
+
load_weights: whether to load the weights. Set false to skip downloading the
|
|
85
|
+
weights from Hugging Face and leave them randomly initialized. Note that
|
|
86
|
+
"""
|
|
87
|
+
config_fpath = _resolve_artifact_path(model_path, CONFIG_FILENAME)
|
|
88
|
+
model = _load_model_from_config(config_fpath)
|
|
89
|
+
|
|
90
|
+
if not load_weights:
|
|
91
|
+
return model
|
|
92
|
+
|
|
93
|
+
state_dict_fpath = _resolve_artifact_path(model_path, WEIGHTS_FILENAME)
|
|
94
|
+
state_dict = _load_state_dict(state_dict_fpath)
|
|
95
|
+
model.load_state_dict(state_dict)
|
|
96
|
+
return model
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _resolve_artifact_path(
|
|
100
|
+
model_id_or_path: ModelID | PathLike | str, filename: str
|
|
101
|
+
) -> UPath:
|
|
102
|
+
"""Resolve the artifact file path for the specified model ID or path, downloading it from Hugging Face if necessary."""
|
|
103
|
+
if isinstance(model_id_or_path, ModelID):
|
|
104
|
+
return UPath(
|
|
105
|
+
hf_hub_download(repo_id=model_id_or_path.repo_id(), filename=filename) # nosec
|
|
106
|
+
)
|
|
107
|
+
base = UPath(model_id_or_path)
|
|
108
|
+
return base / filename
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _load_model_from_config(path: UPath) -> torch.nn.Module:
|
|
112
|
+
"""Load the model config from the specified path."""
|
|
113
|
+
with path.open() as f:
|
|
114
|
+
config_dict = json.load(f)
|
|
115
|
+
model_config = Config.from_dict(config_dict["model"])
|
|
116
|
+
return model_config.build()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _load_state_dict(path: UPath) -> dict[str, torch.Tensor]:
|
|
120
|
+
"""Load the model state dict from the specified path."""
|
|
121
|
+
with path.open("rb") as f:
|
|
122
|
+
state_dict = torch.load(f, map_location="cpu")
|
|
123
|
+
return state_dict
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""OlmoEarth Pretrain neural network modules."""
|