qupled 1.3.3__cp310-cp310-macosx_14_0_arm64.whl → 1.3.4__cp310-cp310-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
qupled/mpi.py CHANGED
@@ -1,69 +1,104 @@
1
- import functools
2
-
3
- from qupled import native
4
-
5
-
6
- class MPI:
7
- """Class to handle the calls to the MPI API"""
8
-
9
- def rank(self):
10
- """Get rank of the process"""
11
- return native.MPI.rank()
12
-
13
- def is_root(self):
14
- """Check if the current process is root (rank 0)"""
15
- return native.MPI.is_root()
16
-
17
- def barrier(self):
18
- """Setup an MPI barrier"""
19
- native.MPI.barrier()
20
-
21
- def timer(self):
22
- """Get wall time"""
23
- return native.MPI.timer()
24
-
25
- @staticmethod
26
- def run_only_on_root(func):
27
- """Python decorator for all methods that have to be run only by root"""
28
-
29
- @functools.wraps(func)
30
- def wrapper(*args, **kwargs):
31
- if native.MPI.is_root():
32
- return func(*args, **kwargs)
33
-
34
- return wrapper
35
-
36
- @staticmethod
37
- def synchronize_ranks(func):
38
- """Python decorator for all methods that need rank synchronization"""
39
-
40
- @functools.wraps(func)
41
- def wrapper(*args, **kwargs):
42
- func(*args, **kwargs)
43
- native.MPI.barrier()
44
-
45
- return wrapper
46
-
47
- @staticmethod
48
- def record_time(func):
49
- """Python decorator for all methods that have to be timed"""
50
-
51
- @functools.wraps(func)
52
- def wrapper(*args, **kwargs):
53
- mpi = native.MPI
54
- tic = mpi.timer()
55
- func(*args, **kwargs)
56
- toc = mpi.timer()
57
- dt = toc - tic
58
- hours = dt // 3600
59
- minutes = (dt % 3600) // 60
60
- seconds = dt % 60
61
- if mpi.is_root():
62
- if hours > 0:
63
- print("Elapsed time: %d h, %d m, %d s." % (hours, minutes, seconds))
64
- elif minutes > 0:
65
- print("Elapsed time: %d m, %d s." % (minutes, seconds))
66
- else:
67
- print("Elapsed time: %.1f s." % seconds)
68
-
69
- return wrapper
1
+ import json
2
+ import subprocess
3
+ import shutil
4
+
5
+ from pathlib import Path
6
+ from . import native
7
+
8
+ # MPI command
9
+ MPI_COMMAND = "mpiexec"
10
+
11
+ # Temporary files used for MPI executions
12
+ INPUT_FILE = Path("input.json")
13
+ RESULT_FILE = Path("results.json")
14
+ STATUS_FILE = Path("status.json")
15
+
16
+
17
+ def launch_mpi_execution(module, nproc):
18
+ """
19
+ Launches the execution of a Python module using MPI if available, otherwise defaults to serial execution.
20
+
21
+ Args:
22
+ module (str): The name of the Python module to execute (as used with the '-m' flag).
23
+ nproc (int): The number of processes to use for MPI execution.
24
+
25
+ Behavior:
26
+ - Checks if the MPI command is available and if native MPI usage is enabled.
27
+ - If MPI is available, runs the module with the specified number of processes using MPI.
28
+ - If MPI is not available, prints a warning and runs the module in serial mode.
29
+
30
+ Raises:
31
+ subprocess.CalledProcessError: If the subprocess execution fails.
32
+ """
33
+ call_mpi = shutil.which(MPI_COMMAND) is not None and native.uses_mpi
34
+ if call_mpi:
35
+ subprocess.run(
36
+ [MPI_COMMAND, "-n", str(nproc), "python", "-m", module], check=True
37
+ )
38
+ else:
39
+ print("WARNING: Could not call MPI, defaulting to serial execution.")
40
+ subprocess.run(["python", "-m", module], check=True)
41
+
42
+
43
+ def write_inputs(inputs):
44
+ """
45
+ Writes the input data to the INPUT_FILE in JSON format.
46
+ """
47
+ with INPUT_FILE.open("w") as f:
48
+ json.dump(inputs.to_dict(), f)
49
+
50
+
51
+ def read_inputs(InputCls):
52
+ """
53
+ Reads input data from a predefined input file and constructs an instance of the specified input class.
54
+ """
55
+ with INPUT_FILE.open() as f:
56
+ input_dict = json.load(f)
57
+ return InputCls.from_dict(input_dict)
58
+
59
+
60
+ def write_results(scheme, ResultCls):
61
+ """
62
+ Writes the results of a computation to a JSON file if the current process is the root.
63
+ """
64
+ if scheme.is_root:
65
+ results = ResultCls()
66
+ results.from_native(scheme)
67
+ with RESULT_FILE.open("w") as f:
68
+ json.dump(results.to_dict(), f)
69
+
70
+
71
+ def read_results(ResultsCls):
72
+ """
73
+ Reads results from a JSON file and returns an instance of the specified ResultsCls.
74
+ """
75
+ with RESULT_FILE.open() as f:
76
+ result_dict = json.load(f)
77
+ return ResultsCls.from_dict(result_dict)
78
+
79
+
80
+ def write_status(scheme, status):
81
+ """
82
+ Writes the status of a computation to a JSON file if the current process is the root.
83
+ """
84
+ if scheme.is_root:
85
+ with STATUS_FILE.open("w") as f:
86
+ json.dump(status, f)
87
+
88
+
89
+ def read_status():
90
+ """
91
+ Reads status from a JSON file and returns an instance of the specified ResultsCls.
92
+ """
93
+ with STATUS_FILE.open() as f:
94
+ status = json.load(f)
95
+ return status
96
+
97
+
98
+ def clean_files():
99
+ """
100
+ Removes the input and result files if they exist.
101
+ """
102
+ for file in [INPUT_FILE, RESULT_FILE, STATUS_FILE]:
103
+ if file.exists():
104
+ file.unlink()
Binary file
qupled/qstls.py CHANGED
@@ -2,20 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  from . import database
4
4
  from . import native
