atst-tools 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. atst_tools/__init__.py +51 -0
  2. atst_tools/calculators/__init__.py +0 -0
  3. atst_tools/calculators/abacuslite_backend.py +81 -0
  4. atst_tools/calculators/base.py +10 -0
  5. atst_tools/calculators/dp.py +112 -0
  6. atst_tools/calculators/factory.py +126 -0
  7. atst_tools/external/ASE_interface/__init__.py +1 -0
  8. atst_tools/external/ASE_interface/abacuslite/__init__.py +13 -0
  9. atst_tools/external/ASE_interface/abacuslite/core.py +655 -0
  10. atst_tools/external/ASE_interface/abacuslite/io/__init__.py +0 -0
  11. atst_tools/external/ASE_interface/abacuslite/io/generalio.py +830 -0
  12. atst_tools/external/ASE_interface/abacuslite/io/latestio.py +677 -0
  13. atst_tools/external/ASE_interface/abacuslite/io/legacyio.py +1193 -0
  14. atst_tools/external/ASE_interface/abacuslite/utils/__init__.py +0 -0
  15. atst_tools/external/ASE_interface/abacuslite/utils/ksampling.py +260 -0
  16. atst_tools/external/ASE_interface/examples/bandstructure.py +82 -0
  17. atst_tools/external/ASE_interface/examples/cellrelax.py +76 -0
  18. atst_tools/external/ASE_interface/examples/constraintmd.py +101 -0
  19. atst_tools/external/ASE_interface/examples/dos.py +65 -0
  20. atst_tools/external/ASE_interface/examples/md.py +72 -0
  21. atst_tools/external/ASE_interface/examples/metadynamics.py +146 -0
  22. atst_tools/external/ASE_interface/examples/neb.py +111 -0
  23. atst_tools/external/ASE_interface/examples/relax.py +67 -0
  24. atst_tools/external/ASE_interface/examples/scf.py +56 -0
  25. atst_tools/external/ASE_interface/examples/soc.py +63 -0
  26. atst_tools/external/ASE_interface/tests/band.py +61 -0
  27. atst_tools/external/ASE_interface/tests/magnetic.py +129 -0
  28. atst_tools/external/ASE_interface/tests/md.py +50 -0
  29. atst_tools/external/ASE_interface/tests/relax.py +47 -0
  30. atst_tools/external/ASE_interface/tests/scf.py +39 -0
  31. atst_tools/external/__init__.py +1 -0
  32. atst_tools/mep/__init__.py +4 -0
  33. atst_tools/mep/autoneb.py +535 -0
  34. atst_tools/mep/dimer.py +241 -0
  35. atst_tools/mep/neb.py +169 -0
  36. atst_tools/mep/sella.py +105 -0
  37. atst_tools/scripts/__init__.py +0 -0
  38. atst_tools/scripts/cli.py +543 -0
  39. atst_tools/scripts/main.py +629 -0
  40. atst_tools/scripts/neb_make.py +65 -0
  41. atst_tools/scripts/neb_post.py +46 -0
  42. atst_tools/utils/__init__.py +0 -0
  43. atst_tools/utils/abacus_io.py +231 -0
  44. atst_tools/utils/analysis.py +57 -0
  45. atst_tools/utils/config.py +89 -0
  46. atst_tools/utils/config_docs.py +189 -0
  47. atst_tools/utils/config_schema.py +522 -0
  48. atst_tools/utils/idpp.py +422 -0
  49. atst_tools/utils/io.py +61 -0
  50. atst_tools/utils/neb_endpoints.py +125 -0
  51. atst_tools/utils/post.py +134 -0
  52. atst_tools/utils/restart_helpers.py +157 -0
  53. atst_tools/utils/thermochemistry.py +109 -0
  54. atst_tools/workflows/__init__.py +0 -0
  55. atst_tools/workflows/d2s.py +299 -0
  56. atst_tools/workflows/irc.py +163 -0
  57. atst_tools/workflows/relax.py +116 -0
  58. atst_tools/workflows/vibration.py +136 -0
  59. atst_tools-2.0.0.dist-info/METADATA +304 -0
  60. atst_tools-2.0.0.dist-info/RECORD +63 -0
  61. atst_tools-2.0.0.dist-info/WHEEL +5 -0
  62. atst_tools-2.0.0.dist-info/entry_points.txt +2 -0
  63. atst_tools-2.0.0.dist-info/top_level.txt +1 -0
