stjames 0.0.40__py3-none-any.whl → 0.0.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of stjames might be problematic. Click here for more details.

stjames/__init__.py CHANGED
@@ -1,6 +1,8 @@
1
1
  # ruff: noqa: I001
2
2
 
3
3
  from .calculation import *
4
+ from .atom import *
5
+ from .periodic_cell import *
4
6
  from .molecule import *
5
7
  from .workflows import *
6
8
 
@@ -21,3 +23,4 @@ from .mode import *
21
23
  from .status import *
22
24
  from .constraint import *
23
25
  from .message import *
26
+ from .types import *
stjames/atom.py ADDED
@@ -0,0 +1,66 @@
1
+ from typing import Self, Sequence
2
+
3
+ from pydantic import NonNegativeInt
4
+
5
+ from .base import Base
6
+ from .data import ELEMENT_SYMBOL, SYMBOL_ELEMENT
7
+ from .types import Vector3D
8
+
9
+
10
+ class Atom(Base):
11
+ atomic_number: NonNegativeInt
12
+ position: Vector3D # in Å
13
+
14
+ def __repr__(self) -> str:
15
+ """
16
+ >>> Atom(atomic_number=2, position=[0, 1, 2])
17
+ Atom(2, [0.00000, 1.00000, 2.00000])
18
+ """
19
+ x, y, z = self.position
20
+ return f"Atom({self.atomic_number}, [{x:.5f}, {y:.5f}, {z:.5f}])"
21
+
22
+ def __str__(self) -> str:
23
+ """
24
+ >>> str(Atom(atomic_number=2, position=[0, 1, 2]))
25
+ 'He 0.0000000000 1.0000000000 2.0000000000'
26
+ """
27
+ x, y, z = self.position
28
+ return f"{self.atomic_symbol:2} {x:15.10f} {y:15.10f} {z:15.10f}"
29
+
30
+ @property
31
+ def atomic_symbol(self) -> str:
32
+ """
33
+ >>> Atom(atomic_number=2, position=[0, 1, 2]).atomic_symbol
34
+ 'He'
35
+ """
36
+ return ELEMENT_SYMBOL[self.atomic_number]
37
+
38
+ def edited(self, atomic_number: int | None = None, position: Sequence[float] | None = None) -> Self:
39
+ """
40
+ Create a new Atom with the specified changes.
41
+
42
+ >>> a = Atom(atomic_number=2, position=[0, 1, 2])
43
+ >>> a2 = a.edited(3)
44
+ >>> a is a2
45
+ False
46
+ >>> a2
47
+ Atom(3, [0.00000, 1.00000, 2.00000])
48
+ """
49
+ if atomic_number is None:
50
+ atomic_number = self.atomic_number
51
+ if position is None:
52
+ position = list(self.position)
53
+
54
+ return self.__class__(atomic_number=atomic_number, position=position)
55
+
56
+ @classmethod
57
+ def from_xyz(cls: type[Self], xyz_line: str) -> Self:
58
+ """
59
+ >>> Atom.from_xyz("H 0 0 0")
60
+ Atom(1, [0.00000, 0.00000, 0.00000])
61
+ """
62
+ name, *xyz = xyz_line.split()
63
+ symbol = int(name) if name.isdigit() else SYMBOL_ELEMENT[name]
64
+ if not len(xyz) == 3:
65
+ raise ValueError("XYZ file should have 3 coordinates per atom")
66
+ return cls(atomic_number=symbol, position=xyz)
stjames/base.py CHANGED
@@ -1,42 +1,42 @@
1
1
  from enum import Enum
2
- from typing import Annotated, Hashable, TypeVar
2
+ from typing import Annotated, Any, Hashable, TypeVar
3
3
 
4
4
  import numpy as np
5
5
  import pydantic
6
6
 
7
- T = TypeVar("T")
7
+ _T = TypeVar("_T")
8
8
 
9
9
 
10
10
  class Base(pydantic.BaseModel):
11
11
  @pydantic.field_validator("*", mode="before")
12
12
  @classmethod
13
- def coerce_numpy(cls, val: T) -> T | list:
13
+ def coerce_numpy(cls, val: _T) -> _T | list[Any]:
14
14
  if isinstance(val, np.ndarray):
15
- return val.tolist()
16
- else:
17
- return val
15
+ return val.tolist() # type: ignore [no-any-return, unused-ignore]
16
+
17
+ return val
18
18
 
19
19
 
20
20
  class LowercaseStrEnum(str, Enum):