5
+ from . import serialize
5
6
  from . import stls
6
7
 
7
8
 
8
- class Qstls(stls.Stls):
9
+ class Solver(stls.Solver):
9
10
  """
10
11
  Class used to solve the Qstls scheme.
11
12
  """
12
13
 
14
+ # Native classes used to solve the scheme
15
+ native_scheme_cls = native.Qstls
16
+ native_inputs_cls = native.QstlsInput
17
+
13
18
  def __init__(self):
14
19
  super().__init__()
15
20
  self.results: stls.Result = stls.Result()
16
- # Undocumented properties
17
- self.native_scheme_cls = native.Qstls
18
- self.native_inputs_cls = native.QstlsInput
19
21
 
20
22
  def compute(self, inputs: Input):
21
23
  self.find_fixed_adr_in_database(inputs)
@@ -55,14 +57,15 @@ class Qstls(stls.Stls):
55
57
  return
56
58
 
57
59
 
58
- # Input class
60
+ @serialize.serializable_dataclass
59
61
  class Input(stls.Input):
60
62
  """
61
63
  Class used to manage the input for the :obj:`qupled.qstls.Qstls` class.
62
64
  """
63
65
 
64
- def __init__(self, coupling: float, degeneracy: float):
65
- super().__init__(coupling, degeneracy)
66
- # Undocumented default values
67
- self.fixed_run_id: int | None = None
68
- self.theory = "QSTLS"
66
+ fixed_run_id: int | None = None
67
+ theory: str = "QSTLS"
68
+
69
+
70
+ if __name__ == "__main__":
71
+ Solver.run_mpi_worker(Input, stls.Result)
qupled/qstlsiet.py CHANGED
@@ -2,36 +2,39 @@ from __future__ import annotations
2
2
 
3
3
  from . import native
4
4
  from . import qstls
5
+ from . import serialize
5
6
  from . import stlsiet
6
7
 
7
8
 
8
- class QstlsIet(qstls.Qstls):
9
+ class Solver(qstls.Solver):
9
10
  """
10
11
  Class used to solve the Qstls-IET schemes.
11
12
  """
