llg3d 2.0.1__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. llg3d/__init__.py +2 -4
  2. llg3d/benchmarks/__init__.py +1 -0
  3. llg3d/benchmarks/compare_commits.py +321 -0
  4. llg3d/benchmarks/efficiency.py +451 -0
  5. llg3d/benchmarks/utils.py +25 -0
  6. llg3d/element.py +98 -17
  7. llg3d/grid.py +48 -58
  8. llg3d/io.py +395 -0
  9. llg3d/main.py +32 -35
  10. llg3d/parameters.py +159 -49
  11. llg3d/post/__init__.py +1 -1
  12. llg3d/post/extract.py +112 -0
  13. llg3d/post/info.py +192 -0
  14. llg3d/post/m1_vs_T.py +107 -0
  15. llg3d/post/m1_vs_time.py +81 -0
  16. llg3d/post/process.py +87 -85
  17. llg3d/post/utils.py +38 -0
  18. llg3d/post/x_profiles.py +161 -0
  19. llg3d/py.typed +1 -0
  20. llg3d/solvers/__init__.py +153 -0
  21. llg3d/solvers/base.py +345 -0
  22. llg3d/solvers/experimental/__init__.py +9 -0
  23. llg3d/{solver → solvers/experimental}/jax.py +117 -143
  24. llg3d/solvers/math_utils.py +41 -0
  25. llg3d/solvers/mpi.py +370 -0
  26. llg3d/solvers/numpy.py +126 -0
  27. llg3d/solvers/opencl.py +439 -0
  28. llg3d/solvers/profiling.py +38 -0
  29. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/METADATA +5 -2
  30. llg3d-3.1.0.dist-info/RECORD +36 -0
  31. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/WHEEL +1 -1
  32. llg3d-3.1.0.dist-info/entry_points.txt +9 -0
  33. llg3d/output.py +0 -107
  34. llg3d/post/plot_results.py +0 -61
  35. llg3d/post/temperature.py +0 -76
  36. llg3d/simulation.py +0 -95
  37. llg3d/solver/__init__.py +0 -45
  38. llg3d/solver/mpi.py +0 -450
  39. llg3d/solver/numpy.py +0 -207
  40. llg3d/solver/opencl.py +0 -330
  41. llg3d/solver/solver.py +0 -89
  42. llg3d-2.0.1.dist-info/RECORD +0 -25
  43. llg3d-2.0.1.dist-info/entry_points.txt +0 -4
  44. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/AUTHORS +0 -0
  45. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/LICENSE +0 -0
  46. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,81 @@