21
21
  """Enum where hyphens, underscores, and case are ignored."""
22
22
 
23
23
  @classmethod
24
- def _missing_(cls, value: str) -> str | None: # type: ignore
25
- # Type note: technically breaking Liskov, value: object in Enum
24
+ def _missing_(cls, value: object) -> str | None:
26
25
  for member in cls:
27
- if member.lower().replace("-", "").replace("_", "") == value.lower().replace("-", "").replace("_", ""):
28
- return member
26
+ if isinstance(value, str):
27
+ if member.lower().replace("-", "").replace("_", "") == value.lower().replace("-", "").replace("_", ""):
28
+ return member
29
29
  return None
30
30
 
31
31
 
32
32
  # cf. https://github.com/pydantic/pydantic-core/pull/820#issuecomment-1670475909
33
- H = TypeVar("H", bound=Hashable)
33
+ _H = TypeVar("_H", bound=Hashable)
34
34
 
35
35
 
36
- def _validate_unique_list(v: list[H]) -> list[H]:
36
+ def _validate_unique_list(v: list[_H]) -> list[_H]:
37
37
  if len(v) != len(set(v)):
38
38
  raise ValueError("this list must be unique, and isn't!")
39
39
  return v
40
40
 
41
41
 
42
- UniqueList = Annotated[list[H], pydantic.AfterValidator(_validate_unique_list)]
42
+ UniqueList = Annotated[list[_H], pydantic.AfterValidator(_validate_unique_list)]
stjames/basis_set.py CHANGED
@@ -1,10 +1,6 @@
1
- import pydantic
2
- from pydantic import PositiveFloat, PositiveInt
1
+ from typing import Optional, Self
3
2
 
4
- try:
5
- from typing import Optional, Self
6
- except ImportError:
7
- from typing_extensions import Optional, Self
3
+ from pydantic import PositiveFloat, PositiveInt, model_validator
8
4
 
9
5
  from .base import Base
10
6
 
@@ -14,7 +10,7 @@ class BasisSetOverride(Base):
14
10
  atomic_numbers: Optional[list[PositiveInt]] = None
15
11
  atoms: Optional[list[PositiveInt]] = None # 1-indexed
16
12
 
17
- @pydantic.model_validator(mode="after")
13
+ @model_validator(mode="after")
18
14
  def check_override(self) -> Self:
19
15
  # ^ is xor
20
16
  assert (self.atomic_numbers is not None) ^ (self.atoms is not None), "Exactly one of ``atomic_numbers`` or ``atoms`` must be specified!"
stjames/calculation.py CHANGED
@@ -5,6 +5,7 @@ from .message import Message
5
5
  from .molecule import Molecule
6
6
  from .settings import Settings
7
7
  from .status import Status
8
+ from .types import UUID
8
9
 
9
10
 
10
11
  class StJamesVersion(LowercaseStrEnum):
@@ -29,6 +30,7 @@ class Calculation(Base):
29
30
  messages: list[Message] = []
30
31
 
31
32
  engine: Optional[str] = "peregrine"
33
+ uuids: list[UUID | None] | None = None
32
34
 
33
35
  # not to be changed by end users, diff. versions will have diff. defaults
34
36
  json_format: str = StJamesVersion.V0
stjames/constraint.py CHANGED
@@ -1,3 +1,5 @@
1
+ from pydantic import PositiveFloat, PositiveInt
2
+
1
3
  from .base import Base, LowercaseStrEnum
2
4
 
3
5
 
@@ -10,7 +12,25 @@ class ConstraintType(LowercaseStrEnum):
10
12
 
11
13
 
12
14
  class Constraint(Base):
13
- """Represents a single constraint."""
15
+ """Represents a single (absolute) constraint."""
14
16
 
15
17
  constraint_type: ConstraintType