12
13
 
14
+ # Native classes used to solve the scheme
15
+ native_scheme_cls = native.QstlsIet
16
+ native_inputs_cls = native.QstlsIetInput
17
+
13
18
  def __init__(self):
14
19
  super().__init__()
15
20
  self.results: stlsiet.Result = stlsiet.Result()
16
- self.native_scheme_cls = native.QstlsIet
17
- self.native_inputs_cls = native.QstlsIetInput
18
21
 
19
22
  @staticmethod
20
23
  def get_initial_guess(
21
24
  run_id: str, database_name: str | None = None
22
25
  ) -> stlsiet.Guess:
23
- return stlsiet.StlsIet.get_initial_guess(run_id, database_name)
26
+ return stlsiet.Solver.get_initial_guess(run_id, database_name)
24
27
 
25
28
 
26
- # Input class
29
+ @serialize.serializable_dataclass
27
30
  class Input(stlsiet.Input, qstls.Input):
28
31
  """
29
32
  Class used to manage the input for the :obj:`qupled.qstlsiet.QStlsIet` class.
30
33
  Accepted theories: ``QSTLS-HNC``, ``QSTLS-IOI`` and ``QSTLS-LCT``.
31
34
  """
32
35
 
33
- def __init__(self, coupling: float, degeneracy: float, theory: str):
34
- super().__init__(coupling, degeneracy, "STLS-HNC")
35
- if theory not in {"QSTLS-HNC", "QSTLS-IOI", "QSTLS-LCT"}:
36
- raise ValueError("Invalid dielectric theory")
37
- self.theory = theory
36
+ allowed_theories = {"QSTLS-HNC", "QSTLS-IOI", "QSTLS-LCT"}
37
+
38
+
39
+ if __name__ == "__main__":
40
+ Solver.run_mpi_worker(Input, stlsiet.Result)
qupled/qvsstls.py CHANGED
@@ -3,19 +3,21 @@ from __future__ import annotations
3
3
  from . import native
4
4
  from . import qstls
5
5
  from . import vsstls
6
+ from . import serialize
6
7
 
7
8
 
8
- class QVSStls(vsstls.VSStls):
9
+ class Solver(vsstls.Solver):
9
10
  """
10
11
  Class used to solve the QVStls scheme.
11
12
  """
12
13
 
14
+ # Native classes used to solve the scheme
15
+ native_scheme_cls = native.QVSStls
16
+ native_inputs_cls = native.QVSStlsInput
17
+
13
18
  def __init__(self):
14
19
  super().__init__()
15
20
  self.results: vsstls.Result = vsstls.Result()
16
- # Undocumented properties
17
- self.native_scheme_cls = native.QVSStls
18
- self.native_inputs_cls = native.QVSStlsInput
19
21
 
20
22
  def compute(self, inputs: Input):
21
23
  """
@@ -24,7 +26,7 @@ class QVSStls(vsstls.VSStls):
24
26
  Args:
25
27
  inputs: Input parameters.
26
28
  """
27
- qstls.Qstls.find_fixed_adr_in_database(self, inputs)
29
+ qstls.Solver.find_fixed_adr_in_database(self, inputs)
28
30
  super().compute(inputs)
29
31
 
30
32
  def _update_input_data(self, inputs: Input):
@@ -46,14 +48,14 @@ class QVSStls(vsstls.VSStls):
46
48
  inputs.fixed_run_id = self.run_id
47
49
 
48
50
 
49
- # Input class
51
+ @serialize.serializable_dataclass
50
52
  class Input(vsstls.Input, qstls.Input):
51
53
  """
52
54
  Class used to manage the input for the :obj:`qupled.qvsstls.QVSStls` class.
53
55
  """
54
56
 
55
- def __init__(self, coupling: float, degeneracy: float):
56
- vsstls.Input.__init__(self, coupling, degeneracy)
57
- qstls.Input.__init__(self, coupling, degeneracy)
58
- # Undocumented default values
59
- self.theory: str = "QVSSTLS"
57
+ theory: str = "QVSSTLS"
58
+
59
+
60
+ if __name__ == "__main__":
61
+ Solver.run_mpi_worker(Input, vsstls.Result)
qupled/rpa.py CHANGED
@@ -2,26 +2,30 @@ from __future__ import annotations
2
2
 
