olmoearth-pretrain-minimal 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. olmoearth_pretrain_minimal/__init__.py +16 -0
  2. olmoearth_pretrain_minimal/model_loader.py +123 -0
  3. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/__init__.py +6 -0
  4. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/__init__.py +1 -0
  5. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/attention.py +559 -0
  6. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/encodings.py +115 -0
  7. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_patch_embed.py +304 -0
  8. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/flexi_vit.py +2219 -0
  9. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/latent_mim.py +166 -0
  10. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/tokenization.py +194 -0
  11. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/nn/utils.py +83 -0
  12. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/olmoearth_pretrain_v1.py +152 -0
  13. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/__init__.py +2 -0
  14. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/config.py +264 -0
  15. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/constants.py +519 -0
  16. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/datatypes.py +165 -0
  17. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/decorators.py +75 -0
  18. olmoearth_pretrain_minimal/olmoearth_pretrain_v1/utils/types.py +8 -0
  19. olmoearth_pretrain_minimal/test.py +51 -0
  20. olmoearth_pretrain_minimal-0.0.1.dist-info/METADATA +326 -0
  21. olmoearth_pretrain_minimal-0.0.1.dist-info/RECORD +24 -0
  22. olmoearth_pretrain_minimal-0.0.1.dist-info/WHEEL +5 -0
  23. olmoearth_pretrain_minimal-0.0.1.dist-info/licenses/LICENSE +204 -0
  24. olmoearth_pretrain_minimal-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,16 @@
1
+ """Root package for the OlmoEarth Pretrain Minimal library."""
2
+
3
+ from olmoearth_pretrain_minimal.model_loader import (
4
+ ModelID,
5
+ load_model_from_id,
6
+ load_model_from_path,
7
+ )
8
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1 import OlmoEarthPretrain_v1
9
+
10
+ __all__ = [
11
+ "OlmoEarthPretrain_v1",
12
+ "ModelID",
13
+ "load_model_from_id",
14
+ "load_model_from_path",
15
+ ]
16
+
@@ -0,0 +1,123 @@
1
+ """Load the OlmoEarth models from Hugging Face.
2
+
3
+ This module works with or without olmo-core installed:
4
+ - Without olmo-core: inference-only mode (loading pre-trained models)
5
+ - With olmo-core: full functionality including training
6
+
7
+ The weights are converted to pth file from distributed checkpoint like this:
8
+
9
+ import json
10
+ from pathlib import Path
11
+
12
+ import torch
13
+
14
+ from olmo_core.config import Config
15
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state
16
+
17
+ checkpoint_path = Path("/weka/dfive-default/helios/checkpoints/joer/nano_lr0.001_wd0.002/step370000")
18
+ with (checkpoint_path / "config.json").open() as f:
19
+ config_dict = json.load(f)
20
+ model_config = Config.from_dict(config_dict["model"])
21
+
22
+ model = model_config.build()
23
+
24
+ train_module_dir = checkpoint_path / "model_and_optim"
25
+ load_model_and_optim_state(str(train_module_dir), model)
26
+ torch.save(model.state_dict(), "OlmoEarth-v1-Nano.pth")
27
+ """
28
+
29
+ import json
30
+ from enum import StrEnum
31
+ from os import PathLike
32
+
33
+ import torch
34
+ from huggingface_hub import hf_hub_download
35
+ from upath import UPath
36
+
37
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.config import Config
38
+
39
+ CONFIG_FILENAME = "config.json"
40
+ WEIGHTS_FILENAME = "weights.pth"
41
+
42
+
43
+ class ModelID(StrEnum):
44
+ """OlmoEarth pre-trained model ID."""
45
+
46
+ OLMOEARTH_V1_NANO = "OlmoEarth-v1-Nano"
47
+ OLMOEARTH_V1_TINY = "OlmoEarth-v1-Tiny"
48
+ OLMOEARTH_V1_BASE = "OlmoEarth-v1-Base"
49
+ OLMOEARTH_V1_LARGE = "OlmoEarth-v1-Large"
50
+
51
+ def repo_id(self) -> str:
52
+ """Return the Hugging Face repo ID for this model."""
53
+ return f"allenai/{self.value}"
54
+
55
+
56
+ def load_model_from_id(model_id: ModelID, load_weights: bool = True) -> torch.nn.Module:
57
+ """Initialize and load the weights for the specified model from Hugging Face.
58
+
59
+ Args:
60
+ model_id: the model ID to load.
61
+ load_weights: whether to load the weights. Set false to skip downloading the
62
+ weights from Hugging Face and leave them randomly initialized. Note that
63
+ the config.json will still be downloaded from Hugging Face.
64
+ """
65
+ config_fpath = _resolve_artifact_path(model_id, CONFIG_FILENAME)
66
+ model = _load_model_from_config(config_fpath)
67
+
68
+ if not load_weights:
69
+ return model
70
+
71
+ state_dict_fpath = _resolve_artifact_path(model_id, WEIGHTS_FILENAME)
72
+ state_dict = _load_state_dict(state_dict_fpath)
73
+ model.load_state_dict(state_dict)
74
+ return model
75
+
76
+
77
+ def load_model_from_path(
78
+ model_path: PathLike | str, load_weights: bool = True
79
+ ) -> torch.nn.Module:
80
+ """Initialize and load the weights for the specified model from a path.
81
+
82
+ Args:
83
+ model_path: the path to the model.
84
+ load_weights: whether to load the weights. Set false to skip downloading the
85
+ weights from Hugging Face and leave them randomly initialized. Note that
86
+ """
87
+ config_fpath = _resolve_artifact_path(model_path, CONFIG_FILENAME)
88
+ model = _load_model_from_config(config_fpath)
89
+
90
+ if not load_weights:
91
+ return model
92
+
93
+ state_dict_fpath = _resolve_artifact_path(model_path, WEIGHTS_FILENAME)
94
+ state_dict = _load_state_dict(state_dict_fpath)
95
+ model.load_state_dict(state_dict)
96
+ return model
97
+
98
+
99
+ def _resolve_artifact_path(
100
+ model_id_or_path: ModelID | PathLike | str, filename: str
101
+ ) -> UPath:
102
+ """Resolve the artifact file path for the specified model ID or path, downloading it from Hugging Face if necessary."""
103
+ if isinstance(model_id_or_path, ModelID):
104
+ return UPath(
105
+ hf_hub_download(repo_id=model_id_or_path.repo_id(), filename=filename) # nosec
106
+ )
107
+ base = UPath(model_id_or_path)
108
+ return base / filename
109
+
110
+
111
+ def _load_model_from_config(path: UPath) -> torch.nn.Module:
112
+ """Load the model config from the specified path."""
113
+ with path.open() as f:
114
+ config_dict = json.load(f)
115
+ model_config = Config.from_dict(config_dict["model"])
116
+ return model_config.build()
117
+
118
+
119
+ def _load_state_dict(path: UPath) -> dict[str, torch.Tensor]:
120
+ """Load the model state dict from the specified path."""
121
+ with path.open("rb") as f:
122
+ state_dict = torch.load(f, map_location="cpu")
123
+ return state_dict
@@ -0,0 +1,6 @@
1
+ """OlmoEarth Pretrain v1 model package."""
2
+
3
+ from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.olmoearth_pretrain_v1 import OlmoEarthPretrain_v1
4
+
5
+ __all__ = ["OlmoEarthPretrain_v1"]
6
+
@@ -0,0 +1 @@
1
+ """OlmoEarth Pretrain neural network modules."""