EquFlash 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. equflash-0.0.1/GGNN/__init__.py +0 -0
  2. equflash-0.0.1/GGNN/_cli.py +113 -0
  3. equflash-0.0.1/GGNN/common/__init__.py +0 -0
  4. equflash-0.0.1/GGNN/common/calculator.py +206 -0
  5. equflash-0.0.1/GGNN/common/lammps_mliap/__init__.py +0 -0
  6. equflash-0.0.1/GGNN/common/lammps_mliap/lmp_mliap_wrapper.py +211 -0
  7. equflash-0.0.1/GGNN/common/logger.py +102 -0
  8. equflash-0.0.1/GGNN/common/statistics.py +151 -0
  9. equflash-0.0.1/GGNN/common/torchsim_model.py +184 -0
  10. equflash-0.0.1/GGNN/common/utils.py +524 -0
  11. equflash-0.0.1/GGNN/datasets/__init__.py +0 -0
  12. equflash-0.0.1/GGNN/datasets/ase_arrays_dataset.py +56 -0
  13. equflash-0.0.1/GGNN/datasets/lmdb_dataset.py +72 -0
  14. equflash-0.0.1/GGNN/datasets/samplers.py +213 -0
  15. equflash-0.0.1/GGNN/main.py +7 -0
  16. equflash-0.0.1/GGNN/model/EquFlash/__init__.py +3 -0
  17. equflash-0.0.1/GGNN/model/EquFlash/equflash.py +218 -0
  18. equflash-0.0.1/GGNN/model/EquFlash/equflash_comp.py +264 -0
  19. equflash-0.0.1/GGNN/model/EquFlash/equflash_pol.py +476 -0
  20. equflash-0.0.1/GGNN/model/EquFlash/nn/__init__.py +0 -0
  21. equflash-0.0.1/GGNN/model/EquFlash/nn/_ghost_exchange_base.py +79 -0
  22. equflash-0.0.1/GGNN/model/EquFlash/nn/activation.py +24 -0
  23. equflash-0.0.1/GGNN/model/EquFlash/nn/atomic_reduce.py +227 -0
  24. equflash-0.0.1/GGNN/model/EquFlash/nn/convolution.py +224 -0
  25. equflash-0.0.1/GGNN/model/EquFlash/nn/edge_embedding.py +231 -0
  26. equflash-0.0.1/GGNN/model/EquFlash/nn/equivariant_gate.py +56 -0
  27. equflash-0.0.1/GGNN/model/EquFlash/nn/feed_foward.py +300 -0
  28. equflash-0.0.1/GGNN/model/EquFlash/nn/force_output.py +230 -0
  29. equflash-0.0.1/GGNN/model/EquFlash/nn/interaction_blocks.py +103 -0
  30. equflash-0.0.1/GGNN/model/EquFlash/nn/node_embedding.py +43 -0
  31. equflash-0.0.1/GGNN/model/EquFlash/nn/scale.py +37 -0
  32. equflash-0.0.1/GGNN/model/EquFlash/nn/symmetric_contraction.py +32 -0
  33. equflash-0.0.1/GGNN/model/EquFlash/nn/util.py +17 -0
  34. equflash-0.0.1/GGNN/model/EquFlash/statistics.py +151 -0
  35. equflash-0.0.1/GGNN/model/EquFlashV2/__init__.py +3 -0
  36. equflash-0.0.1/GGNN/model/EquFlashV2/equflash_v2.py +263 -0
  37. equflash-0.0.1/GGNN/model/EquFlashV2/equflash_v2_comp.py +298 -0
  38. equflash-0.0.1/GGNN/model/EquFlashV2/nn/__init__.py +0 -0
  39. equflash-0.0.1/GGNN/model/EquFlashV2/nn/_ghost_exchange_base.py +79 -0
  40. equflash-0.0.1/GGNN/model/EquFlashV2/nn/activation.py +24 -0
  41. equflash-0.0.1/GGNN/model/EquFlashV2/nn/atomic_reduce.py +60 -0
  42. equflash-0.0.1/GGNN/model/EquFlashV2/nn/convolution.py +229 -0
  43. equflash-0.0.1/GGNN/model/EquFlashV2/nn/edge_embedding.py +171 -0
  44. equflash-0.0.1/GGNN/model/EquFlashV2/nn/force_output.py +87 -0
  45. equflash-0.0.1/GGNN/model/EquFlashV2/nn/interaction_blocks.py +222 -0
  46. equflash-0.0.1/GGNN/model/EquFlashV2/nn/linear.py +209 -0
  47. equflash-0.0.1/GGNN/model/EquFlashV2/nn/node_embedding.py +45 -0
  48. equflash-0.0.1/GGNN/model/EquFlashV2/nn/nonlinears.py +305 -0
  49. equflash-0.0.1/GGNN/model/EquFlashV2/nn/normalization.py +136 -0
  50. equflash-0.0.1/GGNN/model/EquFlashV2/nn/scale.py +37 -0
  51. equflash-0.0.1/GGNN/model/EquFlashV2/nn/skip.py +77 -0
  52. equflash-0.0.1/GGNN/model/EquFlashV2/nn/util.py +30 -0
  53. equflash-0.0.1/GGNN/model/EquFlashV2/statistics.py +151 -0
  54. equflash-0.0.1/GGNN/model/__init__.py +0 -0
  55. equflash-0.0.1/GGNN/modules/__init__.py +0 -0
  56. equflash-0.0.1/GGNN/modules/evaluator_ext.py +610 -0
  57. equflash-0.0.1/GGNN/modules/loss.py +17 -0
  58. equflash-0.0.1/GGNN/modules/loss_ext.py +805 -0
  59. equflash-0.0.1/GGNN/preprocessing/__init__.py +0 -0
  60. equflash-0.0.1/GGNN/preprocessing/atoms_arrays_to_graphs.py +72 -0
  61. equflash-0.0.1/GGNN/preprocessing/atoms_to_graphs.py +105 -0
  62. equflash-0.0.1/GGNN/scripts/__init__.py +0 -0
  63. equflash-0.0.1/GGNN/scripts/create_lmp_mliap_file.py +44 -0
  64. equflash-0.0.1/GGNN/scripts/matbench_submission/equflashv2-45M-oam.yml +85 -0
  65. equflash-0.0.1/GGNN/scripts/matbench_submission/install.sh +5 -0
  66. equflash-0.0.1/GGNN/scripts/matbench_submission/merge_results.py +35 -0
  67. equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/ml_relaxation.py +139 -0
  68. equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizable.py +712 -0
  69. equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizers/base_optimizer.py +82 -0
  70. equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizers/fire.py +157 -0
  71. equflash-0.0.1/GGNN/scripts/matbench_submission/relaxation/optimizers/lbfgs.py +142 -0
  72. equflash-0.0.1/GGNN/scripts/matbench_submission/requirements.txt +35 -0
  73. equflash-0.0.1/GGNN/scripts/matbench_submission/test_equflash_discovery.py +204 -0
  74. equflash-0.0.1/GGNN/scripts/matbench_submission/test_equflash_kappa.py +367 -0
  75. equflash-0.0.1/GGNN/scripts/matbench_submission/thermal_conductivity.py +80 -0
  76. equflash-0.0.1/GGNN/trainer/__init__.py +13 -0
  77. equflash-0.0.1/GGNN/trainer/trainer.py +1754 -0
  78. equflash-0.0.1/PKG-INFO +350 -0
  79. equflash-0.0.1/README.md +312 -0
  80. equflash-0.0.1/pyproject.toml +77 -0