3
3
  from . import hf
4
4
  from . import native
5
+ from . import serialize
5
6
 
6
7
 
7
- class Rpa(hf.HF):
8
+ class Solver(hf.Solver):
8
9
  """
9
10
  Class used to solve the RPA scheme.
10
11
  """
11
12
 
13
+ # Native classes used to solve the scheme
14
+ native_scheme_cls = native.Rpa
15
+
12
16
  def __init__(self):
13
17
  super().__init__()
14
18
  self.results: hf.Result = hf.Result()
15
- # Undocumented properties
16
- self.native_scheme_cls = native.Rpa
17
19
 
18
20
 
21
+ @serialize.serializable_dataclass
19
22
  class Input(hf.Input):
20
23
  """
21
24
  Class used to manage the input for the :obj:`qupled.rpa.Rpa` class.
22
25
  """
23
26
 
24
- def __init__(self, coupling: float, degeneracy: float):
25
- super().__init__(coupling, degeneracy)
26
- # Undocumented default values
27
- self.theory = "RPA"
27
+ theory: str = "RPA"
28
+
29
+
30
+ if __name__ == "__main__":
31
+ Solver.run_mpi_worker(Input, hf.Result)
qupled/serialize.py ADDED
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+
3
+ from dataclasses import dataclass
4
+ from typing import get_type_hints
5
+
6
+
7
+ def serializable_dataclass(cls):
8
+
9
+ cls = dataclass(cls)
10
+
11
+ def to_dict(self):
12
+ result = {}
13
+ for key, value in self.__dict__.items():
14
+ if hasattr(value, "to_dict") and callable(value.to_dict):
15
+ result[key] = value.to_dict()
16
+ elif isinstance(value, np.ndarray):
17
+ result[key] = value.tolist()
18
+ else:
19
+ result[key] = value
20
+ return result
21
+
22
+ @classmethod
23
+ def from_dict(cls, d):
24
+ obj = cls.__new__(cls)
25
+ annotations = get_type_hints(cls)
26
+ for key, value in d.items():
27
+ expected_type = annotations.get(key)
28
+ from_dict_fn = getattr(expected_type, "from_dict", None)
29
+ convert_to_np_array = expected_type is np.ndarray and isinstance(
30
+ value, list
31
+ )
32
+ call_from_dict = callable(from_dict_fn) and isinstance(value, dict)
33
+ if convert_to_np_array:
34
+ setattr(obj, key, np.array(value))
35
+ elif call_from_dict:
36
+ setattr(obj, key, from_dict_fn(value))
37
+ else:
38
+ setattr(obj, key, value)
39
+ return obj
40
+
41
+ cls.to_dict = to_dict
42
+ cls.from_dict = from_dict
43
+ return cls
qupled/stls.py CHANGED
@@ -1,24 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import field
3
4
  import numpy as np
4
5
 
5
6
  from . import hf
6
7
  from . import native
7
8
  from . import output
8
9
  from . import rpa
10
+ from . import serialize
9
11
 
10
12
 
11
- class Stls(rpa.Rpa):
13
+ class Solver(rpa.Solver):
12
14
  """
13
15
  Class used to solve the Stls scheme.
14
16
  """
15
17
 
18
+ # Native classes used to solve the scheme
19
+ native_scheme_cls = native.Stls
20
+ native_inputs_cls = native.StlsInput
21
+
16
22
  def __init__(self):
17
23
  super().__init__()
18
24
  self.results: Result = Result()
19
- # Undocumented properties
20
- self.native_scheme_cls = native.Stls
21
- self.native_inputs_cls = native.StlsInput
22
25
 
23
26
  @staticmethod
24
27
  def get_initial_guess(run_id: int, database_name: str | None = None) -> Guess:
@@ -37,43 +40,39 @@ class Stls(rpa.Rpa):
37
40
  return Guess(data[names[0]], data[names[1]])