1
+ """
2
+ Plot m1 vs time from one or more result files.
3
+
4
+ Use the ``llg3d.m1_vs_time`` command line tool to plot the average magnetization versus time from multiple result files:
5
+
6
+ .. command-output:: llg3d.m1_vs_time -h
7
+ :cwd: ../execute/temperatures
8
+
9
+ When calling the tool on a result files:
10
+
11
+ .. command-output:: llg3d.m1_vs_time run_1100K.npz -i m1_vs_time.png
12
+ :cwd: ../execute/temperatures
13
+ :shell:
14
+
15
+ .. image:: ../execute/temperatures/m1_vs_time.png
16
+ :alt: Magnetization versus time for a single file
17
+
18
+ Now when calling the tool on a selection of result files:
19
+
20
+ .. command-output:: llg3d.m1_vs_time run_*.npz -i m1_vs_time_multiple.png
21
+ :cwd: ../execute/temperatures
22
+ :shell:
23
+
24
+ .. image:: ../execute/temperatures/m1_vs_time_multiple.png
25
+ :alt: Magnetization versus time for multiple files
26
+ """
27
+
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+ from matplotlib import pyplot as plt
32
+
33
+ from ..io import RunResults, load_results
34
+ from .utils import get_cli_args
35
+
36
+
37
+ def plot_m1_vs_time(
38
+ *files: Path, image_filepath: Path | str | None = None, show: bool = False
39
+ ):
40
+ """
41
+ Plot the results from the given files.
42
+
43
+ Args:
44
+ *files: Paths to the result files.
45
+ image_filepath: Path to the output image file.
46
+ if None, the image will not be saved.
47
+ show: display the graph in a graphical window.
48
+ """
49
+ fig, ax = plt.subplots()
50
+ for file in files:
51
+ print(f"Processing file: {file}")
52
+ result: RunResults = load_results(file)
53
+ xyz_average: np.ndarray = np.array(result.get_record("xyz_average"))
54
+ ax.plot(xyz_average[:, 0], xyz_average[:, 1], label=file)
55
+
56
+ ax.set_xlabel("time")
57
+ ax.set_ylabel(r"$<m_1>$")
58
+ ax.grid()
59
+ ax.legend()
60
+ ax.set_title(r"Space average of $m_1$ according to time")
61
+
62
+ if show:
63
+ plt.show()
64
+
65
+ if image_filepath is not None:
66
+ fig.savefig(image_filepath, dpi=300)
67
+ print(f"Written to {image_filepath}")
68
+
69
+
70
+ def main(): # pragma: no cover
71
+ """Parse CLI arguments and call the plot function."""
72
+ args = get_cli_args(
73
+ description="Plot m1 vs time from one or more result files.",
74
+ default_image_filepath=Path("m1_vs_time.png"),
75
+ )
76
+
77
+ plot_m1_vs_time(
78
+ *args.files,
79
+ image_filepath=args.image_filepath,
80
+ show=args.show,
81
+ )
llg3d/post/process.py CHANGED
@@ -1,110 +1,112 @@
1
- #!/usr/bin/env python3
2
1
  """
3
2
  Post-processes a set of runs.
4
3
 
5
- Runs are grouped into a `run.json` file or into a set of SLURM job arrays:
6
-
7
4
  1. Extracts result data,
8
5
  2. Plots the computed average magnetization against temperature,
9
- 3. Interpolates the computed points using cubic splines,
10
- 4. Determines the Curie temperature as the value corresponding to the minimal (negative)
11
- slope of the interpolated curve.
6
+ 3. Interpolates the computed points using a PCHIP interpolator,
7
+ 4. Determines the Curie temperature as the value below which the magnetization
8
+ drops under 0.1.
12
9
  """
13
10
 
14
- import json
15
11
  from pathlib import Path
16
12
 
17
- import numpy as np
18
- from scipy.interpolate import interp1d
19
-
20
-
21
- class MagData:
22
- """Class to handle magnetization data and interpolation according to temperature."""
23
-
24
- n_interp = 200
25
13
 
26
- def __init__(self, job_dir: Path = None, run_file: Path = Path("run.json")) -> None:
27
- if job_dir:
28
- self.parentpath = job_dir
29
- data, self.run = self.process_slurm_jobs()
30
- elif run_file:
31
- self.parentpath = run_file.parent
32
- data, self.run = self.process_json(run_file)
33
-
34
- self.temperature = data[:, 0]
35
- self.m1_mean = data[:, 1]
36
- self.interp = interp1d(self.temperature, self.m1_mean, kind="cubic")
37
- self.T = np.linspace(
14
+ import numpy as np
15
+ from scipy.interpolate import PchipInterpolator
16
+ from scipy.optimize import brentq
17
+
18
+ from llg3d.io import load_results
19
+
20
+
21
+ class MagTempData:
22
+ """
23
+ Handle magnetization data using the npz format.
24
+
25
+ - Extracts result data,
26
+ - Interpolates the computed points using a PCHIP interpolator,
27
+ - Determines the Curie temperature as the value below which the magnetization
28
+ drops under 0.1.
29
+
30
+ Args:
31
+ *files: Paths to the result .npz files
32
+ """
33
+
34
+ n_interp = 200 #: number of interpolation points
35
+
36
+ def __init__(self, *files: Path | str) -> None:
37
+ #: list of result files
38
+ self.files: list[Path] = [Path(file) for file in files]
39
+ self.params: dict = {} #: common parameters of the runs
40
+ # Extract data from the runs
41
+ data = self._process_jobs()
42
+ self.temperature = data[:, 0] #: temperatures from the runs
43
+ self.m1_mean = data[:, 1] #: mean magnetization
44
+ # Use PCHIP interpolator which preserves positivity and monotonicity
45
+ #: interpolated magnetization function
46
+ self.interp = PchipInterpolator(self.temperature, self.m1_mean)
47
+ self.temperature_interp = np.linspace(
38
48
  self.temperature.min(), self.temperature.max(), self.n_interp
39
- )
49
+ ) #: finer temperature grid for interpolation
40
50
 