16
- atoms: list[int] # 1-indexed
18
+ atoms: list[PositiveInt] # 1-indexed
19
+
20
+
21
+ class PairwiseHarmonicConstraint(Base):
22
+ """
23
+ Represents a harmonic constraint, with a characteristic spring constant.
24
+ """
25
+
26
+ atoms: tuple[PositiveInt, PositiveInt] # 1-indexed
27
+ spring_constant: PositiveFloat # kcal/mol / Å**2
28
+
29
+
30
+ class SphericalHarmonicConstraint(Base):
31
+ """
32
+ Represents a spherical harmonic constraint to keep a system near the origin.
33
+ """
34
+
35
+ confining_radius: PositiveFloat
36
+ confining_force_constant: PositiveFloat = 10 # kcal/mol / Å**2
@@ -0,0 +1 @@
1
+ from .elements import *
@@ -0,0 +1,27 @@
1
+ """Read elemental data from files."""
2
+
3
+ import json
4
+ from collections import namedtuple
5
+ from importlib import resources
6
+
7
+ data_dir = resources.files("stjames").joinpath("data")
8
+
9
+ with data_dir.joinpath("symbol_element.json").open() as f:
10
+ SYMBOL_ELEMENT: dict[str, int] = json.loads(f.read())
11
+
12
+ ELEMENT_SYMBOL = {v: k for k, v in SYMBOL_ELEMENT.items()}
13
+
14
+ Isotope = namedtuple("Isotope", ["relative_atomic_mass", "isotopic_composition", "standard_atomic_weight"])
15
+ with data_dir.joinpath("nist_isotopes.json").open() as f:
16
+ d = json.loads(f.read())
17
+
18
+ ISOTOPES: dict[int, dict[int, Isotope]] = {
19
+ int(k): {
20
+ int(kk): Isotope(*vv)
21
+ for kk, vv in v.items() # stay open
22
+ }
23
+ for k, v in d.items()
24
+ }
25
+
26
+ with data_dir.joinpath("bragg_radii.json").open() as f:
27
+ BRAGG_RADII: dict[int, float] = json.loads(f.read())
@@ -0,0 +1,116 @@
1
+ """
2
+ Read the NIST isotopes data file and write it to a JSON file.
3
+
4
+ NIST Isotopes data from:
5
+ https://physics.nist.gov/cgi-bin/Compositions/stand_alone.pl?ele=&all=all&ascii=ascii2
6
+ """
7
+
8
+ import json
9
+ from collections import defaultdict
10
+ from importlib import resources
11
+ from typing import Callable, TypeVar
12
+
13
+ data_dir = resources.files("stjames").joinpath("data")
14
+
15
+ _T = TypeVar("_T")
16
+
17
+
18
+ def process_line(line: str, fmt: Callable[[str], _T] = str) -> _T: # type: ignore[assignment]
19
+ """
20
+ Process a line from the NIST data file.
21
+
22
+ :param line: line to process
23
+ :param fmt: function to format the value
24
+ >>> process_line("Atomic Number = 1", int)
25
+ 1
26
+ """
27
+ return fmt(line.split("=")[-1].strip())
28
+
29
+
30
+ def fmt_float(val: str) -> float:
31
+ """
32
+ Format a float from the NIST data file.
33
+
34
+ >>> fmt_float(" 1.00784(7)")
35
+ 1.00784
36
+ """
37
+ return float(val.strip().split("(")[0])
38
+
39
+
40
+ def fmt_maybe_list(val: str) -> float:
41
+ """
42
+ Format a float or list of floats from the NIST data file.
43
+
44
+ Only the first value is returned.
45
+
46
+ >>> fmt_maybe_list("1.00784(7)")
47
+ 1.00784
48
+ >>> fmt_maybe_list(" [1.00784,1.00811]")
49
+ 1.00784
50
+ >>> fmt_maybe_list(" [98]")
51
+ 98.0
52
+ """
53
+ val = val.strip()
54
+ if val.startswith("["):
55
+ val = val[1:-1].split(",")[0]
56
+ return fmt_float(val)
57
+
58
+
59
+ def process_chunk(chunk: str) -> tuple[int, int, tuple[float, float, float]]:
60
+ r"""
61
+ Atomic Number, Mass Number, (Relative Atomic Mass, Isotopic Composition, Standard Atomic Weight)
62
+
63
+ >>> process_chunk('''\
64
+ ... Atomic Number = 1
65
+ ... Atomic Symbol = H
66
+ ... Mass Number = 1
67
+ ... Relative Atomic Mass = 1.00784(7)
68
+ ... Isotopic Composition = 0.999885(70)
69
+ ... Standard Atomic Weight = [1.00784,1.00811]
70
+ ... Notes = m
71
+ ... ''')
72
+ (1, 1, (1.00784, 0.999885, 1.00784))
73
+ """
74
+ lines = chunk.splitlines()
75
+
76
+ atomic_number = process_line(lines[0], int)
77
+ _atomic_symbol = process_line(lines[1], str)
78
+ mass_number = process_line(lines[2], int)
79
+ relative_atomic_mass = process_line(lines[3], fmt_float)
80
+ try:
81
+ isotopic_composition = process_line(lines[4], fmt_float)
82
+ except ValueError:
83
+ isotopic_composition = 0
84
+ try:
85
+ standard_atomic_weight = process_line(lines[5], fmt_maybe_list)
86
+ except ValueError:
87
+ standard_atomic_weight = relative_atomic_mass
88
+
89
+ return atomic_number, mass_number, (relative_atomic_mass, isotopic_composition, standard_atomic_weight)
90
+
91
+
92
+ def read_nist_isotopes() -> dict[int, dict[int, tuple[float, float, float]]]:
93
+ """
94
+ Read the NIST data file and write it to a JSON file.
95
+
96
+ {Atomic Number: {Mass Number, (Relative Atomic Mass, Isotopic Composition, Standard Atomic Weight)}}
97
+ """
98
+ with data_dir.joinpath("nist_isotopes.txt").open() as f:
99
+ next(f), next(f) # Skip the first two lines
100
+ nist_isotopes = f.read()
101
+
102
+ isotopes: dict[int, dict[int, tuple[float, float, float]]] = defaultdict(dict)
103
+ for chunk in nist_isotopes.split("\n\n"):
104
+ atomic_number, mass_number, values = process_chunk(chunk)
105
+ isotopes[atomic_number][mass_number] = values
106
+
107
+ with open("nist_isotopes.json", "w") as f:
108
+ json.dump(isotopes, f)
109
+
110
+ return isotopes
111
+
112
+
113
+ if __name__ == "__main__":
114
+ from pprint import pprint
115
+
116
+ pprint(read_nist_isotopes())
stjames/method.py CHANGED
@@ -29,10 +29,43 @@ class Method(LowercaseStrEnum):
29
29
 
