@mat3ra/made 2025.1.18-0 → 2025.4.4-0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/package.json +1 -1
  2. package/pyproject.toml +9 -1
  3. package/src/py/mat3ra/__init__.py +1 -0
  4. package/src/py/mat3ra/made/{basis.py → basis/__init__.py} +48 -56
  5. package/src/py/mat3ra/made/basis/coordinates.py +43 -0
  6. package/src/py/mat3ra/made/cell.py +19 -18
  7. package/src/py/mat3ra/made/debug_utils.py +73 -0
  8. package/src/py/mat3ra/made/lattice.py +72 -53
  9. package/src/py/mat3ra/made/material.py +34 -55
  10. package/src/py/mat3ra/made/tools/README.md +1 -1
  11. package/src/py/mat3ra/made/tools/analyze/lattice.py +2 -2
  12. package/src/py/mat3ra/made/tools/analyze/material.py +6 -6
  13. package/src/py/mat3ra/made/tools/analyze/other.py +47 -16
  14. package/src/py/mat3ra/made/tools/analyze/utils.py +19 -21
  15. package/src/py/mat3ra/made/tools/build/__init__.py +1 -1
  16. package/src/py/mat3ra/made/tools/build/defect/builders.py +4 -11
  17. package/src/py/mat3ra/made/tools/build/interface/builders.py +27 -27
  18. package/src/py/mat3ra/made/tools/build/nanoribbon/builders.py +19 -16
  19. package/src/py/mat3ra/made/tools/build/nanoribbon/configuration.py +1 -1
  20. package/src/py/mat3ra/made/tools/build/passivation/builders.py +6 -2
  21. package/src/py/mat3ra/made/tools/build/perturbation/builders.py +1 -1
  22. package/src/py/mat3ra/made/tools/build/slab/configuration.py +3 -3
  23. package/src/py/mat3ra/made/tools/build/supercell.py +1 -1
  24. package/src/py/mat3ra/made/tools/build/utils.py +5 -4
  25. package/src/py/mat3ra/made/tools/calculate/__init__.py +5 -3
  26. package/src/py/mat3ra/made/tools/calculate/ase/__init__.py +2 -2
  27. package/src/py/mat3ra/made/tools/convert/__init__.py +5 -4
  28. package/src/py/mat3ra/made/tools/modify.py +26 -31
  29. package/src/py/mat3ra/made/tools/site.py +1 -1
  30. package/src/py/mat3ra/made/tools/utils/coordinate.py +1 -1
  31. package/src/py/mat3ra/made/tools/utils/perturbation.py +1 -1
  32. package/src/py/mat3ra/made/utils.py +1 -121
  33. package/tests/py/conftest.py +34 -0
  34. package/tests/py/unit/fixtures/__init__.py +0 -0
  35. package/tests/py/unit/fixtures/cell.py +69 -0
  36. package/tests/py/unit/fixtures/generated/__init__.py +0 -0
  37. package/tests/py/unit/fixtures/generated/fixtures.py +83 -0
  38. package/tests/py/unit/fixtures/interface.py +121 -0
  39. package/tests/py/unit/fixtures/monolayer.py +20 -0
  40. package/tests/py/unit/fixtures/nanoribbon.py +226 -0
  41. package/tests/py/unit/fixtures/slab.py +198 -0
  42. package/tests/py/unit/fixtures/supercell.py +42 -0
  43. package/tests/py/unit/test_lattice.py +64 -4
  44. package/tests/py/unit/test_material.py +54 -14
  45. package/tests/py/unit/test_tools_analyze.py +3 -2
  46. package/tests/py/unit/test_tools_build.py +1 -1
  47. package/tests/py/unit/test_tools_build_defect.py +24 -15
  48. package/tests/py/unit/test_tools_build_grain_boundary.py +3 -3
  49. package/tests/py/unit/test_tools_build_interface.py +14 -9
  50. package/tests/py/unit/test_tools_build_nanoribbon.py +7 -6
  51. package/tests/py/unit/test_tools_build_passivation.py +10 -7
  52. package/tests/py/unit/test_tools_build_perturbation.py +3 -3
  53. package/tests/py/unit/test_tools_build_slab.py +4 -4
  54. package/tests/py/unit/test_tools_build_supercell.py +4 -6
  55. package/tests/py/unit/test_tools_calculate.py +4 -4
  56. package/tests/py/unit/test_tools_convert.py +6 -7
  57. package/tests/py/unit/test_tools_modify.py +42 -28
  58. package/tests/py/unit/utils.py +54 -1
  59. package/tests/py/unit/fixtures.py +0 -828
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@mat3ra/made",
3
- "version": "2025.1.18-0",
3
+ "version": "2025.4.4-0",
4
4
  "description": "MAterials DEsign library",