41
- def process_slurm_jobs(self) -> tuple[np.array, dict]:
51
+ @property
52
+ def T_Curie(self) -> float:
42
53
  """
43
- Iterates through calculation directories to assemble data.
54
+ Return the Curie temperature.
44
55
 
45
- Args:
46
- parentdir (str): path to the directory containing the runs
56
+ It is defined as the temperature at which the magnetization equals 0.1,
57
+ found using a root-finding algorithm for precision.
47
58
 
48
59
  Returns:
49
- tuple: (data, run) where data is a numpy array (T, <m>) and run
50
- is a descriptive dictionary of the run
51
- """
52
- json_filename = "run.json"
53
-
54
- # List of run directories
55
- jobdirs = [f for f in self.parentpath.iterdir() if f.is_dir()]
56
- if len(jobdirs) == 0:
57
- exit(f"No job directories found in {self.parentpath}")
58
- data = []
59
- # Iterating through run directories
60
- for jobdir in jobdirs:
61
- try:
62
- # Reading the JSON file
63
- with open(jobdir / json_filename) as f:
64
- run = json.load(f)
65
- # Adding temperature and averaging value to the data list
66
- data.extend(
67
- [
68
- [float(T), res["m1_mean"]]
69
- for T, res in run["results"].items()
70
- ]
71
- )
72
- except FileNotFoundError:
73
- print(f"Warning: {json_filename} file not found in {jobdir.as_posix()}")
74
-
75
- data.sort() # Sorting by increasing temperatures
76
-
77
- return np.array(data), run
60
+ float: Curie temperature
78
61
 
79
- def process_json(json_filepath: Path) -> tuple[np.array, dict]:
62
+ Raises:
63
+ ValueError: If the magnetization never crosses 0.1 in the dataset
80
64
  """
81
- Reads the run.json file and extracts result data.
82
-
83
- Args:
84
- json_filepath: path to the run.json file
65
+ # Check if magnetization ever crosses 0.1
66
+ T_min = self.temperature.min()
67
+ T_max = self.temperature.max()
68
+ m1_min = self.interp(T_min)
69
+ m1_max = self.interp(T_max)
70
+
71
+ # If 0.1 is never reached
72
+ if (m1_min < 0.1 and m1_max < 0.1) or (m1_min > 0.1 and m1_max > 0.1):
73
+ raise ValueError(
74
+ f"Magnetization never crosses 0.1 in the dataset. "
75
+ f"Range: [{m1_min:.4f}, {m1_max:.4f}]"
76
+ )
77
+
78
+ # Find the exact temperature where m = 0.1 using Brent's method
79
+ T_curie = brentq(lambda T: float(self.interp(T)) - 0.1, T_min, T_max)
80
+ return float(T_curie)
81
+
82
+ def _process_jobs(self) -> np.ndarray:
83
+ """
84
+ Iterates through calculation directories to assemble data.
85
85
 
86
86
  Returns:
87
- tuple: (data, run) where data is a numpy array (T, <m>) and run
88
- is a descriptive dictionary of the run
87
+ data a numpy array (T, <m>)
88
+
89
+ Raises:
90
+ ValueError: If any file does not end with .npz
89
91
  """
90
- with open(json_filepath) as f:
91
- run = json.load(f)
92
+ for file in self.files:
93
+ if not file.name.endswith(".npz"):
94
+ raise ValueError(f"File {file} should end with .npz")
95
+ # Get parameters from the first run file
92
96
 
93
- data = [[int(T), res["m1_mean"]] for T, res in run["results"].items()]
97
+ first_results = load_results(self.files[0])
98
+ self.params = first_results.params.as_dict() # Store common parameters
99
+ data = []
100
+ # Iterating through run directories
101
+ for file in self.files:
102
+ print(f"Processing file: {file}")
103
+ run_results = load_results(file)
104
+ params = run_results.params.as_dict()
105
+ m1_mean = np.nan
106
+ if "observables" in run_results.results:
107
+ m1_mean = run_results.results["observables"].get("m1_mean", np.nan)
108
+ data.append([params["T"], m1_mean])
94
109
 
95
110
  data.sort() # Sorting by increasing temperatures
96
111
 
97
- return np.array(data), run
98
-
99
- @property
100
- def T_Curie(self) -> float:
101
- """
102
- Return the Curie temperature.
103
-
104
- It is defined as the temperature at which the magnetization is below 0.1.
105
-
106
- Returns:
107
- float: Curie temperature
108
- """
109
- i_max = np.where(0.1 - self.interp(self.T) > 0)[0].min()
110
- return self.T[i_max]
112
+ return np.array(data)
llg3d/post/utils.py ADDED
@@ -0,0 +1,38 @@
1
+ """Command-line argument parsing for post-processing scripts."""
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+
7
+ def get_cli_args(
8
+ description: str | None, default_image_filepath: Path
9
+ ) -> argparse.Namespace:
10
+ """
11
+ Parse command-line arguments for post-processing scripts.
12
+
13
+ Args:
14
+ description: Description of the script for the help message.
15
+ default_image_filepath: Default path to save the output image.
16
+
17
+ Returns:
18
+ Parsed command-line arguments.
19
+ """
20
+ parser = argparse.ArgumentParser(
21
+ description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
22
+ )
23
+ parser.add_argument("files", nargs="+", type=Path, help="Path to the result files.")
24
+ parser.add_argument(
25
+ "-i",
26
+ "--image_filepath",
27
+ type=Path,
28
+ default=default_image_filepath,
29
+ help="Path to save the image",
30
+ )
31
+ parser.add_argument(
32
+ "-s",
33
+ "--show",
34
+ action="store_true",
35
+ default=False,
36
+ help="Display the plot (omit to disable display in non-interactive runs).",
37
+ )
38
+ return parser.parse_args()
@@ -0,0 +1,161 @@
1
+ """
2
+ Plot x-profiles of magnetization from a result file.
3
+
4
+ Use the ``llg3d.x_profiles`` command line tool to plot the x-profiles of magnetization
5
+ from a result file:
6
+
7
+ .. command-output:: llg3d.x_profiles -h
8
+ :cwd: ../execute/domain_wall
9
+
10
+ See :doc:`here </execute/domain_wall/index>` for how to generate the ``run.npz`` file
11
+ used in this example.
12
+
13
+ When calling the tool on a result file:
14
+
15
+ .. command-output:: llg3d.x_profiles run.npz -m 1 -t ::4 -i x_profiles_4.png
16
+ :cwd: ../execute/domain_wall
17
+ :shell:
18
+
19
+ .. image:: ../execute/domain_wall/x_profiles_4.png
20
+ :alt: Longitudinal profiles of magnetization
21
+ """
22
+
23
+ import argparse
24
+ from pathlib import Path
25
+
26
+ from matplotlib import pyplot as plt
27
+
28
+ from ..grid import Grid
29
+ from ..io import RunResults, load_results
30
+ from .extract import extract_values
31
+
32
+
33
+ def parse_slice(slice_str: str) -> slice:
34
+ """
35
+ Parse a string representation of a slice into a slice object.
36
+
37
+ Handles formats like 'start:stop:step', 'start:stop', ':stop', '::step', etc.
38
+
39
+ Args:
40
+ slice_str: String representation of the slice.
41
+
42
+ Returns:
43
+ A slice object corresponding to the input string.
44
+ """
45
+ parts = slice_str.split(":")
46
+ # Convert parts to int or None
47
+ args = [int(p) if p else None for p in parts]
48
+ return slice(*args)
49
+
50
+
51
+ def plot_x_profiles(
52
+ file: Path,
53
+ image_filepath: Path | str | None = None,
54
+ show: bool = False,
55
+ time_slice: str = ":",
56
+ m_index: int = 0,
57
+ ):
58
+ """
59
+ Plot the results from the given files.
60
+
61
+ Args:
62
+ file: Path to the result file.
63
+ image_filepath: Path to the output image file.
64
+ if None, the image will not be saved.
65
+ show: display the graph in a graphical window.
66
+ time_slice: String representing the time slice to plot (e.g. ":-5", "::10").
67
+ m_index: Index of the magnetization component to plot (0, 1, or 2).
68
+ """
69
+ fig, ax = plt.subplots()
70
+ result: RunResults = load_results(file)
71
+ grid_params = extract_values(
72
+ file, "params/Jx", "params/Jy", "params/Jz", "params/dx"
73
+ )
74
+ grid = Grid(*grid_params) # type: ignore
75
+ x_coords = grid.get_x_coords(local=False)
76
+ x_profiles = result.get_record("x_profiles")
77
+
78
+ # Parse and apply slice
79
+ sl = parse_slice(time_slice)
80
+ times = x_profiles["t"][sl]
81
+ m_profiles = x_profiles[f"m{m_index}"][sl]
82
+
83
+ for i, time in enumerate(times):
84
+ ax.plot(x_coords, m_profiles[i], label=f"t = {time:.3e} s")
85
+ ax.set_xlabel("x")
86
+ ax.set_ylabel(rf"$m_{{{m_index}}}$")
87
+ ax.grid()
88
+ ax.legend()
89
+ ax.set_title(rf"Longitudinal profiles of magnetization $m_{{{m_index + 1}}}$")
90
+
91
+ if show:
92
+ plt.show()
93
+
94
+ if image_filepath is not None:
95
+ fig.savefig(image_filepath, dpi=300)
96
+ print(f"Written to {image_filepath}")
97
+
98
+
99
+ def get_cli_args(
100
+ description: str | None, default_image_filepath: Path
101
+ ) -> argparse.Namespace:
102
+ """
103
+ Parse command-line arguments for post-processing scripts.
104
+
105
+ Args:
106
+ description: Description of the script for the help message.
107
+ default_image_filepath: Default path to save the output image.
108
+
109
+ Returns:
110
+ Parsed command-line arguments.
111
+ """
112
+ parser = argparse.ArgumentParser(
113
+ description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
114
+ )
115
+ parser.add_argument("file", type=Path, help="Path to the result file.")
116
+ parser.add_argument(
117
+ "-i",
118
+ "--image_filepath",
119
+ type=Path,
120
+ default=default_image_filepath,
121
+ help="Path to save the image",
122
+ )
123
+ parser.add_argument(
124
+ "-s",
125
+ "--show",
126
+ action="store_true",
127
+ default=False,
128
+ help="Display the plot (omit to disable display in non-interactive runs).",
129
+ )
130
+ parser.add_argument(
131
+ "-t",
132
+ "--time",
133
+ type=str,
134
+ default=":",
135
+ help="Slice for time steps (e.g., '-5:' for last 5, '::10' for every 10th).",
136
+ )
137
+ parser.add_argument(
138
+ "-m",
139
+ "--m_index",
140
+ type=int,
141
+ choices=[1, 2, 3],
142
+ default=1,
143
+ help="Index of the magnetization component to plot (1, 2, or 3).",
144
+ )
145
+ return parser.parse_args()
146
+
147
+
148
+ def main(): # pragma: no cover
149
+ """Parse CLI arguments and call the plot function."""
150
+ args = get_cli_args(
151
+ description="Plot x-profiles of magnetization from a result file.",
152
+ default_image_filepath=Path("x_profile.png"),
153
+ )
154
+
155
+ plot_x_profiles(
156
+ args.file,
157
+ image_filepath=args.image_filepath,
158
+ show=args.show,
159
+ time_slice=args.time,
160
+ m_index=args.m_index,
161
+ )
llg3d/py.typed ADDED
@@ -0,0 +1 @@
1
+ # Marker file to indicate llg3d is PEP 561 type-annotated
@@ -0,0 +1,153 @@
1
+ """
2
+ Define various types of solvers.
3
+
4
+ Example:
5
+ To initialize one of the solver classes:
6
+
7
+ >>> from llg3d.parameters import RunParameters
8
+ >>> from llg3d.solvers.numpy import NumpySolver
9
+ >>> solver = NumpySolver(**RunParameters(solver="numpy").as_dict())
10
+ """
11
+
12
+ from typing import TYPE_CHECKING
13
+
14
+ import os
15
+ import importlib.util
16
+
17
+ if TYPE_CHECKING:
18
+ from .base import BaseSolver
19
+
20
+
21
+ def get_size() -> int:
22
+ """
23
+ Return the number of parallel MPI processes.
24
+
25
+ Use environment variables to avoid initializing MPI unnecessarily.
26
+
27
+ Returns:
28
+ Number of MPI processes if in an MPI environment, else 1.
29
+ """
30
+ # Open MPI
31
+ if "OMPI_COMM_WORLD_SIZE" in os.environ:
32
+ return int(os.environ["OMPI_COMM_WORLD_SIZE"])
33
+
34
+ # MPICH/Intel MPI
35
+ if "PMI_SIZE" in os.environ:
36
+ return int(os.environ["PMI_SIZE"])
37
+
38
+ # SLURM
39
+ if "SLURM_NTASKS" in os.environ:
40
+ return int(os.environ["SLURM_NTASKS"])
41
+
42
+ return 1
43
+
44
+
45
+ def get_rank() -> int:
46
+ """
47
+ Return the rank of the current MPI process.
48
+
49
+ Use environment variables to avoid initializing MPI unnecessarily.
50
+
51
+ Returns:
52
+ Rank of the current MPI process if in an MPI environment, else 0.
53
+ """
54
+ # PMIx
55
+ if "PMIX_RANK" in os.environ:
56
+ return int(os.environ["PMIX_RANK"])
57
+
58
+ # Open MPI
59
+ if "OMPI_COMM_WORLD_RANK" in os.environ:
60
+ return int(os.environ["OMPI_COMM_WORLD_RANK"])
61
+
62
+ # MPICH/Intel MPI
63
+ if "PMI_RANK" in os.environ:
64
+ return int(os.environ["PMI_RANK"])
65
+
66
+ # SLURM
67
+ if "SLURM_PROCID" in os.environ:
68
+ return int(os.environ["SLURM_PROCID"])
69
+
70
+ return 0
71
+
72
+
73
+ __all__ = [
74
+ "rank",
75
+ "size",
76
+ "comm",
77
+ "status",
78
+ "mpi_initialized",
79
+ "get_size",
80
+ "get_rank",
81
+ "get_solver_class",
82
+ "LIB_AVAILABLE",
83
+ ]
84
+
85
+ LIB_AVAILABLE: dict[str, bool] = {}
86
+
87
+ # Check for other solver availability
88
+ for lib in "pyopencl", "jax", "mpi4py":
89
+ LIB_AVAILABLE[lib] = importlib.util.find_spec(lib, package=__package__) is not None
90
+
91
+
92
+ # MPI library: initialize dummy variables at first:
93
+ # it prevents from initializing the MPI communicator if not needed
94
+ class _DummyComm:
95
+ pass
96
+
97
+
98
+ class _DummyStatus:
99
+ pass
100
+
101
+
102
+ comm = _DummyComm()
103
+ rank = get_rank()
104
+ size = get_size()
105
+ status = _DummyStatus()
106
+ mpi_initialized = False
107
+
108
+
109
+ def get_solver_class(solver_name: str) -> "type[BaseSolver]":
110
+ """
111
+ Get the solver class based on the solver name.
112
+
113
+ Args:
114
+ solver_name: Name of the solver ("mpi", "numpy", "opencl", "jax")
115
+
116
+ Returns:
117
+ The solver class
118
+
119
+ Raises:
120
+ ValueError: If the selected solver is not compatible with MPI
121
+ or if the solver name is unknown
122
+
123
+ Example:
124
+ >>> Solver = get_solver_class("numpy")
125
+ >>> Solver.__name__
126
+ "NumpySolver"
127
+ """
128
+ if size > 1 and solver_name != "mpi":
129
+ raise ValueError(f"Solver method '{solver_name}' is not compatible with MPI.")
130
+
131
+ Solver: type[BaseSolver]
132
+ if solver_name == "mpi":
133
+ from .mpi import MPISolver
134
+
135
+ Solver = MPISolver
136
+ elif solver_name == "numpy":
137
+ from .numpy import NumpySolver
138
+
139
+ Solver = NumpySolver
140
+
141
+ elif solver_name == "opencl":
142
+ from .opencl import OpenCLSolver
143
+
144
+ Solver = OpenCLSolver
145
+
146
+ elif solver_name == "jax":
147
+ from .experimental.jax import JaxSolver
148
+
149
+ Solver = JaxSolver
150
+ else:
151
+ raise ValueError(f"Unknown solver method '{solver_name}'.")
152
+
153
+ return Solver