30
30
  AIMNET2_WB97MD3 = "aimnet2_wb97md3"
31
31
 
32
+ GFN_FF = "gfn_ff"
32
33
  GFN0_XTB = "gfn0_xtb"
33
34
  GFN1_XTB = "gfn1_xtb"
34
35
  GFN2_XTB = "gfn2_xtb"
35
- GFN_FF = "gfn_ff"
36
36
 
37
37
  # this was going to be removed, but Jonathon wrote such a nice basis set test... it's off the front end.
38
38
  BP86 = "bp86"
39
+
40
+
41
+ MLFF = [
42
+ Method.AIMNET2_WB97MD3,
43
+ ]
44
+
45
+ XTB_METHODS = [
46
+ Method.GFN_FF,
47
+ Method.GFN0_XTB,
48
+ Method.GFN1_XTB,
49
+ Method.GFN2_XTB,
50
+ ]
51
+
52
+ COMPOSITE_METHODS = [
53
+ Method.HF3C,
54
+ Method.B973C,
55
+ Method.R2SCAN3C,
56
+ Method.WB97X3C,
57
+ ]
58
+
59
+ PREPACKAGED_METHODS = [
60
+ *MLFF,
61
+ *XTB_METHODS,
62
+ *COMPOSITE_METHODS,
63
+ ]
64
+
65
+ METHODS_WITH_CORRECTION = [
66
+ Method.WB97XD3,
67
+ Method.WB97XV,
68
+ Method.WB97MV,
69
+ Method.WB97MD3BJ,
70
+ Method.DSDBLYPD3BJ,
71
+ ]
stjames/molecule.py CHANGED
@@ -1,24 +1,26 @@
1
+ from pathlib import Path
2
+ from typing import Iterable, Optional, Self
3
+
1
4
  import pydantic
2
5
  from pydantic import NonNegativeInt, PositiveInt
3
6
 
4
- try:
5
- from typing import Optional, Self
6
- except ImportError:
7
- from typing_extensions import Optional, Self
8
-
7
+ from .atom import Atom
9
8
  from .base import Base
9
+ from .periodic_cell import PeriodicCell
10
+ from .types import Matrix3x3, Vector3D, Vector3DPerAtom
11
+
12
+
13
+ class MoleculeReadError(RuntimeError):
14
+ pass
10
15
 
11
16
 
12
17
  class VibrationalMode(Base):
13
18
  frequency: float # in cm-1
14
- reduced_mass: float
15
- force_constant: float
16
- displacements: list[list[float]]
17
-
19
+ reduced_mass: float # amu
18
20
 
19
- class Atom(Base):
20
- atomic_number: NonNegativeInt
21
- position: list[float] # in Å
21
+ # todo - check units here?
22
+ force_constant: float
23
+ displacements: Vector3DPerAtom
22
24
 