File without changes
@@ -0,0 +1,113 @@
1
+ """
2
+ This file is derived from fairchem-core.
3
+ Original source: https://github.com/facebookresearch/fairchem
4
+
5
+ Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ Licensed under the MIT License. See LICENSE for details.
7
+
8
+ Modifications copyright (c) 2026 Samsung Electronics.
9
+ Licensed under CC BY-NC-SA 4.0. See LICENSE for details.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import logging
16
+ from typing import TYPE_CHECKING
17
+
18
+ from submitit import AutoExecutor
19
+ from submitit.helpers import Checkpointable, DelayedSubmission
20
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
21
+
22
+ from fairchem.core.common import distutils
23
+ from fairchem.core.common.flags import flags
24
+ from fairchem.core.common.utils import (
25
+ build_config,
26
+ create_grid,
27
+ save_experiment_log,
28
+ setup_logging,
29
+ )
30
+ from GGNN.common.utils import new_trainer_context
31
+
32
+ if TYPE_CHECKING:
33
+ import argparse
34
+
35
+
36
+ class Runner(Checkpointable):
37
+ def __init__(self) -> None:
38
+ self.config = None
39
+
40
+ def __call__(self, config: dict) -> None:
41
+ with new_trainer_context(config=config) as ctx:
42
+ self.config = ctx.config
43
+ self.task = ctx.task
44
+ self.trainer = ctx.trainer
45
+ self.task.setup(self.trainer)
46
+ self.task.run()
47
+
48
+ def checkpoint(self, *args, **kwargs):
49
+ new_runner = Runner()
50
+ self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
51
+ self.config["checkpoint"] = self.task.chkpt_path
52
+ self.config["timestamp_id"] = self.trainer.timestamp_id
53
+ if self.trainer.logger is not None:
54
+ self.trainer.logger.mark_preempting()
55
+ logging.info(
56
+ f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}'
57
+ )
58
+ return DelayedSubmission(new_runner, self.config)
59
+
60
+
61
+ def runner_wrapper(config: dict):
62
+ Runner()(config)
63
+
64
+
65
+ def main(
66
+ args: argparse.Namespace | None = None, override_args: list[str] | None = None
67
+ ):
68
+ """Run the main fairchem program."""
69
+ setup_logging()
70
+
71
+ if args is None:
72
+ parser: argparse.ArgumentParser = flags.get_parser()
73
+ args, override_args = parser.parse_known_args()
74
+
75
+ assert (
76
+ args.num_gpus > 0
77
+ ), "num_gpus is used to determine number ranks, so it must be at least 1"
78
+ config = build_config(args, override_args)
79
+
80
+ if args.submit: # Run on cluster
81
+ slurm_add_params = config.get("slurm", None) # additional slurm arguments
82
+ configs = create_grid(config, args.sweep_yml) if args.sweep_yml else [config]
83
+
84
+ logging.info(f"Submitting {len(configs)} jobs")
85
+ executor = AutoExecutor(folder=args.logdir / "%j", slurm_max_num_timeout=3)
86
+ executor.update_parameters(
87
+ name=args.identifier,
88
+ mem_gb=args.slurm_mem,
89
+ timeout_min=args.slurm_timeout * 60,
90
+ slurm_partition=args.slurm_partition,
91
+ gpus_per_node=args.num_gpus,
92
+ cpus_per_task=(config["optim"]["num_workers"] + 1),
93
+ tasks_per_node=args.num_gpus,
94
+ nodes=args.num_nodes,
95
+ slurm_additional_parameters=slurm_add_params,
96
+ slurm_qos=args.slurm_qos,
97
+ slurm_account=args.slurm_account,
98
+ )
99
+ for config in configs:
100
+ config["slurm"] = copy.deepcopy(executor.parameters)
101
+ config["slurm"]["folder"] = str(executor.folder)
102
+ jobs = executor.map_array(Runner(), configs)
103
+ logging.info(f"Submitted jobs: {', '.join([job.job_id for job in jobs])}")
104
+ log_file = save_experiment_log(args, jobs, configs)
105
+ logging.info(f"Experiment log saved to: {log_file}")
106
+
107
+ else: # Run locally on a single node, n-processes
108
+ runner_wrapper(config)
109
+
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
File without changes
@@ -0,0 +1,206 @@
1
+ """
2
+ This file is derived from fairchem-core.
3
+ Original source: https://github.com/facebookresearch/fairchem
4
+
5
+ Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ Licensed under the MIT License. See LICENSE for details.
7
+
8
+ Modifications copyright (c) 2026 Samsung Electronics.
9
+ Licensed under CC BY-NC-SA 4.0. See LICENSE for details.
10
+ """
11
+
12
+ from fairchem.core.common.relaxation.ase_utils import OCPCalculator
13
+ from ..trainer.trainer import Trainer
14
+
15
+ import copy
16
+ import logging
17
+ from types import MappingProxyType
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+ from ase import Atoms
22
+ from ase.calculators.calculator import Calculator
23
+ from ase.calculators.singlepoint import SinglePointCalculator
24
+ from ase.constraints import FixAtoms
25
+ from ase.geometry import wrap_positions
26
+
27
+ from fairchem.core.common.registry import registry
28
+ from fairchem.core.common.utils import (
29
+ load_config,
30
+ setup_imports,
31
+ setup_logging,
32
+ update_config,
33
+ )
34
+ from fairchem.core.datasets import data_list_collater
35
+ from fairchem.core.models.model_registry import model_name_to_local_file
36
+ from fairchem.core.preprocessing import AtomsToGraphs
37
+ from collections import OrderedDict
38
+
39
+
40
+ def convert_compiled_ckpt(ckpt):
41
+ if ckpt["config"]["model"]["name"].endswith("_comp"):
42
+ ckpt["config"]["model"]["name"] = ckpt["config"]["model"]["name"].replace(
43
+ "_comp", ""
44
+ )
45
+ new_dict = OrderedDict()
46
+ for k, v in ckpt["state_dict"].items():
47
+ if k.startswith("module.model") and "z_to_onehot_tensor" not in k:
48
+ new_dict[k.replace(".model.", ".")] = v
49
+ ckpt["state_dict"] = new_dict
50
+ if "task" in ckpt["config"]:
51
+ ckpt["config"].pop("task")
52
+ return ckpt
53
+
54
+
55
+ if TYPE_CHECKING:
56
+ from pathlib import Path
57
+
58
+ from torch_geometric.data import Batch
59
+
60
+
61
+ class UCalculator(OCPCalculator):
62
+ """ASE based calculator using an OCP model"""
63
+
64
+ def __init__(
65
+ self,
66
+ config_yml: str | None = None,
67
+ checkpoint_path: str | None = None,
68
+ model_name: str | None = None,
69
+ local_cache: str | None = None,
70
+ trainer: str | None = None,
71
+ cpu: bool = True,
72
+ seed: int | None = None,
73
+ only_output: list[str] | None = None,
74
+ ) -> None:
75
+
76
+ setup_imports()
77
+ setup_logging()
78
+ Calculator.__init__(self)
79
+
80
+ if model_name is not None:
81
+ if checkpoint_path is not None:
82
+ raise RuntimeError(
83
+ "model_name and checkpoint_path were both specified, please use only one at a time"
84
+ )
85
+ if local_cache is None:
86
+ raise NotImplementedError(
87
+ "Local cache must be set when specifying a model name"
88
+ )
89
+ checkpoint_path = model_name_to_local_file(
90
+ model_name=model_name, local_cache=local_cache
91
+ )
92
+
93
+ # Either the config path or the checkpoint path needs to be provided
94
+ assert config_yml or checkpoint_path is not None
95
+
96
+ checkpoint = None
97
+ if config_yml is not None:
98
+ if isinstance(config_yml, str):
99
+ config, duplicates_warning, duplicates_error = load_config(config_yml)
100
+ if len(duplicates_warning) > 0:
101
+ logging.warning(
102
+ f"Overwritten config parameters from included configs "
103
+ f"(non-included parameters take precedence): {duplicates_warning}"
104
+ )
105
+ if len(duplicates_error) > 0:
106
+ raise ValueError(
107
+ f"Conflicting (duplicate) parameters in simultaneously "
108
+ f"included configs: {duplicates_error}"
109
+ )
110
+ else:
111
+ config = config_yml
112
+ if config["model"]["name"].startswith("compiled"):
113
+ config["model"]["name"] = config["model"]["name"].replace(
114
+ "compiled_", ""
115
+ )
116
+
117
+ # Only keeps the train data that might have normalizer values
118
+ if isinstance(config["dataset"], list):
119
+ config["dataset"] = config["dataset"][0]
120
+ elif isinstance(config["dataset"], dict):
121
+ config["dataset"] = config["dataset"].get("train", None)
122
+ else:
123
+ # Loads the config from the checkpoint directly (always on CPU).
124
+ checkpoint = torch.load(
125
+ checkpoint_path, map_location=torch.device("cpu"), weights_only=False
126
+ )
127
+ checkpoint = convert_compiled_ckpt(checkpoint)
128
+ config = checkpoint["config"]
129
+
130
+ if trainer is not None:
131
+ config["trainer"] = trainer
132
+ else:
133
+ config["trainer"] = config.get("trainer", "default")
134
+
135
+ if "model_attributes" in config:
136
+ config["model_attributes"]["name"] = config.pop("model")
137
+ config["model"] = config["model_attributes"]
138
+
139
+ # Calculate the edge indices on the fly
140
+ config["model"]["otf_graph"] = True
141
+
142
+ ### backwards compatability with OCP v<2.0
143
+ config = update_config(config)
144
+
145
+ self.config = copy.deepcopy(config)
146
+ self.config["checkpoint"] = str(checkpoint_path)
147
+ del config["dataset"]["src"]
148
+
149
+ # some models that are published have configs that include tasks
150
+ # which are not output by the model
151
+ if only_output is not None:
152
+ assert isinstance(
153
+ only_output, list
154
+ ), "only output must be a list of targets to output"
155
+ for key in only_output:
156
+ assert (
157
+ key in config["outputs"]
158
+ ), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}"
159
+ remove_outputs = set(config["outputs"].keys()) - set(only_output)
160
+ for key in remove_outputs:
161
+ config["outputs"].pop(key)
162
+
163
+ trainer_cls = Trainer
164
+
165
+ self.trainer = trainer_cls(
166
+ task=config.get("task", {}),
167
+ model=config["model"],
168
+ dataset=[config["dataset"]],
169
+ outputs=config["outputs"],
170
+ loss_functions=config["loss_functions"],
171
+ evaluation_metrics=config["evaluation_metrics"],
172
+ optimizer=config["optim"],
173
+ identifier="",
174
+ slurm=config.get("slurm", {}),
175
+ local_rank=config.get("local_rank", 0),
176
+ is_debug=config.get("is_debug", True),
177
+ cpu=cpu,
178
+ amp=config.get("amp", False),
179
+ inference_only=True,
180
+ )
181
+
182
+ if checkpoint_path is not None:
183
+ self.load_checkpoint(checkpoint_path=checkpoint_path, checkpoint=checkpoint)
184
+
185
+ seed = seed if seed is not None else self.trainer.config["cmd"]["seed"]
186
+ if seed is None:
187
+ logging.warning(
188
+ "No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run"
189
+ )
190
+ else:
191
+ self.trainer.set_seed(seed)
192
+
193
+ if "otf_graph" in config["model"]: # temporalily impemented.
194
+ r_edges = not config["model"]["otf_graph"]
195
+ else:
196
+ r_edges = not self.trainer.model.otf_graph
197
+
198
+ self.a2g = AtomsToGraphs(
199
+ r_energy=False,
200
+ r_forces=False,
201
+ r_distances=False,
202
+ r_pbc=True,
203
+ # r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model
204
+ r_edges=r_edges,
205
+ )
206
+ self.implemented_properties = list(self.config["outputs"].keys())
File without changes
@@ -0,0 +1,211 @@
1
+ # This file is a part of the `nequip` package. Please see LICENSE and README at the root for information on using it.
2
+
3
+ import torch
4
+ import os
5
+
6
+
7
+ try:
8
+ from lammps.mliap.mliap_unified_abc import MLIAPUnified
9
+ except ModuleNotFoundError:
10
+ raise ImportError(
11
+ "LAMMPS ML-IAP has to be installed in the Python environment for NequIP's ML-IAP integration. "
12
+ "See https://nequip.readthedocs.io/en/latest/integrations/lammps/mliap.html for installation instructions."
13
+ )
14
+
15
+ from typing import List
16
+ from GGNN.common.calculator import UCalculator,convert_compiled_ckpt
17
+ from GGNN.model.SevenNet.sevennet import SevenNet, convert_compiled_to_original_sevennet
18
+
19
+ from GGNN.model.EquFlash import EquFlash
20
+ from fairchem.core.common.utils import (
21
+ match_state_dict,
22
+ )
23
+ from fairchem.core.common.registry import registry
24
+ from fairchem.core.modules.normalization.element_references import (
25
+ create_element_references,
26
+ )
27
+ from fairchem.core.modules.normalization.normalizer import (
28
+ create_normalizer,
29
+ )
30
+
31
+ from torch_geometric.data import Batch
32
+
33
+
34
+ class GGNNLAMMPSMLIAPWrapper(MLIAPUnified):
35
+ """LAMMPS-MLIAP interface for NequIP framework models."""
36
+
37
+ def __init__(
38
+ self,
39
+ ckpt_path: str,
40
+ **kwargs,
41
+ ):
42
+ # this is a white lie, unsure if strictly necessary, but just in case
43
+ super().__init__()
44
+ self.ckpt_path = ckpt_path
45
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
46
+ self.model = None
47
+ self.device = None
48
+ # model_name = ckpt["config"]["model"].pop("name")
49
+ # self.model = registry.get_model_class(model_name)(self.ckpt["config"]["model"])
50
+ # self.model = registry.get_model_class(model_name)(ckpt["config"]["model"])
51
+ # import pdb;pdb.set_trace()
52
+ # to placate the interface
53
+ self.nparams = 1
54
+ self.ndescriptors = 1
55
+
56
+ # === set model-depnedent params ===
57
+ # what is rcutfac
58
+ # self.rcutfac = 0.5 * float(model.metadata[graph_model.R_MAX_KEY])
59
+ self.rcutfac = 3.0
60
+ import ase
61
+
62
+ self.element_types = (
63
+ ase.data.chemical_symbols
64
+ ) # ckpt["config"]["model"]["chemical_species"]
65
+ self.ckpt = ckpt
66
+
67
+ def compute_forces(self, lmp_data):
68
+ # === lazily load model ===
69
+ if self.model == None:
70
+ self.device = (
71
+ "cuda" if "kokkos" in lmp_data.__class__.__module__.lower() else "cpu"
72
+ )
73
+ self.ckpt = convert_compiled_ckpt(self.ckpt)
74
+ model_name = self.ckpt["config"]["model"].pop("name")
75
+ # self.model = registry.get_model_class(model_name)(self.ckpt["config"]["model"])
76
+ if "conv_kwargs" not in self.ckpt["config"]["model"].keys():
77
+ self.ckpt["config"]["model"]["conv_kwargs"] = {}
78
+ self.ckpt["config"]["model"]["conv_kwargs"]["use_lammps_mliap"] = True
79
+
80
+
81
+ self.model = registry.get_model_class(model_name)(
82
+ self.ckpt["config"]["model"], lammps_mliap=True
83
+ ).to(self.device)
84
+
85
+ new_dict = match_state_dict(
86
+ self.model.state_dict(), self.ckpt["state_dict"]
87
+ )
88
+ self.model.load_state_dict(new_dict, strict=False)
89
+ # print(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
90
+ rank = eval(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
91
+ if rank == 0:
92
+ print(self.model)
93
+ self.normalizers = {}
94
+ for key, state_dict in self.ckpt.get("normalizers", {}).items():
95
+ ### Convert old normalizer keys to new target keys
96
+ if key == "target":
97
+ target_key = "energy"
98
+ elif key == "grad_target":
99
+ target_key = "forces"
100
+ else:
101
+ target_key = key
102
+
103
+ if target_key not in self.normalizers:
104
+ self.normalizers[target_key] = create_normalizer(
105
+ state_dict=state_dict
106
+ )
107
+ self.normalizers[target_key].to(self.device)
108
+ self.elementrefs = {}
109
+ for key, state_dict in self.ckpt.get("elementrefs", {}).items():
110
+ if key not in self.elementrefs:
111
+ self.elementrefs[key] = create_element_references(
112
+ state_dict=state_dict
113
+ )
114
+ else:
115
+ mkeys = self.elementrefs[key].load_state_dict(state_dict)
116
+ assert len(mkeys.missing_keys) == 0
117
+ assert len(mkeys.unexpected_keys) == 0
118
+
119
+ self.elementrefs[key].to(self.device)
120
+ self.model.regress_forces = False
121
+ self.model.regress_stress = False
122
+
123
+ if lmp_data.nlocal == 0 or lmp_data.npairs <= 1:
124
+ return
125
+
126
+ # === create input data ===
127
+
128
+ # NOTE
129
+ # This LAMMPS ML-IAP integration introduces a new dimension of having `num_local` vs `num_local + num_ghost` number of nodes.
130
+ # There are three crucial dimensions to be aware of `num_edges`, `num_local`, `num_local + num_ghost`.
131
+ # The following input tensors have the following shapes.
132
+ # - `edge_vectors`: (num_edges, 3)
133
+ # - `edge_idxs`: (2, num_edges)
134
+ # - `atom_types`: (num_local + num_ghost)
135
+
136
+ # This LAMMPS ML-IAP wrapper can handle output `atomic_energy` having either shape `num_local` or `num_local + num_ghost` based on the (uncompiled) size check.
137
+
138
+ # Models can perform optimizations based on an understanding of when ghost atoms matter or not and are responsible for carefully handling internal shape logic.
139
+ # Examples include:
140
+ # - edge -> node scatter operations / nodewise operations (e.g. in `nequip/nn/interaction_block.py`)
141
+ # - nodewise operations that involve `atom_types` (since `atom_types` is `num_local + num_ghost`), e.g. in `PerTypeScaleShift` and `ZBL`.
142
+
143
+ # TODO: we have yet to exploit per-edge-type cutoffs by pruning the edge vectors and neighborlist
144
+ # make sure edge vectors `requires_grad`
145
+ edge_vectors = (
146
+ torch.as_tensor(lmp_data.rij, dtype=torch.float64).to(self.device).float()
147
+ )
148
+ edge_vectors.requires_grad_(True)
149
+ edge_index = torch.vstack(
150
+ [
151
+ torch.as_tensor(lmp_data.pair_i, dtype=torch.int64).to(self.device),
152
+ torch.as_tensor(lmp_data.pair_j, dtype=torch.int64).to(self.device),
153
+ ],
154
+ )
155
+
156
+ atomic_type = torch.as_tensor(lmp_data.elems, dtype=torch.int64).to(self.device)
157
+
158
+ num_atoms = torch.tensor(
159
+ [lmp_data.nlocal, lmp_data.ntotal - lmp_data.nlocal], dtype=torch.int64
160
+ ).to(self.device)
161
+ natoms = torch.tensor(lmp_data.nlocal, dtype=torch.int64).to(self.device)
162
+ # import pdb;pdb.set_trace()
163
+ # === run model ===
164
+ # run model and backwards for edge forces
165
+ batch = Batch(
166
+ edge_vec=edge_vectors,
167
+ atomic_numbers=(atomic_type[: lmp_data.nlocal].long()),
168
+ edge_index=edge_index,
169
+ num_atoms=num_atoms,
170
+ natoms=natoms,
171
+ lmp_data=lmp_data,
172
+ )
173
+ out = self.model(batch)
174
+ # correct sign convention for consistency with LAMMPS
175
+ edge_forces = torch.autograd.grad(
176
+ [out["energy"].sum()],
177
+ [edge_vectors],
178
+ )[0]
179
+ # === pass outputs to LAMMPS ===
180
+ # handle ghosts
181
+ atomic_energies = out["atomic_energy"].squeeze()
182
+
183
+ # = nequip_data_out[
184
+ # AtomicDataDict.PER_ATOM_ENERGY_KEY
185
+ # ].view(-1)
186
+
187
+ # shape-dependent control flow, but should be outside of compiled model
188
+ if atomic_energies.size(0) != lmp_data.nlocal:
189
+ atomic_energies = torch.narrow(atomic_energies, 0, 0, lmp_data.nlocal)
190
+ total_energy = torch.sum(atomic_energies)
191
+ else:
192
+ total_energy = torch.sum(atomic_energies)
193
+ if "forces" in self.normalizers:
194
+ edge_forces = self.normalizers["forces"](edge_forces)
195
+ if "energy" in self.normalizers:
196
+ total_energy = self.normalizers["energy"](total_energy)
197
+ if "energy" in self.elementrefs:
198
+ batch.batch = torch.zeros(natoms, device=self.device, dtype=torch.int)
199
+ total_energy = self.elementrefs["energy"](
200
+ total_energy.reshape(1, -1), batch
201
+ )
202
+
203
+ # update LAMMPS variables
204
+ lmp_data.energy = total_energy
205
+ lmp_data.update_pair_forces_gpu(edge_forces.double())
206
+
207
+ def compute_descriptors(self, lmp_data):
208
+ pass
209
+
210
+ def compute_gradients(self, lmp_data):
211
+ pass
@@ -0,0 +1,102 @@
1
+ import os
2
+ import logging
3
+ import torch
4
+ import yaml
5
+ import math
6
+ from typing import Any
7
+
8
+ from fairchem.core.common.logger import Logger
9
+ from fairchem.core.common.registry import registry
10
+
11
+ import datetime
12
+
13
+ def parse_logs(update_dict):
14
+ ts = datetime.datetime.now().timestamp()
15
+ dt = datetime.datetime.fromtimestamp(ts)
16
+ formatted = dt.strftime("%Y-%m-%d %H:%M:%S")
17
+ ss = f"{formatted} "
18
+ if "epoch" in update_dict:
19
+ ep = update_dict["epoch"]
20
+ ss += f"epoch {ep:.1f}"
21
+ if "step" in update_dict:
22
+ step = update_dict["step"]
23
+ ss += f" (step {int(step)})"
24
+ if ss != "":
25
+ ss += ":"
26
+ for key, val in update_dict.items():
27
+ if key in ["epoch", "step"]:
28
+ continue
29
+
30
+ mse_metric = "mse" in key
31
+ if mse_metric:
32
+ key = key.replace("mse", "rmse")
33
+
34
+ if torch.is_tensor(val):
35
+ if mse_metric:
36
+ val = torch.sqrt(val)
37
+ ss += f" {key} {val.item():.5f}"
38
+ elif isinstance(val, float):
39
+ if key == "lr":
40
+ ss += f" {key} {val:.2e}"
41
+ else:
42
+ if mse_metric:
43
+ val = math.sqrt(val)
44
+ ss += f" {key} {val:.5f}"
45
+ else:
46
+ ss += f" {key} {val}"
47
+ return ss
48
+
49
+
50
+ @registry.register_logger("files")
51
+ class FilesLogger(Logger):
52
+ def __init__(self, config):
53
+ super().__init__(config)
54
+
55
+ logdir = self.config["cmd"]["logs_dir"]
56
+ self.log_path = {"train" : os.path.join(logdir, "train.log")}
57
+ self.log_path["val"] = os.path.join(logdir, "val.log")
58
+ self.log_path["test"] = os.path.join(logdir, "test.log")
59
+
60
+ def watch(self, model):
61
+ logging.warning(
62
+ "Model gradient logging to files is not supported."
63
+ )
64
+ return False
65
+
66
+ def log(self, update_dict, step=None, split=""):
67
+ assert split in ["train", "val", "test"], f"Split {split} is not supported"
68
+ outfile = open(self.log_path[split], 'a')
69
+
70
+ ss = parse_logs(update_dict)
71
+ outfile.write(ss + "\n")
72
+ outfile.close()
73
+
74
+ def log_plots(self, plots):
75
+ pass
76
+
77
+ def mark_preempting(self):
78
+ pass
79
+
80
+ def log_model_training_info(self, model=None):
81
+ model_log_path = os.path.join(self.config["cmd"]["logs_dir"], "model_training_info.yml")
82
+ outfile = open(model_log_path, 'w')
83
+ outfile.write(yaml.dump(self.config, default_flow_style=False))
84
+ outfile.write("\n")
85
+
86
+ if model:
87
+ outfile.write(str(model)+ "\n")
88
+ outfile.write(f"model num of parameters: {model.num_params}\n")
89
+ outfile.close()
90
+
91
+ def log_final_metrics(self, table, time=None):
92
+ log_path = os.path.join(self.config["cmd"]["logs_dir"], "final_metrics.log")
93
+ outfile = open(log_path, 'w')
94
+ outfile.write(str(table)+"\n")
95
+ if time:
96
+ outfile.write(f"train() elapsed time: {time:.1f} sec ({time/3600.0:.1f} h)\n")
97
+ outfile.close()
98
+ def log_summary(self, summary_dict: dict[str, Any]) -> None:
99
+ logging.warning("log_summary for Files not supported")
100
+
101
+ def log_artifact(self, name: str, type: str, file_location: str) -> None:
102
+ logging.warning("log_artifact for Files not supported")