5
5
  "scripts": {
6
6
  "lint": "eslint --cache src/js tests/js && prettier --write src/js tests/js",
package/pyproject.toml CHANGED
@@ -6,7 +6,7 @@ readme = "README.md"
6
6
  requires-python = ">=3.8"
7
7
  license = {file = "LICENSE.md"}
8
8
  authors = [
9
- {name = "Exabyte Inc.", email = "info@mat3ra.com"}
9
+ { name = "Exabyte Inc.", email = "info@mat3ra.com" }
10
10
  ]
11
11
  classifiers = [
12
12
  "Programming Language :: Python",
@@ -99,3 +99,11 @@ target-version = "py38"
99
99
  profile = "black"
100
100
  multi_line_output = 3
101
101
  include_trailing_comma = true
102
+
103
+ [tool.pytest.ini_options]
104
+ pythonpath = [
105
+ "src/py",
106
+ ]
107
+ testpaths = [
108
+ "tests/py"
109
+ ]
@@ -0,0 +1 @@
1
+ __path__ = __import__("pkgutil").extend_path(__path__, __name__)
@@ -1,21 +1,36 @@
1
- import json
2
- from typing import Dict, List, Optional, Union
3
-
4
- from mat3ra.code.constants import AtomicCoordinateUnits
5
- from mat3ra.utils.mixins import RoundNumericValuesMixin
6
- from pydantic import BaseModel
7
-
8
- from .cell import Cell
9
- from .utils import ArrayWithIds, get_overlapping_coordinates
10
-
11
-
12
- class Basis(RoundNumericValuesMixin, BaseModel):
13
- elements: ArrayWithIds = ArrayWithIds(values=["Si"])
14
- coordinates: ArrayWithIds = ArrayWithIds(values=[0, 0, 0])
15
- units: str = AtomicCoordinateUnits.crystal
16
- cell: Cell = Cell()
17
- labels: Optional[ArrayWithIds] = ArrayWithIds(values=[])
18
- constraints: Optional[ArrayWithIds] = ArrayWithIds(values=[])
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ from mat3ra.code.array_with_ids import ArrayWithIds
4
+ from mat3ra.code.entity import InMemoryEntityPydantic
5
+ from mat3ra.esse.models.material import BasisSchema
6
+ from mat3ra.esse.models.material import Units as UnitsEnum
7
+ from mat3ra.made.basis.coordinates import Coordinates
8
+ from mat3ra.made.cell import Cell
9
+ from mat3ra.made.utils import get_overlapping_coordinates
10
+ from pydantic import Field
11
+
12
+
13
+ class Basis(BasisSchema, InMemoryEntityPydantic):
14
+ elements: ArrayWithIds
15
+ coordinates: Coordinates
16
+ cell: Cell = Field(Cell(), exclude=True)
17
+ labels: ArrayWithIds = Field(ArrayWithIds.from_values([]))
18
+ constraints: ArrayWithIds = Field(ArrayWithIds.from_values([]))
19
+
20
+ def __convert_kwargs__(self, **kwargs: Any) -> Dict[str, Any]:
21
+ if isinstance(kwargs.get("elements"), list):
22
+ kwargs["elements"] = ArrayWithIds.from_list_of_dicts(kwargs["elements"])
23
+ if isinstance(kwargs.get("coordinates"), list):
24
+ kwargs["coordinates"] = Coordinates.from_list_of_dicts(kwargs["coordinates"])
25
+ if isinstance(kwargs.get("labels"), list):
26
+ kwargs["labels"] = ArrayWithIds.from_list_of_dicts(kwargs["labels"])
27
+ if isinstance(kwargs.get("constraints"), list):
28
+ kwargs["constraints"] = ArrayWithIds.from_list_of_dicts(kwargs["constraints"])
29
+ return kwargs
30
+
31
+ def __init__(self, *args: Any, **kwargs: Any):
32
+ kwargs = self.__convert_kwargs__(**kwargs)
33
+ super().__init__(*args, **kwargs)
19
34
 
20
35
  @classmethod
21
36
  def from_dict(
@@ -23,57 +38,38 @@ class Basis(RoundNumericValuesMixin, BaseModel):
23
38
  elements: List[Dict],
24
39
  coordinates: List[Dict],
25
40
  units: str,
26
- labels: Optional[List[Dict]] = None,
27
- cell: Optional[List[List[float]]] = None,
28
- constraints: Optional[List[Dict]] = None,
41
+ cell: List[List[float]],
42
+ labels: Optional[List[Dict]] = ArrayWithIds.from_list_of_dicts([]),
43
+ constraints: Optional[List[Dict]] = ArrayWithIds.from_list_of_dicts([]),
29
44
  ) -> "Basis":
30
45
  return Basis(
31
46
  elements=ArrayWithIds.from_list_of_dicts(elements),
32
- coordinates=ArrayWithIds.from_list_of_dicts(coordinates),
47
+ coordinates=Coordinates.from_list_of_dicts(coordinates),
33
48
  units=units,
34
49
  cell=Cell.from_vectors_array(cell),
35
- labels=ArrayWithIds.from_list_of_dicts(labels) if labels else ArrayWithIds(values=[]),
36
- constraints=ArrayWithIds.from_list_of_dicts(constraints) if constraints else ArrayWithIds(values=[]),
37
- )
38
-
39
- def to_json(self, skip_rounding=False):
40
- json_value = {
41
- "elements": self.elements.to_json(),
42
- "coordinates": self.coordinates.to_json(skip_rounding=skip_rounding),
43
- "units": self.units,
44
- "labels": self.labels.to_json(),
45
- }
46
- return json.loads(json.dumps(json_value))
47
-
48
- def clone(self):
49
- return Basis(
50
- elements=self.elements,
51
- coordinates=self.coordinates,
52
- units=self.units,
53
- cell=self.cell,
54
- isEmpty=False,
55
- labels=self.labels,
50
+ labels=ArrayWithIds.from_list_of_dicts(labels),
51
+ constraints=ArrayWithIds.from_list_of_dicts(constraints),
56
52
  )
57
53
 
58
54
  @property
59
55
  def is_in_crystal_units(self):
60
- return self.units == AtomicCoordinateUnits.crystal
56
+ return self.units == UnitsEnum.crystal
61
57
 
62
58
  @property
63
59
  def is_in_cartesian_units(self):
64
- return self.units == AtomicCoordinateUnits.cartesian
60
+ return self.units == UnitsEnum.cartesian
65
61
 
66
62
  def to_cartesian(self):
67
63
  if self.is_in_cartesian_units:
68
64
  return
69
65
  self.coordinates.map_array_in_place(self.cell.convert_point_to_cartesian)
70
- self.units = AtomicCoordinateUnits.cartesian
66
+ self.units = UnitsEnum.cartesian
71
67
 
72
68
  def to_crystal(self):
73
69
  if self.is_in_crystal_units:
74
70
  return
75
71
  self.coordinates.map_array_in_place(self.cell.convert_point_to_crystal)
76
- self.units = AtomicCoordinateUnits.crystal
72
+ self.units = UnitsEnum.crystal
77
73
 
78
74
  def add_atom(
79
75
  self,
@@ -117,19 +113,15 @@ class Basis(RoundNumericValuesMixin, BaseModel):
117
113
  def remove_atom_by_id(self, id: int):
118
114
  self.elements.remove_item(id)
119
115
  self.coordinates.remove_item(id)
120
- if self.labels is not None:
121
- self.labels.remove_item(id)
116
+ self.labels.remove_item(id)
122
117
 
123
- def filter_atoms_by_ids(self, ids: Union[List[int], int]) -> "Basis":
124
- self.elements.filter_by_ids(ids)
125
- self.coordinates.filter_by_ids(ids)
126
- if self.labels is not None:
127
- self.labels.filter_by_ids(ids)
118
+ def filter_atoms_by_ids(self, ids: Union[List[int], int], invert: bool = False) -> "Basis":
119
+ self.elements.filter_by_ids(ids, invert)
120
+ self.coordinates.filter_by_ids(ids, invert)
121
+ self.labels.filter_by_ids(ids, invert)
128
122
  return self
129
123
 
130
124
  def filter_atoms_by_labels(self, labels: Union[List[str], str]) -> "Basis":
131
- if self.labels is None:
132
- return self
133
125
  self.labels.filter_by_values(labels)
134
126
  ids = self.labels.ids
135
127
  self.elements.filter_by_ids(ids)
@@ -0,0 +1,43 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ from mat3ra.code.array_with_ids import RoundedArrayWithIds
5
+ from mat3ra.code.value_with_id import RoundedValueWithId
6
+ from mat3ra.code.vector import RoundedVector3D as RoundedPoint3D
7
+
8
+
9
+ class Coordinate(RoundedValueWithId):
10
+ value: RoundedPoint3D
11
+
12
+ def get_value_along_axis(self, axis: Literal["x", "y", "z"] = "z"):
13
+ return self.value.root[{"x": 0, "y": 1, "z": 2}[axis]]
14
+
15
+
16
+ class Coordinates(RoundedArrayWithIds):
17
+ def get_values_along_axis(
18
+ self,
19
+ axis: Literal["x", "y", "z"] = "z",
20
+ ):
21
+ values_along_axis = [Coordinate(value=coord).get_value_along_axis(axis) for coord in self.values]
22
+ return values_along_axis
23
+
24
+ def get_max_value_along_axis(
25
+ self,
26
+ axis: Literal["x", "y", "z"] = "z",
27
+ ):
28
+ return np.max(self.get_values_along_axis(axis))
29
+
30
+ def get_min_value_along_axis(
31
+ self,
32
+ axis: Literal["x", "y", "z"] = "z",
33
+ ):
34
+ return np.min(self.get_values_along_axis(axis))
35
+
36
+ def get_extremum_value_along_axis(
37
+ self,
38
+ extremum: Literal["max", "min"] = "max",
39
+ axis: Literal["x", "y", "z"] = "z",
40
+ ):
41
+ if extremum == "max":
42
+ return self.get_max_value_along_axis(axis)
43
+ return self.get_min_value_along_axis(axis)
@@ -7,16 +7,23 @@ from pydantic import BaseModel, Field
7
7
 
8
8
  class Cell(RoundNumericValuesMixin, BaseModel):
9
9
  # TODO: figure out how to use ArrayOf3NumberElementsSchema
10
- vector1: List[float] = Field(default_factory=lambda: [1, 0, 0])
11
- vector2: List[float] = Field(default_factory=lambda: [0, 1, 0])
12
- vector3: List[float] = Field(default_factory=lambda: [0, 0, 1])
10
+ vector1: List[float] = Field(default_factory=lambda: [1.0, 0.0, 0.0])
11
+ vector2: List[float] = Field(default_factory=lambda: [0.0, 1.0, 0.0])
12
+ vector3: List[float] = Field(default_factory=lambda: [0.0, 0.0, 1.0])
13
13
  __round_precision__ = 6
14
14
 
15
15
  @classmethod
16
16
  def from_vectors_array(cls, vectors_array: Optional[List[List[float]]] = None) -> "Cell":
17
17
  if vectors_array is None:
18
- vectors_array = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
19
- return cls(vector1=vectors_array[0], vector2=vectors_array[1], vector3=vectors_array[2])
18
+ vectors_array = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
19
+
20
+ # Ensure vectors are properly converted to lists of floats
21
+ processed_vectors = []
22
+ for vector in vectors_array:
23
+ processed_vector = [float(v) for v in vector]
24
+ processed_vectors.append(processed_vector)
25
+
26
+ return cls(vector1=processed_vectors[0], vector2=processed_vectors[1], vector3=processed_vectors[2])
20
27
 
21
28
  @property
22
29
  def vectors_as_array(self, skip_rounding=False) -> List[List[float]]:
@@ -24,7 +31,7 @@ class Cell(RoundNumericValuesMixin, BaseModel):
24
31
  return [self.vector1, self.vector2, self.vector3]
25
32
  return self.round_array_or_number([self.vector1, self.vector2, self.vector3])
26
33
 
27
- def to_json(self, skip_rounding=False):
34
+ def to_list(self, skip_rounding=False) -> List[List[float]]:
28
35
  _ = self.round_array_or_number
29
36
  return [
30
37
  self.vector1 if skip_rounding else _(self.vector1),
@@ -35,23 +42,17 @@ class Cell(RoundNumericValuesMixin, BaseModel):
35
42
  def clone(self) -> "Cell":
36
43
  return self.from_vectors_array(self.vectors_as_array)
37
44
 
38
- def clone_and_scale_by_matrix(self, matrix: List[List[float]]) -> "Cell":
39
- new_cell = self.clone()
40
- new_cell.scale_by_matrix(matrix)
41
- return new_cell
42
-
43
45
  def convert_point_to_cartesian(self, point: List[float]) -> List[float]:
44
46
  np_vector = np.array(self.vectors_as_array)
45
- return np.dot(point, np_vector)
47
+ result_list = np.dot(point, np_vector).tolist()
48
+ return self.round_array_or_number(result_list)
46
49
 
47
50
  def convert_point_to_crystal(self, point: List[float]) -> List[float]:
48
51
  np_vector = np.array(self.vectors_as_array)
49
- return np.dot(point, np.linalg.inv(np_vector))
50
-
51
- def scale_by_matrix(self, matrix: List[List[float]]):
52
- np_vector = np.array(self.vectors_as_array)
53
- self.vector1, self.vector2, self.vector3 = np.dot(np.array(matrix), np_vector).tolist()
52
+ result_list = np.dot(point, np.linalg.inv(np_vector)).tolist()
53
+ return self.round_array_or_number(result_list)
54
54
 
55
55
  @property
56
56
  def volume(self) -> float:
57
- return np.linalg.det(np.array(self.vectors_as_array))
57
+ volume = np.linalg.det(np.array(self.vectors_as_array))
58
+ return self.round_array_or_number(volume)
@@ -0,0 +1,73 @@
1
+ import os
2
+ import tempfile
3
+ import time
4
+ import webbrowser
5
+
6
+ # Use a default div id
7
+ default_div_id = "wave-div"
8
+
9
+
10
+ def get_wave_html(div_id=default_div_id, width=600, height=600, title="Material"):
11
+ size = min(width, height) # Make it square using the smaller dimension
12
+ return f"""
13
+ <h2>{title}</h2>
14
+ <div id="{div_id}" style="width:{size}px; height:{size}px; border:1px solid #333;"></div>
15
+ """
16
+
17
+
18
+ def get_wave_js(material_json, div_id=default_div_id):
19
+ return (
20
+ f"""
21
+ const materialConfig = {material_json};
22
+ const container = document.getElementById('{div_id}');
23
+ """
24
+ + """
25
+ (async function() {
26
+ const module = await import('https://exabyte-io.github.io/wave.js/main.js');
27
+ window.renderThreeDEditor(materialConfig, container);
28
+ })();
29
+ document.head.insertAdjacentHTML(
30
+ 'beforeend',
31
+ '<link rel="stylesheet" href="https://exabyte-io.github.io/wave.js/main.css"/>'
32
+ );
33
+ """
34
+ )
35
+
36
+
37
+ def debug_visualize_material(material, width=600, height=600, title="Material"):
38
+ """
39
+ Generates a temporary HTML file that uses Wave.js to visualize the material,
40
+ and opens it in the default browser.
41
+
42
+ Call this function from the PyCharm debugger (e.g., via Evaluate Expression).
43
+ """
44
+ # Convert your material to JSON.
45
+ # (Assuming material.to_json() returns a JSON-serializable object)
46
+ material_json = material.to_json()
47
+
48
+ # Generate a unique div id so multiple calls don't conflict
49
+ div_id = f"wave-{int(time.time())}"
50
+
51
+ # Create HTML content that includes our working code
52
+ html_content = f"""
53
+ <!DOCTYPE html>
54
+ <html>
55
+ <head>
56
+ <meta charset="UTF-8">
57
+ <title>Wave.js Debug Viewer</title>
58
+ </head>
59
+ <body>
60
+ {get_wave_html(div_id, width, height, title)}
61
+ <script type="module">
62
+ {get_wave_js(material_json, div_id)}
63
+ </script>
64
+ </body>
65
+ </html>
66
+ """
67
+
68
+ # Write the HTML to a temporary file and open it in the default browser
69
+ fd, file_path = tempfile.mkstemp(suffix=".html", prefix="wave_debug_")
70
+ with os.fdopen(fd, "w", encoding="utf-8") as f:
71
+ f.write(html_content)
72
+
73
+ webbrowser.open("file://" + file_path)
@@ -1,39 +1,63 @@
1
1
  import math
2
- from typing import Any, Dict, List, Optional
2
+ from typing import List, Optional
3
3
 
4
4
  import numpy as np
5
+ from mat3ra.code.entity import InMemoryEntityPydantic
6
+ from mat3ra.code.vector import RoundedVector3D
7
+ from mat3ra.esse.models.properties_directory.structural.lattice.lattice_bravais import (
8
+ LatticeImplicitSchema as LatticeBravaisSchema,
9
+ )
10
+ from mat3ra.esse.models.properties_directory.structural.lattice.lattice_bravais import (
11
+ LatticeTypeEnum,
12
+ LatticeUnitsSchema,
13
+ )
14
+ from mat3ra.esse.models.properties_directory.structural.lattice.lattice_vectors import (
15
+ LatticeExplicitUnit as LatticeVectorsSchema,
16
+ )
5
17
  from mat3ra.utils.mixins import RoundNumericValuesMixin
6
- from pydantic import BaseModel
18
+ from pydantic import Field
7
19
 
8
20
  from .cell import Cell
9
21
 
10
- HASH_TOLERANCE = 3
11
- DEFAULT_UNITS = {"length": "angstrom", "angle": "degree"}
12
- DEFAULT_TYPE = "TRI"
22
+ COORDINATE_TOLERANCE = 6
13
23
 
14
24
 
15
- class LatticeVectors(BaseModel):
25
+ class LatticeVector(RoundedVector3D):
26
+ pass
27
+
28
+
29
+ class LatticeVectors(RoundNumericValuesMixin, LatticeVectorsSchema):
16
30
  """
17
31
  A class to represent the lattice vectors.
18
32
  """
19
33
 
20
- a: List[float] = [1.0, 0.0, 0.0]
21
- b: List[float] = [0.0, 1.0, 0.0]
22
- c: List[float] = [0.0, 0.0, 1.0]
34
+ a: LatticeVector = Field(default_factory=lambda: LatticeVector(root=[1.0, 0.0, 0.0]))
35
+ b: LatticeVector = Field(default_factory=lambda: LatticeVector(root=[0.0, 1.0, 0.0]))
36
+ c: LatticeVector = Field(default_factory=lambda: LatticeVector(root=[0.0, 0.0, 1.0]))
37
+
38
+ @classmethod
39
+ def from_vectors_array(cls, vectors: List[List[float]]) -> "LatticeVectors":
40
+ return cls(a=LatticeVector(root=vectors[0]), b=LatticeVector(root=vectors[1]), c=LatticeVector(root=vectors[2]))
41
+
23
42
 
43
+ class Lattice(RoundNumericValuesMixin, LatticeBravaisSchema, InMemoryEntityPydantic):
44
+ __types__ = LatticeTypeEnum
45
+ __type_default__ = LatticeBravaisSchema.model_fields["type"].default
46
+ __units_default__ = LatticeBravaisSchema.model_fields["units"].default_factory()
24
47
 
25
- class Lattice(RoundNumericValuesMixin, BaseModel):
26
48
  a: float = 1.0
27
49
  b: float = a
28
50
  c: float = a
29
51
  alpha: float = 90.0
30
52
  beta: float = 90.0
31
53
  gamma: float = 90.0
32
- units: Dict[str, str] = DEFAULT_UNITS
33
- type: str = DEFAULT_TYPE
34
54
 
35
55
  @property
36
56
  def vectors(self) -> LatticeVectors:
57
+ vectors = self.calculate_vectors()
58
+ return LatticeVectors.from_vectors_array(vectors)
59
+
60
+ def calculate_vectors(self):
37
61
  a = self.a
38
62
  b = self.b
39
63
  c = self.c
@@ -59,61 +83,56 @@ class Lattice(RoundNumericValuesMixin, BaseModel):
59
83
  vector_b = [-b * sin_alpha * cos_gamma_star, b * sin_alpha * sin_gamma_star, b * cos_alpha]
60
84
  vector_c = [0.0, 0.0, c]
61
85
 
62
- return LatticeVectors(a=vector_a, b=vector_b, c=vector_c)
86
+ return [vector_a, vector_b, vector_c]
63
87
 
64
88
  @classmethod
65
89
  def from_vectors_array(
66
- cls, vectors: List[List[float]], units: Optional[Dict[str, str]] = None, type: Optional[str] = None
90
+ cls,
91
+ vectors: List[List[float]],
92
+ units: Optional[LatticeUnitsSchema] = __units_default__,
93
+ type: Optional[LatticeTypeEnum] = __type_default__,
67
94
  ) -> "Lattice":
68
- """
69
- Create a Lattice object from a nested array of vectors.
70
- Args:
71
- vectors (List[List[float]]): A nested array of vectors.
72
- Returns:
73
- Lattice: A Lattice object.
74
- """
75
95
  a = np.linalg.norm(vectors[0])
76
96
  b = np.linalg.norm(vectors[1])
77
97
  c = np.linalg.norm(vectors[2])
78
98
  alpha = np.degrees(np.arccos(np.dot(vectors[1], vectors[2]) / (b * c)))
79
99
  beta = np.degrees(np.arccos(np.dot(vectors[0], vectors[2]) / (a * c)))
80
100
  gamma = np.degrees(np.arccos(np.dot(vectors[0], vectors[1]) / (a * b)))
81
- if units is None:
82
- units = DEFAULT_UNITS
83
- if type is None:
84
- type = DEFAULT_TYPE
85
- return cls(a=float(a), b=float(b), c=float(c), alpha=alpha, beta=beta, gamma=gamma, units=units, type=type)
86
-
87
- def to_json(self, skip_rounding: bool = False) -> Dict[str, Any]:
88
- __round__ = RoundNumericValuesMixin.round_array_or_number
89
- round_func = __round__ if not skip_rounding else lambda x: x
90
- return {
91
- "a": round_func(self.a),
92
- "b": round_func(self.b),
93
- "c": round_func(self.c),
94
- "alpha": round_func(self.alpha),
95
- "beta": round_func(self.beta),
96
- "gamma": round_func(self.gamma),
97
- "units": self.units,
98
- "type": self.type,
99
- "vectors": self.vectors,
100
- }
101
-
102
- def clone(self, extra_context: Optional[Dict[str, Any]] = None) -> "Lattice":
103
- if extra_context is None:
104
- extra_context = {}
105
- return Lattice(**{**self.to_json(), **extra_context})
101
+
102
+ return cls(
103
+ a=float(a),
104
+ b=float(b),
105
+ c=float(c),
106
+ alpha=alpha,
107
+ beta=beta,
108
+ gamma=gamma,
109
+ units=units,
110
+ type=type,
111
+ )
106
112
 
107
113
  @property
108
- def vector_arrays(self) -> List[List[float]]:
109
- """Returns lattice vectors as a nested array"""
110
- v = self.vectors
111
- return [v.a, v.b, v.c]
114
+ def vector_arrays(self, skip_rounding=False) -> List[List[float]]:
115
+ _ = [self.vectors.a, self.vectors.b, self.vectors.c]
116
+ if not skip_rounding:
117
+ return list(map(lambda vector: vector.value_rounded, _))
118
+ return list(map(lambda vector: vector.root, _))
112
119
 
113
120
  @property
114
121
  def cell(self) -> Cell:
115
122
  return Cell.from_vectors_array(self.vector_arrays)
116
123
 
117
- def volume(self) -> float:
118
- np_vector = np.array(self.vector_arrays)
119
- return abs(np.linalg.det(np_vector))
124
+ @property
125
+ def cell_volume(self) -> float:
126
+ return self.cell.volume
127
+
128
+ def get_scaled_by_matrix(self, matrix: List[List[float]]):
129
+ """
130
+ Scale the lattice by a matrix.
131
+ Args:
132
+ matrix (List[List[float]]): A 3x3 matrix.
133
+ """
134
+ np_vectors = np.array(self.vector_arrays)
135
+ np_matrix = np.array(matrix)
136
+ scaled_vectors = np.dot(np_matrix, np_vectors).tolist()
137
+ new_lattice = self.from_vectors_array(vectors=scaled_vectors, units=self.units, type=self.type)
138
+ return new_lattice