23
25
 
24
26
  class Molecule(Base):
@@ -26,6 +28,9 @@ class Molecule(Base):
26
28
  multiplicity: PositiveInt
27
29
  atoms: list[Atom]
28
30
 
31
+ # for periodic boundary conditions
32
+ cell: Optional[PeriodicCell] = None
33
+
29
34
  energy: Optional[float] = None # in Hartree
30
35
  scf_iterations: Optional[NonNegativeInt] = None
31
36
  scf_completed: Optional[bool] = None
@@ -33,11 +38,14 @@ class Molecule(Base):
33
38
 
34
39
  homo_lumo_gap: Optional[float] = None # in eV
35
40
 
36
- gradient: Optional[list[list[float]]] = None # Hartree/Bohr
41
+ gradient: Optional[Vector3DPerAtom] = None # Hartree
42
+ stress: Optional[Matrix3x3] = None # Hartree/Å
43
+
44
+ velocities: Optional[Vector3DPerAtom] = None # Å/fs
37
45
 
38
46
  mulliken_charges: Optional[list[float]] = None
39
47
  mulliken_spin_densities: Optional[list[float]] = None
40
- dipole: Optional[list[float]] = None # in Debye
48
+ dipole: Optional[Vector3D] = None # in Debye
41
49
 
42
50
  vibrational_modes: Optional[list[VibrationalMode]] = None
43
51
 
@@ -49,8 +57,18 @@ class Molecule(Base):
49
57
  def __len__(self) -> int:
50
58
  return len(self.atoms)
51
59
 
60
+ def distance(self, atom1: PositiveInt, atom2: PositiveInt) -> float:
61
+ r"""
62
+ Get the distance between atoms.
63
+
64
+ >>> mol = Molecule.from_xyz("H 0 1 0\nH 0 0 1")
65
+ >>> mol.distance(1, 2)
66
+ 1.4142135623730951
67
+ """
68
+ return sum((q2 - q1) ** 2 for q1, q2 in zip(self.atoms[atom1 - 1].position, self.atoms[atom2 - 1].position)) ** 0.5 # type: ignore [no-any-return,unused-ignore]
69
+
52
70
  @property
53
- def coordinates(self) -> list[list[float]]:
71
+ def coordinates(self) -> Vector3DPerAtom:
54
72
  return [a.position for a in self.atoms]
55
73
 
56
74
  @property
@@ -92,3 +110,83 @@ class Molecule(Base):
92
110
  )
93
111
 
94
112
  return self
