EquFlash 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- equflash-0.0.1/GGNN/__init__.py +0 -0
- equflash-0.0.1/GGNN/_cli.py +113 -0
- equflash-0.0.1/GGNN/common/__init__.py +0 -0
- equflash-0.0.1/GGNN/common/calculator.py +206 -0
- equflash-0.0.1/GGNN/common/lammps_mliap/__init__.py +0 -0
- equflash-0.0.1/GGNN/common/lammps_mliap/lmp_mliap_wrapper.py +211 -0
- equflash-0.0.1/GGNN/common/logger.py +102 -0
- equflash-0.0.1/GGNN/common/statistics.py +151 -0
- equflash-0.0.1/GGNN/common/torchsim_model.py +184 -0
- equflash-0.0.1/GGNN/common/utils.py +524 -0
- equflash-0.0.1/GGNN/datasets/__init__.py +0 -0
- equflash-0.0.1/GGNN/datasets/ase_arrays_dataset.py +56 -0
- equflash-0.0.1/GGNN/datasets/lmdb_dataset.py +72 -0
- equflash-0.0.1/GGNN/datasets/samplers.py +213 -0
- equflash-0.0.1/GGNN/main.py +7 -0
- equflash-0.0.1/GGNN/model/EquFlash/__init__.py +3 -0
- equflash-0.0.1/GGNN/model/EquFlash/equflash.py +218 -0
- equflash-0.0.1/GGNN/model/EquFlash/equflash_comp.py +264 -0
- equflash-0.0.1/GGNN/model/EquFlash/equflash_pol.py +476 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/__init__.py +0 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/_ghost_exchange_base.py +79 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/activation.py +24 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/atomic_reduce.py +227 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/convolution.py +224 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/edge_embedding.py +231 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/equivariant_gate.py +56 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/feed_foward.py +300 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/force_output.py +230 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/interaction_blocks.py +103 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/node_embedding.py +43 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/scale.py +37 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/symmetric_contraction.py +32 -0
- equflash-0.0.1/GGNN/model/EquFlash/nn/util.py +17 -0
- equflash-0.0.1/GGNN/model/EquFlash/statistics.py +151 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/__init__.py +3 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/equflash_v2.py +263 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/equflash_v2_comp.py +298 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/__init__.py +0 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/_ghost_exchange_base.py +79 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/activation.py +24 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/atomic_reduce.py +60 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/convolution.py +229 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/edge_embedding.py +171 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/force_output.py +87 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/interaction_blocks.py +222 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/linear.py +209 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/node_embedding.py +45 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/nonlinears.py +305 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/normalization.py +136 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/scale.py +37 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/skip.py +77 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/nn/util.py +30 -0
- equflash-0.0.1/GGNN/model/EquFlashV2/statistics.py +151 -0
- equflash-0.0.1/GGNN/model/__init__.py +0 -0
- equflash-0.0.1/GGNN/modules/__init__.py +0 -0
- equflash-0.0.1/GGNN/modules/evaluator_ext.py +610 -0
- equflash-0.0.1/GGNN/modules/loss.py +17 -0
- equflash-0.0.1/GGNN/modules/loss_ext.py +805 -0
- equflash-0.0.1/GGNN/preprocessing/__init__.py +0 -0
- equflash-0.0.1/GGNN/preprocessing/atoms_arrays_to_graphs.py +72 -0
- equflash-0.0.1/GGNN/preprocessing/atoms_to_graphs.py +105 -0
- equflash-0.0.1/GGNN/scripts/__init__.py +0 -0
- equflash-0.0.1/GGNN/scripts/create_lmp_mliap_file.py +44 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/equflashv2-45M-oam.yml +85 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/install.sh +5 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/merge_results.py +35 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/ml_relaxation.py +139 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizable.py +712 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizers/base_optimizer.py +82 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizers/fire.py +157 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizers/lbfgs.py +142 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/requirements.txt +35 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/test_equflash_discovery.py +204 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/test_equflash_kappa.py +367 -0
- equflash-0.0.1/GGNN/scripts/matbench_submission/thermal_conductivity.py +80 -0
- equflash-0.0.1/GGNN/trainer/__init__.py +13 -0
- equflash-0.0.1/GGNN/trainer/trainer.py +1754 -0
- equflash-0.0.1/PKG-INFO +350 -0
- equflash-0.0.1/README.md +312 -0
- equflash-0.0.1/pyproject.toml +77 -0
|
File without changes
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is derived from fairchem-core.
|
|
3
|
+
Original source: https://github.com/facebookresearch/fairchem
|
|
4
|
+
|
|
5
|
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
6
|
+
Licensed under the MIT License. See LICENSE for details.
|
|
7
|
+
|
|
8
|
+
Modifications copyright (c) 2026 Samsung Electronics.
|
|
9
|
+
Licensed under CC BY-NC-SA 4.0. See LICENSE for details.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import copy
|
|
15
|
+
import logging
|
|
16
|
+
from typing import TYPE_CHECKING
|
|
17
|
+
|
|
18
|
+
from submitit import AutoExecutor
|
|
19
|
+
from submitit.helpers import Checkpointable, DelayedSubmission
|
|
20
|
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
21
|
+
|
|
22
|
+
from fairchem.core.common import distutils
|
|
23
|
+
from fairchem.core.common.flags import flags
|
|
24
|
+
from fairchem.core.common.utils import (
|
|
25
|
+
build_config,
|
|
26
|
+
create_grid,
|
|
27
|
+
save_experiment_log,
|
|
28
|
+
setup_logging,
|
|
29
|
+
)
|
|
30
|
+
from GGNN.common.utils import new_trainer_context
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
import argparse
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Runner(Checkpointable):
|
|
37
|
+
def __init__(self) -> None:
|
|
38
|
+
self.config = None
|
|
39
|
+
|
|
40
|
+
def __call__(self, config: dict) -> None:
|
|
41
|
+
with new_trainer_context(config=config) as ctx:
|
|
42
|
+
self.config = ctx.config
|
|
43
|
+
self.task = ctx.task
|
|
44
|
+
self.trainer = ctx.trainer
|
|
45
|
+
self.task.setup(self.trainer)
|
|
46
|
+
self.task.run()
|
|
47
|
+
|
|
48
|
+
def checkpoint(self, *args, **kwargs):
|
|
49
|
+
new_runner = Runner()
|
|
50
|
+
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
|
|
51
|
+
self.config["checkpoint"] = self.task.chkpt_path
|
|
52
|
+
self.config["timestamp_id"] = self.trainer.timestamp_id
|
|
53
|
+
if self.trainer.logger is not None:
|
|
54
|
+
self.trainer.logger.mark_preempting()
|
|
55
|
+
logging.info(
|
|
56
|
+
f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}'
|
|
57
|
+
)
|
|
58
|
+
return DelayedSubmission(new_runner, self.config)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def runner_wrapper(config: dict):
|
|
62
|
+
Runner()(config)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def main(
|
|
66
|
+
args: argparse.Namespace | None = None, override_args: list[str] | None = None
|
|
67
|
+
):
|
|
68
|
+
"""Run the main fairchem program."""
|
|
69
|
+
setup_logging()
|
|
70
|
+
|
|
71
|
+
if args is None:
|
|
72
|
+
parser: argparse.ArgumentParser = flags.get_parser()
|
|
73
|
+
args, override_args = parser.parse_known_args()
|
|
74
|
+
|
|
75
|
+
assert (
|
|
76
|
+
args.num_gpus > 0
|
|
77
|
+
), "num_gpus is used to determine number ranks, so it must be at least 1"
|
|
78
|
+
config = build_config(args, override_args)
|
|
79
|
+
|
|
80
|
+
if args.submit: # Run on cluster
|
|
81
|
+
slurm_add_params = config.get("slurm", None) # additional slurm arguments
|
|
82
|
+
configs = create_grid(config, args.sweep_yml) if args.sweep_yml else [config]
|
|
83
|
+
|
|
84
|
+
logging.info(f"Submitting {len(configs)} jobs")
|
|
85
|
+
executor = AutoExecutor(folder=args.logdir / "%j", slurm_max_num_timeout=3)
|
|
86
|
+
executor.update_parameters(
|
|
87
|
+
name=args.identifier,
|
|
88
|
+
mem_gb=args.slurm_mem,
|
|
89
|
+
timeout_min=args.slurm_timeout * 60,
|
|
90
|
+
slurm_partition=args.slurm_partition,
|
|
91
|
+
gpus_per_node=args.num_gpus,
|
|
92
|
+
cpus_per_task=(config["optim"]["num_workers"] + 1),
|
|
93
|
+
tasks_per_node=args.num_gpus,
|
|
94
|
+
nodes=args.num_nodes,
|
|
95
|
+
slurm_additional_parameters=slurm_add_params,
|
|
96
|
+
slurm_qos=args.slurm_qos,
|
|
97
|
+
slurm_account=args.slurm_account,
|
|
98
|
+
)
|
|
99
|
+
for config in configs:
|
|
100
|
+
config["slurm"] = copy.deepcopy(executor.parameters)
|
|
101
|
+
config["slurm"]["folder"] = str(executor.folder)
|
|
102
|
+
jobs = executor.map_array(Runner(), configs)
|
|
103
|
+
logging.info(f"Submitted jobs: {', '.join([job.job_id for job in jobs])}")
|
|
104
|
+
log_file = save_experiment_log(args, jobs, configs)
|
|
105
|
+
logging.info(f"Experiment log saved to: {log_file}")
|
|
106
|
+
|
|
107
|
+
else: # Run locally on a single node, n-processes
|
|
108
|
+
runner_wrapper(config)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
if __name__ == "__main__":
|
|
113
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is derived from fairchem-core.
|
|
3
|
+
Original source: https://github.com/facebookresearch/fairchem
|
|
4
|
+
|
|
5
|
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
6
|
+
Licensed under the MIT License. See LICENSE for details.
|
|
7
|
+
|
|
8
|
+
Modifications copyright (c) 2026 Samsung Electronics.
|
|
9
|
+
Licensed under CC BY-NC-SA 4.0. See LICENSE for details.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from fairchem.core.common.relaxation.ase_utils import OCPCalculator
|
|
13
|
+
from ..trainer.trainer import Trainer
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
import logging
|
|
17
|
+
from types import MappingProxyType
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from ase import Atoms
|
|
22
|
+
from ase.calculators.calculator import Calculator
|
|
23
|
+
from ase.calculators.singlepoint import SinglePointCalculator
|
|
24
|
+
from ase.constraints import FixAtoms
|
|
25
|
+
from ase.geometry import wrap_positions
|
|
26
|
+
|
|
27
|
+
from fairchem.core.common.registry import registry
|
|
28
|
+
from fairchem.core.common.utils import (
|
|
29
|
+
load_config,
|
|
30
|
+
setup_imports,
|
|
31
|
+
setup_logging,
|
|
32
|
+
update_config,
|
|
33
|
+
)
|
|
34
|
+
from fairchem.core.datasets import data_list_collater
|
|
35
|
+
from fairchem.core.models.model_registry import model_name_to_local_file
|
|
36
|
+
from fairchem.core.preprocessing import AtomsToGraphs
|
|
37
|
+
from collections import OrderedDict
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def convert_compiled_ckpt(ckpt):
|
|
41
|
+
if ckpt["config"]["model"]["name"].endswith("_comp"):
|
|
42
|
+
ckpt["config"]["model"]["name"] = ckpt["config"]["model"]["name"].replace(
|
|
43
|
+
"_comp", ""
|
|
44
|
+
)
|
|
45
|
+
new_dict = OrderedDict()
|
|
46
|
+
for k, v in ckpt["state_dict"].items():
|
|
47
|
+
if k.startswith("module.model") and "z_to_onehot_tensor" not in k:
|
|
48
|
+
new_dict[k.replace(".model.", ".")] = v
|
|
49
|
+
ckpt["state_dict"] = new_dict
|
|
50
|
+
if "task" in ckpt["config"]:
|
|
51
|
+
ckpt["config"].pop("task")
|
|
52
|
+
return ckpt
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if TYPE_CHECKING:
|
|
56
|
+
from pathlib import Path
|
|
57
|
+
|
|
58
|
+
from torch_geometric.data import Batch
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class UCalculator(OCPCalculator):
|
|
62
|
+
"""ASE based calculator using an OCP model"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
config_yml: str | None = None,
|
|
67
|
+
checkpoint_path: str | None = None,
|
|
68
|
+
model_name: str | None = None,
|
|
69
|
+
local_cache: str | None = None,
|
|
70
|
+
trainer: str | None = None,
|
|
71
|
+
cpu: bool = True,
|
|
72
|
+
seed: int | None = None,
|
|
73
|
+
only_output: list[str] | None = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
|
|
76
|
+
setup_imports()
|
|
77
|
+
setup_logging()
|
|
78
|
+
Calculator.__init__(self)
|
|
79
|
+
|
|
80
|
+
if model_name is not None:
|
|
81
|
+
if checkpoint_path is not None:
|
|
82
|
+
raise RuntimeError(
|
|
83
|
+
"model_name and checkpoint_path were both specified, please use only one at a time"
|
|
84
|
+
)
|
|
85
|
+
if local_cache is None:
|
|
86
|
+
raise NotImplementedError(
|
|
87
|
+
"Local cache must be set when specifying a model name"
|
|
88
|
+
)
|
|
89
|
+
checkpoint_path = model_name_to_local_file(
|
|
90
|
+
model_name=model_name, local_cache=local_cache
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Either the config path or the checkpoint path needs to be provided
|
|
94
|
+
assert config_yml or checkpoint_path is not None
|
|
95
|
+
|
|
96
|
+
checkpoint = None
|
|
97
|
+
if config_yml is not None:
|
|
98
|
+
if isinstance(config_yml, str):
|
|
99
|
+
config, duplicates_warning, duplicates_error = load_config(config_yml)
|
|
100
|
+
if len(duplicates_warning) > 0:
|
|
101
|
+
logging.warning(
|
|
102
|
+
f"Overwritten config parameters from included configs "
|
|
103
|
+
f"(non-included parameters take precedence): {duplicates_warning}"
|
|
104
|
+
)
|
|
105
|
+
if len(duplicates_error) > 0:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"Conflicting (duplicate) parameters in simultaneously "
|
|
108
|
+
f"included configs: {duplicates_error}"
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
config = config_yml
|
|
112
|
+
if config["model"]["name"].startswith("compiled"):
|
|
113
|
+
config["model"]["name"] = config["model"]["name"].replace(
|
|
114
|
+
"compiled_", ""
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Only keeps the train data that might have normalizer values
|
|
118
|
+
if isinstance(config["dataset"], list):
|
|
119
|
+
config["dataset"] = config["dataset"][0]
|
|
120
|
+
elif isinstance(config["dataset"], dict):
|
|
121
|
+
config["dataset"] = config["dataset"].get("train", None)
|
|
122
|
+
else:
|
|
123
|
+
# Loads the config from the checkpoint directly (always on CPU).
|
|
124
|
+
checkpoint = torch.load(
|
|
125
|
+
checkpoint_path, map_location=torch.device("cpu"), weights_only=False
|
|
126
|
+
)
|
|
127
|
+
checkpoint = convert_compiled_ckpt(checkpoint)
|
|
128
|
+
config = checkpoint["config"]
|
|
129
|
+
|
|
130
|
+
if trainer is not None:
|
|
131
|
+
config["trainer"] = trainer
|
|
132
|
+
else:
|
|
133
|
+
config["trainer"] = config.get("trainer", "default")
|
|
134
|
+
|
|
135
|
+
if "model_attributes" in config:
|
|
136
|
+
config["model_attributes"]["name"] = config.pop("model")
|
|
137
|
+
config["model"] = config["model_attributes"]
|
|
138
|
+
|
|
139
|
+
# Calculate the edge indices on the fly
|
|
140
|
+
config["model"]["otf_graph"] = True
|
|
141
|
+
|
|
142
|
+
### backwards compatability with OCP v<2.0
|
|
143
|
+
config = update_config(config)
|
|
144
|
+
|
|
145
|
+
self.config = copy.deepcopy(config)
|
|
146
|
+
self.config["checkpoint"] = str(checkpoint_path)
|
|
147
|
+
del config["dataset"]["src"]
|
|
148
|
+
|
|
149
|
+
# some models that are published have configs that include tasks
|
|
150
|
+
# which are not output by the model
|
|
151
|
+
if only_output is not None:
|
|
152
|
+
assert isinstance(
|
|
153
|
+
only_output, list
|
|
154
|
+
), "only output must be a list of targets to output"
|
|
155
|
+
for key in only_output:
|
|
156
|
+
assert (
|
|
157
|
+
key in config["outputs"]
|
|
158
|
+
), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}"
|
|
159
|
+
remove_outputs = set(config["outputs"].keys()) - set(only_output)
|
|
160
|
+
for key in remove_outputs:
|
|
161
|
+
config["outputs"].pop(key)
|
|
162
|
+
|
|
163
|
+
trainer_cls = Trainer
|
|
164
|
+
|
|
165
|
+
self.trainer = trainer_cls(
|
|
166
|
+
task=config.get("task", {}),
|
|
167
|
+
model=config["model"],
|
|
168
|
+
dataset=[config["dataset"]],
|
|
169
|
+
outputs=config["outputs"],
|
|
170
|
+
loss_functions=config["loss_functions"],
|
|
171
|
+
evaluation_metrics=config["evaluation_metrics"],
|
|
172
|
+
optimizer=config["optim"],
|
|
173
|
+
identifier="",
|
|
174
|
+
slurm=config.get("slurm", {}),
|
|
175
|
+
local_rank=config.get("local_rank", 0),
|
|
176
|
+
is_debug=config.get("is_debug", True),
|
|
177
|
+
cpu=cpu,
|
|
178
|
+
amp=config.get("amp", False),
|
|
179
|
+
inference_only=True,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if checkpoint_path is not None:
|
|
183
|
+
self.load_checkpoint(checkpoint_path=checkpoint_path, checkpoint=checkpoint)
|
|
184
|
+
|
|
185
|
+
seed = seed if seed is not None else self.trainer.config["cmd"]["seed"]
|
|
186
|
+
if seed is None:
|
|
187
|
+
logging.warning(
|
|
188
|
+
"No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run"
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
self.trainer.set_seed(seed)
|
|
192
|
+
|
|
193
|
+
if "otf_graph" in config["model"]: # temporalily impemented.
|
|
194
|
+
r_edges = not config["model"]["otf_graph"]
|
|
195
|
+
else:
|
|
196
|
+
r_edges = not self.trainer.model.otf_graph
|
|
197
|
+
|
|
198
|
+
self.a2g = AtomsToGraphs(
|
|
199
|
+
r_energy=False,
|
|
200
|
+
r_forces=False,
|
|
201
|
+
r_distances=False,
|
|
202
|
+
r_pbc=True,
|
|
203
|
+
# r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model
|
|
204
|
+
r_edges=r_edges,
|
|
205
|
+
)
|
|
206
|
+
self.implemented_properties = list(self.config["outputs"].keys())
|
|
File without changes
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
|
9
|
+
except ModuleNotFoundError:
|
|
10
|
+
raise ImportError(
|
|
11
|
+
"LAMMPS ML-IAP has to be installed in the Python environment for NequIP's ML-IAP integration. "
|
|
12
|
+
"See https://nequip.readthedocs.io/en/latest/integrations/lammps/mliap.html for installation instructions."
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from typing import List
|
|
16
|
+
from GGNN.common.calculator import UCalculator,convert_compiled_ckpt
|
|
17
|
+
from GGNN.model.SevenNet.sevennet import SevenNet, convert_compiled_to_original_sevennet
|
|
18
|
+
|
|
19
|
+
from GGNN.model.EquFlash import EquFlash
|
|
20
|
+
from fairchem.core.common.utils import (
|
|
21
|
+
match_state_dict,
|
|
22
|
+
)
|
|
23
|
+
from fairchem.core.common.registry import registry
|
|
24
|
+
from fairchem.core.modules.normalization.element_references import (
|
|
25
|
+
create_element_references,
|
|
26
|
+
)
|
|
27
|
+
from fairchem.core.modules.normalization.normalizer import (
|
|
28
|
+
create_normalizer,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from torch_geometric.data import Batch
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GGNNLAMMPSMLIAPWrapper(MLIAPUnified):
|
|
35
|
+
"""LAMMPS-MLIAP interface for NequIP framework models."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
ckpt_path: str,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
# this is a white lie, unsure if strictly necessary, but just in case
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.ckpt_path = ckpt_path
|
|
45
|
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
46
|
+
self.model = None
|
|
47
|
+
self.device = None
|
|
48
|
+
# model_name = ckpt["config"]["model"].pop("name")
|
|
49
|
+
# self.model = registry.get_model_class(model_name)(self.ckpt["config"]["model"])
|
|
50
|
+
# self.model = registry.get_model_class(model_name)(ckpt["config"]["model"])
|
|
51
|
+
# import pdb;pdb.set_trace()
|
|
52
|
+
# to placate the interface
|
|
53
|
+
self.nparams = 1
|
|
54
|
+
self.ndescriptors = 1
|
|
55
|
+
|
|
56
|
+
# === set model-depnedent params ===
|
|
57
|
+
# what is rcutfac
|
|
58
|
+
# self.rcutfac = 0.5 * float(model.metadata[graph_model.R_MAX_KEY])
|
|
59
|
+
self.rcutfac = 3.0
|
|
60
|
+
import ase
|
|
61
|
+
|
|
62
|
+
self.element_types = (
|
|
63
|
+
ase.data.chemical_symbols
|
|
64
|
+
) # ckpt["config"]["model"]["chemical_species"]
|
|
65
|
+
self.ckpt = ckpt
|
|
66
|
+
|
|
67
|
+
def compute_forces(self, lmp_data):
|
|
68
|
+
# === lazily load model ===
|
|
69
|
+
if self.model == None:
|
|
70
|
+
self.device = (
|
|
71
|
+
"cuda" if "kokkos" in lmp_data.__class__.__module__.lower() else "cpu"
|
|
72
|
+
)
|
|
73
|
+
self.ckpt = convert_compiled_ckpt(self.ckpt)
|
|
74
|
+
model_name = self.ckpt["config"]["model"].pop("name")
|
|
75
|
+
# self.model = registry.get_model_class(model_name)(self.ckpt["config"]["model"])
|
|
76
|
+
if "conv_kwargs" not in self.ckpt["config"]["model"].keys():
|
|
77
|
+
self.ckpt["config"]["model"]["conv_kwargs"] = {}
|
|
78
|
+
self.ckpt["config"]["model"]["conv_kwargs"]["use_lammps_mliap"] = True
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
self.model = registry.get_model_class(model_name)(
|
|
82
|
+
self.ckpt["config"]["model"], lammps_mliap=True
|
|
83
|
+
).to(self.device)
|
|
84
|
+
|
|
85
|
+
new_dict = match_state_dict(
|
|
86
|
+
self.model.state_dict(), self.ckpt["state_dict"]
|
|
87
|
+
)
|
|
88
|
+
self.model.load_state_dict(new_dict, strict=False)
|
|
89
|
+
# print(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
|
|
90
|
+
rank = eval(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
|
|
91
|
+
if rank == 0:
|
|
92
|
+
print(self.model)
|
|
93
|
+
self.normalizers = {}
|
|
94
|
+
for key, state_dict in self.ckpt.get("normalizers", {}).items():
|
|
95
|
+
### Convert old normalizer keys to new target keys
|
|
96
|
+
if key == "target":
|
|
97
|
+
target_key = "energy"
|
|
98
|
+
elif key == "grad_target":
|
|
99
|
+
target_key = "forces"
|
|
100
|
+
else:
|
|
101
|
+
target_key = key
|
|
102
|
+
|
|
103
|
+
if target_key not in self.normalizers:
|
|
104
|
+
self.normalizers[target_key] = create_normalizer(
|
|
105
|
+
state_dict=state_dict
|
|
106
|
+
)
|
|
107
|
+
self.normalizers[target_key].to(self.device)
|
|
108
|
+
self.elementrefs = {}
|
|
109
|
+
for key, state_dict in self.ckpt.get("elementrefs", {}).items():
|
|
110
|
+
if key not in self.elementrefs:
|
|
111
|
+
self.elementrefs[key] = create_element_references(
|
|
112
|
+
state_dict=state_dict
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
mkeys = self.elementrefs[key].load_state_dict(state_dict)
|
|
116
|
+
assert len(mkeys.missing_keys) == 0
|
|
117
|
+
assert len(mkeys.unexpected_keys) == 0
|
|
118
|
+
|
|
119
|
+
self.elementrefs[key].to(self.device)
|
|
120
|
+
self.model.regress_forces = False
|
|
121
|
+
self.model.regress_stress = False
|
|
122
|
+
|
|
123
|
+
if lmp_data.nlocal == 0 or lmp_data.npairs <= 1:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
# === create input data ===
|
|
127
|
+
|
|
128
|
+
# NOTE
|
|
129
|
+
# This LAMMPS ML-IAP integration introduces a new dimension of having `num_local` vs `num_local + num_ghost` number of nodes.
|
|
130
|
+
# There are three crucial dimensions to be aware of `num_edges`, `num_local`, `num_local + num_ghost`.
|
|
131
|
+
# The following input tensors have the following shapes.
|
|
132
|
+
# - `edge_vectors`: (num_edges, 3)
|
|
133
|
+
# - `edge_idxs`: (2, num_edges)
|
|
134
|
+
# - `atom_types`: (num_local + num_ghost)
|
|
135
|
+
|
|
136
|
+
# This LAMMPS ML-IAP wrapper can handle output `atomic_energy` having either shape `num_local` or `num_local + num_ghost` based on the (uncompiled) size check.
|
|
137
|
+
|
|
138
|
+
# Models can perform optimizations based on an understanding of when ghost atoms matter or not and are responsible for carefully handling internal shape logic.
|
|
139
|
+
# Examples include:
|
|
140
|
+
# - edge -> node scatter operations / nodewise operations (e.g. in `nequip/nn/interaction_block.py`)
|
|
141
|
+
# - nodewise operations that involve `atom_types` (since `atom_types` is `num_local + num_ghost`), e.g. in `PerTypeScaleShift` and `ZBL`.
|
|
142
|
+
|
|
143
|
+
# TODO: we have yet to exploit per-edge-type cutoffs by pruning the edge vectors and neighborlist
|
|
144
|
+
# make sure edge vectors `requires_grad`
|
|
145
|
+
edge_vectors = (
|
|
146
|
+
torch.as_tensor(lmp_data.rij, dtype=torch.float64).to(self.device).float()
|
|
147
|
+
)
|
|
148
|
+
edge_vectors.requires_grad_(True)
|
|
149
|
+
edge_index = torch.vstack(
|
|
150
|
+
[
|
|
151
|
+
torch.as_tensor(lmp_data.pair_i, dtype=torch.int64).to(self.device),
|
|
152
|
+
torch.as_tensor(lmp_data.pair_j, dtype=torch.int64).to(self.device),
|
|
153
|
+
],
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
atomic_type = torch.as_tensor(lmp_data.elems, dtype=torch.int64).to(self.device)
|
|
157
|
+
|
|
158
|
+
num_atoms = torch.tensor(
|
|
159
|
+
[lmp_data.nlocal, lmp_data.ntotal - lmp_data.nlocal], dtype=torch.int64
|
|
160
|
+
).to(self.device)
|
|
161
|
+
natoms = torch.tensor(lmp_data.nlocal, dtype=torch.int64).to(self.device)
|
|
162
|
+
# import pdb;pdb.set_trace()
|
|
163
|
+
# === run model ===
|
|
164
|
+
# run model and backwards for edge forces
|
|
165
|
+
batch = Batch(
|
|
166
|
+
edge_vec=edge_vectors,
|
|
167
|
+
atomic_numbers=(atomic_type[: lmp_data.nlocal].long()),
|
|
168
|
+
edge_index=edge_index,
|
|
169
|
+
num_atoms=num_atoms,
|
|
170
|
+
natoms=natoms,
|
|
171
|
+
lmp_data=lmp_data,
|
|
172
|
+
)
|
|
173
|
+
out = self.model(batch)
|
|
174
|
+
# correct sign convention for consistency with LAMMPS
|
|
175
|
+
edge_forces = torch.autograd.grad(
|
|
176
|
+
[out["energy"].sum()],
|
|
177
|
+
[edge_vectors],
|
|
178
|
+
)[0]
|
|
179
|
+
# === pass outputs to LAMMPS ===
|
|
180
|
+
# handle ghosts
|
|
181
|
+
atomic_energies = out["atomic_energy"].squeeze()
|
|
182
|
+
|
|
183
|
+
# = nequip_data_out[
|
|
184
|
+
# AtomicDataDict.PER_ATOM_ENERGY_KEY
|
|
185
|
+
# ].view(-1)
|
|
186
|
+
|
|
187
|
+
# shape-dependent control flow, but should be outside of compiled model
|
|
188
|
+
if atomic_energies.size(0) != lmp_data.nlocal:
|
|
189
|
+
atomic_energies = torch.narrow(atomic_energies, 0, 0, lmp_data.nlocal)
|
|
190
|
+
total_energy = torch.sum(atomic_energies)
|
|
191
|
+
else:
|
|
192
|
+
total_energy = torch.sum(atomic_energies)
|
|
193
|
+
if "forces" in self.normalizers:
|
|
194
|
+
edge_forces = self.normalizers["forces"](edge_forces)
|
|
195
|
+
if "energy" in self.normalizers:
|
|
196
|
+
total_energy = self.normalizers["energy"](total_energy)
|
|
197
|
+
if "energy" in self.elementrefs:
|
|
198
|
+
batch.batch = torch.zeros(natoms, device=self.device, dtype=torch.int)
|
|
199
|
+
total_energy = self.elementrefs["energy"](
|
|
200
|
+
total_energy.reshape(1, -1), batch
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# update LAMMPS variables
|
|
204
|
+
lmp_data.energy = total_energy
|
|
205
|
+
lmp_data.update_pair_forces_gpu(edge_forces.double())
|
|
206
|
+
|
|
207
|
+
def compute_descriptors(self, lmp_data):
|
|
208
|
+
pass
|
|
209
|
+
|
|
210
|
+
def compute_gradients(self, lmp_data):
|
|
211
|
+
pass
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import logging
|
|
3
|
+
import torch
|
|
4
|
+
import yaml
|
|
5
|
+
import math
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from fairchem.core.common.logger import Logger
|
|
9
|
+
from fairchem.core.common.registry import registry
|
|
10
|
+
|
|
11
|
+
import datetime
|
|
12
|
+
|
|
13
|
+
def parse_logs(update_dict):
|
|
14
|
+
ts = datetime.datetime.now().timestamp()
|
|
15
|
+
dt = datetime.datetime.fromtimestamp(ts)
|
|
16
|
+
formatted = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
17
|
+
ss = f"{formatted} "
|
|
18
|
+
if "epoch" in update_dict:
|
|
19
|
+
ep = update_dict["epoch"]
|
|
20
|
+
ss += f"epoch {ep:.1f}"
|
|
21
|
+
if "step" in update_dict:
|
|
22
|
+
step = update_dict["step"]
|
|
23
|
+
ss += f" (step {int(step)})"
|
|
24
|
+
if ss != "":
|
|
25
|
+
ss += ":"
|
|
26
|
+
for key, val in update_dict.items():
|
|
27
|
+
if key in ["epoch", "step"]:
|
|
28
|
+
continue
|
|
29
|
+
|
|
30
|
+
mse_metric = "mse" in key
|
|
31
|
+
if mse_metric:
|
|
32
|
+
key = key.replace("mse", "rmse")
|
|
33
|
+
|
|
34
|
+
if torch.is_tensor(val):
|
|
35
|
+
if mse_metric:
|
|
36
|
+
val = torch.sqrt(val)
|
|
37
|
+
ss += f" {key} {val.item():.5f}"
|
|
38
|
+
elif isinstance(val, float):
|
|
39
|
+
if key == "lr":
|
|
40
|
+
ss += f" {key} {val:.2e}"
|
|
41
|
+
else:
|
|
42
|
+
if mse_metric:
|
|
43
|
+
val = math.sqrt(val)
|
|
44
|
+
ss += f" {key} {val:.5f}"
|
|
45
|
+
else:
|
|
46
|
+
ss += f" {key} {val}"
|
|
47
|
+
return ss
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@registry.register_logger("files")
|
|
51
|
+
class FilesLogger(Logger):
|
|
52
|
+
def __init__(self, config):
|
|
53
|
+
super().__init__(config)
|
|
54
|
+
|
|
55
|
+
logdir = self.config["cmd"]["logs_dir"]
|
|
56
|
+
self.log_path = {"train" : os.path.join(logdir, "train.log")}
|
|
57
|
+
self.log_path["val"] = os.path.join(logdir, "val.log")
|
|
58
|
+
self.log_path["test"] = os.path.join(logdir, "test.log")
|
|
59
|
+
|
|
60
|
+
def watch(self, model):
|
|
61
|
+
logging.warning(
|
|
62
|
+
"Model gradient logging to files is not supported."
|
|
63
|
+
)
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
def log(self, update_dict, step=None, split=""):
|
|
67
|
+
assert split in ["train", "val", "test"], f"Split {split} is not supported"
|
|
68
|
+
outfile = open(self.log_path[split], 'a')
|
|
69
|
+
|
|
70
|
+
ss = parse_logs(update_dict)
|
|
71
|
+
outfile.write(ss + "\n")
|
|
72
|
+
outfile.close()
|
|
73
|
+
|
|
74
|
+
def log_plots(self, plots):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
def mark_preempting(self):
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
def log_model_training_info(self, model=None):
|
|
81
|
+
model_log_path = os.path.join(self.config["cmd"]["logs_dir"], "model_training_info.yml")
|
|
82
|
+
outfile = open(model_log_path, 'w')
|
|
83
|
+
outfile.write(yaml.dump(self.config, default_flow_style=False))
|
|
84
|
+
outfile.write("\n")
|
|
85
|
+
|
|
86
|
+
if model:
|
|
87
|
+
outfile.write(str(model)+ "\n")
|
|
88
|
+
outfile.write(f"model num of parameters: {model.num_params}\n")
|
|
89
|
+
outfile.close()
|
|
90
|
+
|
|
91
|
+
def log_final_metrics(self, table, time=None):
|
|
92
|
+
log_path = os.path.join(self.config["cmd"]["logs_dir"], "final_metrics.log")
|
|
93
|
+
outfile = open(log_path, 'w')
|
|
94
|
+
outfile.write(str(table)+"\n")
|
|
95
|
+
if time:
|
|
96
|
+
outfile.write(f"train() elapsed time: {time:.1f} sec ({time/3600.0:.1f} h)\n")
|
|
97
|
+
outfile.close()
|
|
98
|
+
def log_summary(self, summary_dict: dict[str, Any]) -> None:
|
|
99
|
+
logging.warning("log_summary for Files not supported")
|
|
100
|
+
|
|
101
|
+
def log_artifact(self, name: str, type: str, file_location: str) -> None:
|
|
102
|
+
logging.warning("log_artifact for Files not supported")
|