38
41
 
39
42
 
43
+ @serialize.serializable_dataclass
40
44
  class Input(rpa.Input):
41
45
  """
42
46
  Class used to manage the input for the :obj:`qupled.stls.Stls` class.
43
47
  """
44
48
 
45
- def __init__(self, coupling: float, degeneracy: float):
46
- super().__init__(coupling, degeneracy)
47
- self.error: float = 1.0e-5
48
- """Minimum error for convergence. Default = ``1.0e-5``"""
49
- self.mixing: float = 1.0
50
- """Mixing parameter. Default = ``1.0``"""
51
- self.iterations: int = 1000
52
- """Maximum number of iterations. Default = ``1000``"""
53
- self.guess: Guess = Guess()
54
- """Initial guess. Default = ``stls.Guess()``"""
55
- # Undocumented default values
56
- self.theory: str = "STLS"
49
+ error: float = 1.0e-5
50
+ """Minimum error for convergence. Default = ``1.0e-5``"""
51
+ mixing: float = 1.0
52
+ """Mixing parameter. Default = ``1.0``"""
53
+ iterations: int = 1000
54
+ """Maximum number of iterations. Default = ``1000``"""
55
+ guess: Guess = field(default_factory=lambda: Guess())
56
+ """Initial guess. Default = ``stls.Guess()``"""
57
+ theory: str = "STLS"
57
58
 
58
59
 
60
+ @serialize.serializable_dataclass
59
61
  class Result(hf.Result):
60
62
  """
61
63
  Class used to store the results for the :obj:`qupled.stls.Stls` class.
62
64
  """
63
65
 
64
- def __init__(self):
65
- super().__init__()
66
- self.error: float = None
67
- """Residual error in the solution"""
66
+ error: float = None
67
+ """Final error of the scheme. Default = ``None``"""
68
68
 
69
69
 
70
+ @serialize.serializable_dataclass
70
71
  class Guess:
71
-
72
- def __init__(self, wvg: np.ndarray = None, ssf: np.ndarray = None):
73
- self.wvg = wvg
74
- """ Wave-vector grid. Default = ``None``"""
75
- self.ssf = ssf
76
- """ Static structure factor. Default = ``None``"""
72
+ wvg: np.ndarray = None
73
+ """Wave-vector grid. Default = ``None``"""
74
+ ssf: np.ndarray = None
75
+ """Static structure factor. Default = ``None``"""
77
76
 
78
77
  def to_native(self) -> native.Guess:
79
78
  """
@@ -92,3 +91,7 @@ class Guess:
92
91
  if value is not None:
93
92
  setattr(native_guess, attr, value)
94
93
  return native_guess
94
+
95
+
96
+ if __name__ == "__main__":
97
+ Solver.run_mpi_worker(Input, Result)
qupled/stlsiet.py CHANGED
@@ -1,22 +1,26 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import field
3
4
  import numpy as np
4
5
 
5
6
  from . import native
6
7
  from . import output
7
8
  from . import stls
9
+ from . import serialize
8
10
 
9
11
 
10
- class StlsIet(stls.Stls):
12
+ class Solver(stls.Solver):
11
13
  """
12
14
  Class used to solve the StlsIet schemes.
13
15
  """
14
16
 
17
+ # Native classes used to solve the scheme
18
+ native_scheme_cls = native.StlsIet
19
+ native_inputs_cls = native.StlsIetInput
20
+
15
21
  def __init__(self):
16
22
  super().__init__()
17
23
  self.results: Result = Result()
18
- self.native_scheme_cls = native.StlsIet
19
- self.native_inputs_cls = native.StlsIetInput
20
24
 
21
25
  @staticmethod
22
26
  def get_initial_guess(run_id: str, database_name: str | None = None) -> Guess:
@@ -42,19 +46,15 @@ class StlsIet(stls.Stls):
42
46
  )
43
47
 
44
48
 
49
+ @serialize.serializable_dataclass
45
50
  class Input(stls.Input):
46
51
  """
47
52
  Class used to manage the input for the :obj:`qupled.stlsiet.StlsIet` class.