113
+
114
+ @classmethod
115
+ def from_file(cls: type[Self], filename: Path | str, format: str | None = None, charge: int = 0, multiplicity: PositiveInt = 1) -> Self:
116
+ r"""
117
+ Read a molecule from a file.
118
+
119
+ >>> import tempfile
120
+ >>> with tempfile.NamedTemporaryFile("w+", suffix=".xyz") as f:
121
+ ... _ = f.write("2\nComment\nH 0 0 0\nF 0 0 1")
122
+ ... _ = f.seek(0)
123
+ ... mol = Molecule.from_file(f.name)
124
+ >>> print(mol.to_xyz())
125
+ 2
126
+ <BLANKLINE>
127
+ H 0.0000000000 0.0000000000 0.0000000000
128
+ F 0.0000000000 0.0000000000 1.0000000000
129
+ """
130
+ filename = Path(filename)
131
+ if not format:
132
+ format = filename.suffix[1:]
133
+
134
+ with open(filename) as f:
135
+ match format:
136
+ case "xyz":
137
+ return cls.from_xyz_lines(f.readlines(), charge=charge, multiplicity=multiplicity)
138
+ case _:
139
+ raise ValueError(f"Unsupported {format=}")
140
+
141
+ @classmethod
142
+ def from_xyz(cls: type[Self], xyz: str, charge: int = 0, multiplicity: PositiveInt = 1) -> Self:
143
+ r"""
144
+ Generate a Molecule from an XYZ string.
145
+
146
+ Note: only supports single molecule inputs.
147
+
148
+ >>> len(Molecule.from_xyz("2\nComment\nH 0 0 0\nH 0 0 1"))
149
+ 2
150
+ """
151
+ return cls.from_xyz_lines(xyz.strip().splitlines(), charge=charge, multiplicity=multiplicity)
152
+
153
+ @classmethod
154
+ def from_xyz_lines(cls: type[Self], lines: Iterable[str], charge: int = 0, multiplicity: PositiveInt = 1) -> Self:
155
+ lines = list(lines)
156
+ if len(lines[0].split()) == 1:
157
+ natoms = lines[0].strip()
158
+ if not natoms.isdigit() or (int(lines[0]) != len(lines) - 2):
159
+ raise MoleculeReadError(f"First line of XYZ file should be the number of atoms, got: {lines[0]} != {len(lines) - 2}")
160
+ lines = lines[2:]
161
+
162
+ try:
163
+ return cls(atoms=[Atom.from_xyz(line) for line in lines], charge=charge, multiplicity=multiplicity)
164
+ except Exception as e:
165
+ raise MoleculeReadError("Error reading molecule from xyz") from e
166
+
167
+ def to_xyz(self, comment: str = "", out_file: Path | str | None = None) -> str:
168
+ r"""
169
+ Generate an XYZ string.
170
+
171
+ >>> mol = Molecule.from_xyz("2\nComment\nH 0 1 2\nF 1 2 3")
172
+ >>> print(mol.to_xyz(comment="HF"))
173
+ 2
174
+ HF
175
+ H 0.0000000000 1.0000000000 2.0000000000
176
+ F 1.0000000000 2.0000000000 3.0000000000
177
+ >>> import tempfile
178
+ >>> with tempfile.TemporaryDirectory() as directory:
179
+ ... file = Path(directory) / "mol.xyz"
180
+ ... out = mol.to_xyz(comment="HF", out_file=file)
181
+ ... with file.open() as f:
182
+ ... Molecule.from_xyz(f.read()).to_xyz("HF") == out
183
+ True
184
+ """
185
+ geom = "\n".join(map(str, self.atoms))
186
+ out = f"{len(self)}\n{comment}\n{geom}"
187
+
188
+ if out_file:
189
+ with Path(out_file).open("w") as f:
190
+ f.write(out)
191
+
192
+ return out
stjames/opt_settings.py CHANGED
@@ -1,4 +1,6 @@
1
- from pydantic import Field, PositiveFloat, PositiveInt
1
+ from typing import Sequence
2
+
3
+ from pydantic import PositiveFloat, PositiveInt
2
4
 
3
5
  from .base import Base
4
6
  from .constraint import Constraint
@@ -8,9 +10,12 @@ class OptimizationSettings(Base):
8
10
  max_steps: PositiveInt = 250
9
11
  transition_state: bool = False
10
12
 
11
- # when are we converged?
12
- max_gradient_threshold: PositiveFloat = 4.5e-4
13
- rms_gradient_threshold: PositiveFloat = 3.0e-4
13
+ # when are we converged? (Hartree and Hartree/Å)
14
+ max_gradient_threshold: PositiveFloat = 7e-4
15
+ rms_gradient_threshold: PositiveFloat = 6e-4
14
16
  energy_threshold: PositiveFloat = 1e-6
15
17
 
16
- constraints: list[Constraint] = Field(default_factory=list)
18
+ # for periodic systems only
19
+ optimize_cell: bool = False
20
+
21
+ constraints: Sequence[Constraint] = tuple()
@@ -0,0 +1,34 @@
1
+ from typing import TypeAlias
2
+
3
+ import numpy as np
4
+ import pydantic
5
+
6
+ from .base import Base
7
+ from .types import Matrix3x3
8
+
9
+ Bool3: TypeAlias = tuple[bool, bool, bool]
10
+
11
+
12
+ class PeriodicCell(Base):
13
+ lattice_vectors: Matrix3x3
14
+ is_periodic: Bool3 = (True, True, True)
15
+
16
+ @pydantic.field_validator("lattice_vectors")
17
+ @classmethod
18
+ def check_tensor_3D(cls, v: Matrix3x3) -> Matrix3x3:
19
+ if len(v) != 3 or any(len(row) != 3 for row in v):
20
+ raise ValueError("Cell tensor must be a 3x3 list of floats")
21
+
22
+ return v
23
+
24
+ @pydantic.field_validator("is_periodic")
25
+ @classmethod
26
+ def check_pbc(cls, v: Bool3) -> Bool3:
27
+ if not any(v):
28
+ raise ValueError("For periodic boundary conditions, at least one dimension must be periodic!")
29
+ return v
30
+
31
+ @pydantic.computed_field # type: ignore[misc, prop-decorator, unused-ignore]
32
+ @property
33
+ def volume(self) -> float:
34
+ return float(np.abs(np.linalg.det(np.array(self.lattice_vectors))))