atst_tools/__init__.py ADDED
@@ -0,0 +1,51 @@
1
+ """ATST-Tools package metadata helpers."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version as _metadata_version
4
+ from pathlib import Path
5
+
6
+
7
+ def _source_tree_version() -> str | None:
8
+ """Return the project version from a local source-tree ``pyproject.toml``."""
9
+ for parent in Path(__file__).resolve().parents:
10
+ pyproject = parent / "pyproject.toml"
11
+ if not pyproject.exists():
12
+ continue
13
+
14
+ in_project = False
15
+ for raw_line in pyproject.read_text(encoding="utf-8").splitlines():
16
+ line = raw_line.strip()
17
+ if line == "[project]":
18
+ in_project = True
19
+ continue
20
+ if in_project and line.startswith("["):
21
+ return None
22
+ if in_project and line.startswith("version"):
23
+ _, value = line.split("=", 1)
24
+ return value.strip().strip('"').strip("'")
25
+ return None
26
+
27
+
28
+ def package_version() -> str:
29
+ """Return the ATST-Tools package version.
30
+
31
+ The package version is governed by ``pyproject.toml``. Source-tree runs
32
+ read that file directly, while installed-package runs use distribution
33
+ metadata generated from the same project version.
34
+
35
+ Returns:
36
+ Project package version, or ``"unknown"`` when neither source-tree nor
37
+ installed-package metadata is available.
38
+ """
39
+ return _source_tree_version() or _installed_version()
40
+
41
+
42
+ def _installed_version() -> str:
43
+ try:
44
+ return _metadata_version("atst-tools")
45
+ except PackageNotFoundError:
46
+ return "unknown"
47
+
48
+
49
+ __version__ = package_version()
50
+
51
+ __all__ = ["__version__", "package_version"]
File without changes
@@ -0,0 +1,81 @@
1
+ """Resolve the ABACUS ASE backend used by ATST-Tools."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ import shlex
7
+ from typing import Literal
8
+
9
+ from ase.calculators.genericfileio import read_stdout
10
+
11
+
12
+ BackendSource = Literal["external", "vendored"]
13
+
14
+
15
+ def _load_abacuslite_backend():
16
+ try:
17
+ from abacuslite import Abacus, AbacusProfile
18
+
19
+ return Abacus, AbacusProfile, "external"
20
+ except ImportError:
21
+ from atst_tools.external.ASE_interface.abacuslite import Abacus, AbacusProfile
22
+
23
+ return Abacus, AbacusProfile, "vendored"
24
+
25
+
26
+ Abacus, AbacusProfile, BACKEND_SOURCE = _load_abacuslite_backend()
27
+
28
+
29
+ def _default_version_command(command: str) -> list[str]:
30
+ """Return the default bare ABACUS version-probe command."""
31
+ parts = shlex.split(command)
32
+ if not parts:
33
+ return ["abacus", "--version"]
34
+ executable = parts[-1] if parts[0] in {"mpirun", "mpiexec", "srun"} else parts[0]
35
+ return [executable, "--version"]
36
+
37
+
38
+ class ATSTAbacusProfile(AbacusProfile):
39
+ """ABACUS profile with ATST-managed version probing.
40
+
41
+ The run command may be MPI-wrapped for calculations, but version probing is
42
+ a lightweight environment check and defaults to a bare executable call.
43
+ """
44
+
45
+ def __init__(self, *args, version_command: str | None = None, **kwargs):
46
+ """Initialize an ATST ABACUS profile.
47
+
48
+ Args:
49
+ *args: Positional arguments passed to abacuslite's profile.
50
+ version_command: Optional full command used for version probing.
51
+ **kwargs: Keyword arguments passed to abacuslite's profile.
52
+ """
53
+ super().__init__(*args, **kwargs)
54
+ self.version_command = version_command
55
+
56
+ @staticmethod
57
+ def parse_version(stdout: str) -> str:
58
+ """Parse ABACUS version output from legacy and LTS banner formats."""
59
+ for pattern in (r"ABACUS version (\S+)", r"ABACUS\s+(v\S+)"):
60
+ match = re.search(pattern, stdout)
61
+ if match is not None:
62
+ return match.group(1)
63
+ raise RuntimeError(f"Could not parse ABACUS version from output: {stdout!r}")
64
+
65
+ def version(self) -> str:
66
+ """Return the ABACUS version using ATST's version-probe command."""
67
+ command = (
68
+ shlex.split(self.version_command)
69
+ if self.version_command
70
+ else _default_version_command(self.command)
71
+ )
72
+ return self.parse_version(read_stdout(command))
73
+
74
+
75
+ __all__ = [
76
+ "Abacus",
77
+ "AbacusProfile",
78
+ "ATSTAbacusProfile",
79
+ "BACKEND_SOURCE",
80
+ "BackendSource",
81
+ ]
@@ -0,0 +1,10 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict
3
+
4
+ class BaseCalculator(ABC):
5
+ """Abstract base class for ATST calculators."""
6
+
7
+ @abstractmethod
8
+ def get_calculator(self, **kwargs) -> Any:
9
+ """Return a configured ASE calculator instance."""
10
+ pass
@@ -0,0 +1,112 @@
1
+ """DeepMD-kit ASE calculator adapter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import Any, Dict, Hashable
8
+
9
+ from ase.calculators.calculator import Calculator
10
+
11
+
12
+ def is_dp_calculator(name: str) -> bool:
13
+ """Return whether a calculator name refers to the DeepMD-kit adapter."""
14
+ return name.lower() in {"dp", "deepmd"}
15
+
16
+
17
+ def dp_section(config: Dict[str, Any]) -> Dict[str, Any]:
18
+ """Return DeepMD-kit settings from supported config layouts."""
19
+ calculator = config.get("calculator", {})
20
+ if isinstance(calculator, dict):
21
+ if "dp" in calculator:
22
+ return dict(calculator.get("dp") or {})
23
+ if calculator.get("name") == "deepmd":
24
+ return dict(calculator.get("deepmd") or {})
25
+ if "parameters" in config:
26
+ return dict(config["parameters"])
27
+ return {}
28
+
29
+
30
+ def dp_share_calculator(config: Dict[str, Any], default: bool = True) -> bool:
31
+ """Return the configured DP calculator sharing policy."""
32
+ return bool(dp_section(config).get("share_calculator", default))
33
+
34
+
35
+ def should_share_calculator(name: str, config: Dict[str, Any], parallel: bool = False) -> bool:
36
+ """Return whether workflow images should share a single calculator instance."""
37
+ return is_dp_calculator(name) and not parallel and dp_share_calculator(config)
38
+
39
+
40
+ def _normalize_type_dict(dp_params: Dict[str, Any]) -> dict[str, int] | None:
41
+ type_map = dp_params.pop("type_map", None)
42
+ type_dict = dp_params.pop("type_dict", None)
43
+ if type_map is not None and type_dict is not None:
44
+ raise ValueError("calculator.dp.type_map and calculator.dp.type_dict are mutually exclusive")
45
+ if type_dict is not None:
46
+ return {str(symbol): int(index) for symbol, index in dict(type_dict).items()}
47
+ if type_map is None:
48
+ return None
49
+ return {str(symbol): index for index, symbol in enumerate(type_map)}
50
+
51
+
52
+ def _cache_key(model: str, params: Dict[str, Any]) -> tuple[Hashable, ...]:
53
+ serializable = json.dumps(params, sort_keys=True, default=str)
54
+ return (os.path.abspath(os.path.expanduser(model)), serializable)
55
+
56
+
57
+ class DeepPotentialFactory:
58
+ """Factory for creating DeepMD-kit ASE calculators with optional sharing."""
59
+
60
+ _instances: Dict[tuple[Hashable, ...], Calculator] = {}
61
+
62
+ @staticmethod
63
+ def get_calculator(
64
+ config: Dict[str, Any],
65
+ shared: bool | None = None,
66
+ **kwargs: Any,
67
+ ) -> Calculator:
68
+ """Create a DeepMD-kit ASE calculator.
69
+
70
+ Args:
71
+ config: ATST-Tools configuration dictionary.
72
+ shared: Override the configured calculator sharing policy.
73
+ **kwargs: Workflow-local calculator construction hints.
74
+
75
+ Returns:
76
+ Configured ``deepmd.calculator.DP`` instance.
77
+ """
78
+ try:
79
+ from deepmd.calculator import DP
80
+ except ImportError as exc:
81
+ raise ImportError(
82
+ "deepmd-kit is not installed. Install it to use the DP calculator."
83
+ ) from exc
84
+
85
+ dp_params = dp_section(config)
86
+ dp_params.update(kwargs)
87
+
88
+ model_file = dp_params.pop("model", None)
89
+ if not model_file:
90
+ raise ValueError("Missing required field calculator.dp.model")
91
+
92
+ omp = dp_params.pop("omp", None)
93
+ if omp is not None:
94
+ os.environ["OMP_NUM_THREADS"] = str(int(omp))
95
+
96
+ share = dp_share_calculator(config) if shared is None else bool(shared)
97
+ dp_params.pop("share_calculator", None)
98
+ dp_params.pop("directory", None)
99
+
100
+ type_dict = _normalize_type_dict(dp_params)
101
+ constructor_params: Dict[str, Any] = dict(dp_params)
102
+ if type_dict is not None:
103
+ constructor_params["type_dict"] = type_dict
104
+
105
+ key = _cache_key(model_file, constructor_params)
106
+ if share and key in DeepPotentialFactory._instances:
107
+ return DeepPotentialFactory._instances[key]
108
+
109
+ calc = DP(model=model_file, **constructor_params)
110
+ if share:
111
+ DeepPotentialFactory._instances[key] = calc
112
+ return calc
@@ -0,0 +1,126 @@
1
+ """Calculator factories for ATST-Tools workflows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import shlex
7
+ import logging
8
+ from typing import Any, Dict
9
+
10
+ from ase.calculators.calculator import Calculator
11
+
12
+ from atst_tools.calculators.abacuslite_backend import Abacus, ATSTAbacusProfile, BACKEND_SOURCE
13
+ from atst_tools.calculators.dp import DeepPotentialFactory
14
+
15
+
16
+ _ABACUS_CONTROL_KEYS = {"command", "mpi", "omp", "directory", "parameters", "version_command"}
17
+ LOGGER = logging.getLogger(__name__)
18
+ _ABACUS_BACKEND_LOGGED = False
19
+
20
+
21
+ def _abacus_section(config: Dict[str, Any]) -> Dict[str, Any]:
22
+ """Return ABACUS calculator settings from the supported config layouts."""
23
+ if "calculator" in config:
24
+ return dict(config.get("calculator", {}).get("abacus", {}))
25
+ if "abacus" in config:
26
+ return dict(config.get("abacus", {}))
27
+ return dict(config)
28
+
29
+
30
+ def _as_mp_kpts(kpts: Any) -> Any:
31
+ if isinstance(kpts, list) and len(kpts) == 3:
32
+ return {
33
+ "mode": "mp-sampling",
34
+ "nk": kpts,
35
+ "gamma-centered": True,
36
+ "kshift": [0, 0, 0],
37
+ }
38
+ return kpts
39
+
40
+
41
+ def _build_abacus_command(command: str, mpi: int) -> str:
42
+ if "{mpi}" in command:
43
+ return command.format(mpi=mpi)
44
+
45
+ executable = shlex.split(command)[0] if command.strip() else "abacus"
46
+ if mpi > 1 and executable not in {"mpirun", "mpiexec", "srun"}:
47
+ return f"mpirun -np {mpi} {command}"
48
+ return command or "abacus"
49
+
50
+
51
+ def _resolve_directory(path: str | None) -> str | None:
52
+ if path is None:
53
+ return None
54
+ return os.path.abspath(os.path.expanduser(path))
55
+
56
+
57
+ class AbacusFactory:
58
+ """Factory for creating ABACUS ASE calculators through abacuslite."""
59
+
60
+ @staticmethod
61
+ def _log_backend_source_once() -> None:
62
+ global _ABACUS_BACKEND_LOGGED
63
+ if not _ABACUS_BACKEND_LOGGED:
64
+ LOGGER.info("Using %s abacuslite backend for ABACUS calculator.", BACKEND_SOURCE)
65
+ _ABACUS_BACKEND_LOGGED = True
66
+
67
+ @staticmethod
68
+ def get_calculator(
69
+ config: Dict[str, Any],
70
+ directory: str | None = None,
71
+ mpi: int | None = None,
72
+ omp: int | None = None,
73
+ **kwargs: Any,
74
+ ) -> Calculator:
75
+ AbacusFactory._log_backend_source_once()
76
+ abacus_config = _abacus_section(config)
77
+ raw_parameters = dict(abacus_config.get("parameters", {}))
78
+
79
+ parameters = {
80
+ key: value
81
+ for key, value in abacus_config.items()
82
+ if key not in _ABACUS_CONTROL_KEYS
83
+ }
84
+ parameters.update(raw_parameters)
85
+ parameters.update(kwargs)
86
+
87
+ if "pp" in parameters:
88
+ parameters["pseudopotentials"] = parameters.pop("pp")
89
+ if "basis" in parameters:
90
+ parameters["basissets"] = parameters.pop("basis")
91
+ if "basis_dir" in parameters:
92
+ parameters["orbital_dir"] = parameters.pop("basis_dir")
93
+ if "kpts" in parameters:
94
+ parameters["kpts"] = _as_mp_kpts(parameters["kpts"])
95
+
96
+ pseudo_dir = _resolve_directory(parameters.pop("pseudo_dir", None))
97
+ orbital_dir = _resolve_directory(parameters.pop("orbital_dir", None))
98
+
99
+ mpi = int(mpi if mpi is not None else abacus_config.get("mpi", 1))
100
+ omp = int(omp if omp is not None else abacus_config.get("omp", 1))
101
+ directory = directory or abacus_config.get("directory", ".")
102
+ command = _build_abacus_command(abacus_config.get("command", "abacus"), mpi)
103
+ version_command = abacus_config.get("version_command")
104
+
105
+ os.environ["OMP_NUM_THREADS"] = str(omp)
106
+ profile = ATSTAbacusProfile(
107
+ command=command,
108
+ pseudo_dir=pseudo_dir,
109
+ orbital_dir=orbital_dir,
110
+ omp_num_threads=omp,
111
+ version_command=version_command,
112
+ )
113
+ return Abacus(directory=directory, profile=profile, **parameters)
114
+
115
+
116
+ class CalculatorFactory:
117
+ """Unified factory for calculator construction."""
118
+
119
+ @staticmethod
120
+ def get_calculator(name: str, config: Dict[str, Any], **kwargs: Any) -> Calculator:
121
+ name = name.lower()
122
+ if name == "abacus":
123
+ return AbacusFactory.get_calculator(config, **kwargs)
124
+ if name in {"dp", "deepmd"}:
125
+ return DeepPotentialFactory.get_calculator(config, **kwargs)
126
+ raise ValueError(f"Unsupported calculator: {name}. Supported: 'abacus', 'dp'")
@@ -0,0 +1 @@
1
+ """Vendored ABACUS ASE interface snapshot."""
@@ -0,0 +1,13 @@
1
+ '''
2
+ interfaces to Atomic-orbital Based Ab-initio Computation at UStc (ABACUS),
3
+ for more information about this DFT calculator,
4
+ please refer to the Github official repository:
5
+ https://github.com/deepmodeling/abacus-develop
6
+ and online-manual:
7
+ https://abacus.deepmodeling.com/en/latest/index.html
8
+
9
+ For a more complete ABACUS pre-/post-processing workflow package,
10
+ please refer to the ABACUSTest package:
11
+ https://github.com/pxlxingliang/abacus-test
12
+ '''
13
+ from .core import *