48
53
  Accepted theories: ``STLS-HNC``, ``STLS-IOI`` and ``STLS-LCT``.
49
54
  """
50
55
 
51
- def __init__(self, coupling: float, degeneracy: float, theory: str):
52
- super().__init__(coupling, degeneracy)
53
- if theory not in {"STLS-HNC", "STLS-IOI", "STLS-LCT"}:
54
- raise ValueError("Invalid dielectric theory")
55
- self.theory = theory
56
- self.mapping = "standard"
57
- r"""
56
+ mapping: str = "standard"
57
+ r"""
58
58
  Mapping for the classical-to-quantum coupling parameter
59
59
  :math:`\Gamma` used in the iet schemes. Allowed options include:
60
60
 
@@ -69,29 +69,38 @@ class Input(stls.Input):
69
69
  the ground state they can differ significantly (the standard
70
70
  mapping diverges). Default = ``standard``.
71
71
  """
72
- self.guess: Guess = Guess()
73
- """Initial guess. Default = ``stlsiet.Guess()``"""
72
+ guess: Guess = field(default_factory=lambda: Guess())
73
+ allowed_theories = {"STLS-HNC", "STLS-IOI", "STLS-LCT"}
74
+
75
+ def __post_init__(self):
76
+ if self.is_default_theory():
77
+ raise ValueError(
78
+ f"Missing dielectric theory, choose among {self.allowed_theories} "
79
+ )
80
+ if self.theory not in self.allowed_theories:
81
+ raise ValueError(
82
+ f"Invalid dielectric theory {self.theory}, choose among {self.allowed_theories}"
83
+ )
84
+
85
+ def is_default_theory(self) -> bool:
86
+ return self.theory == Input.__dataclass_fields__["theory"].default
74
87
 
75
88
 
89
+ @serialize.serializable_dataclass
76
90
  class Result(stls.Result):
77
91
  """
78
92
  Class used to store the results for the :obj:`qupled.stlsiet.StlsIet` class.
79
93
  """
80
94
 
81
- def __init__(self):
82
- super().__init__()
83
- self.bf: np.ndarray = None
84
- """Bridge function adder"""
95
+ bf: np.ndarray = None
96
+ """Bridge function adder"""
85
97
 
86
98
 
99
+ @serialize.serializable_dataclass
87
100
  class Guess(stls.Guess):
101
+ lfc: np.ndarray = None
102
+ """ Local field correction. Default = ``None``"""
103
+
88
104
 
89
- def __init__(
90
- self,
91
- wvg: np.ndarray = None,
92
- ssf: np.ndarray = None,
93
- lfc: np.ndarray = None,
94
- ):
95
- super().__init__(wvg, ssf)
96
- self.lfc = lfc
97
- """ Local field correction. Default = ``None``"""
105
+ if __name__ == "__main__":
106
+ Solver.run_mpi_worker(Input, Result)
qupled/timer.py ADDED
@@ -0,0 +1,33 @@
1
+ import time
2
+
3
+
4
+ def timer(func):
5
+ """
6
+ A decorator that measures and prints the execution time of the decorated function.
7
+
8
+ The elapsed time is displayed in hours, minutes, and seconds as appropriate.
9
+
10
+ Args:
11
+ func (callable): The function whose execution time is to be measured.
12
+
13
+ Returns:
14
+ callable: A wrapper function that executes the original function and prints the elapsed time.
15
+ """
16
+
17
+ def wrapper(*args, **kwargs):
18
+ tic = time.perf_counter()
19
+ result = func(*args, **kwargs)
20
+ toc = time.perf_counter()
21
+ dt = toc - tic
22
+ hours = int(dt // 3600)
23
+ minutes = int((dt % 3600) // 60)
24
+ seconds = dt % 60
25
+ if hours > 0:
26
+ print(f"Elapsed time: {hours} h, {minutes} m, {seconds:.1f} s.")
27
+ elif minutes > 0:
28
+ print(f"Elapsed time: {minutes} m, {seconds:.1f} s.")
29
+ else:
30
+ print(f"Elapsed time: {seconds:.1f} s.")
31
+ return result
32
+
33
+ return wrapper