llg3d 2.0.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. llg3d/__init__.py +3 -3
  2. llg3d/__main__.py +2 -2
  3. llg3d/benchmarks/__init__.py +1 -0
  4. llg3d/benchmarks/compare_commits.py +321 -0
  5. llg3d/benchmarks/efficiency.py +451 -0
  6. llg3d/benchmarks/utils.py +25 -0
  7. llg3d/element.py +118 -31
  8. llg3d/grid.py +51 -64
  9. llg3d/io.py +395 -0
  10. llg3d/main.py +36 -38
  11. llg3d/parameters.py +159 -49
  12. llg3d/post/__init__.py +1 -1
  13. llg3d/post/extract.py +105 -0
  14. llg3d/post/info.py +178 -0
  15. llg3d/post/m1_vs_T.py +90 -0
  16. llg3d/post/m1_vs_time.py +56 -0
  17. llg3d/post/process.py +82 -75
  18. llg3d/post/utils.py +38 -0
  19. llg3d/post/x_profiles.py +141 -0
  20. llg3d/py.typed +1 -0
  21. llg3d/solvers/__init__.py +153 -0
  22. llg3d/solvers/base.py +345 -0
  23. llg3d/solvers/experimental/__init__.py +9 -0
  24. llg3d/solvers/experimental/jax.py +361 -0
  25. llg3d/solvers/math_utils.py +41 -0
  26. llg3d/solvers/mpi.py +370 -0
  27. llg3d/solvers/numpy.py +126 -0
  28. llg3d/solvers/opencl.py +439 -0
  29. llg3d/solvers/profiling.py +38 -0
  30. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/METADATA +6 -3
  31. llg3d-3.0.0.dist-info/RECORD +36 -0
  32. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/WHEEL +1 -1
  33. llg3d-3.0.0.dist-info/entry_points.txt +9 -0
  34. llg3d/output.py +0 -108
  35. llg3d/post/plot_results.py +0 -65
  36. llg3d/post/temperature.py +0 -83
  37. llg3d/simulation.py +0 -104
  38. llg3d/solver/__init__.py +0 -45
  39. llg3d/solver/jax.py +0 -383
  40. llg3d/solver/mpi.py +0 -449
  41. llg3d/solver/numpy.py +0 -210
  42. llg3d/solver/opencl.py +0 -329
  43. llg3d/solver/solver.py +0 -93
  44. llg3d-2.0.0.dist-info/RECORD +0 -25
  45. llg3d-2.0.0.dist-info/entry_points.txt +0 -4
  46. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/AUTHORS +0 -0
  47. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/LICENSE +0 -0
  48. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/top_level.txt +0 -0
@@ -1,65 +0,0 @@
1
- """
2
- Plot 1D curves from several files
3
-
4
- Usage:
5
-
6
- python plot_results.py file1.txt
7
- or
8
- python plot_results.py file1.txt file2.txt file3.txt
9
-
10
- """
11
-
12
- import argparse
13
- from matplotlib import pyplot as plt
14
- import numpy as np
15
-
16
-
17
- DEFAULT_OUTPUT_FILE = "results.png"
18
-
19
-
20
- def plot(*files: tuple[str], output_file: str = DEFAULT_OUTPUT_FILE):
21
- """
22
- Plot the results from the given files.
23
-
24
- Args:
25
- files (tuple[str]): Paths to the result files.
26
- output_file (str): Path to the output image file.
27
- """
28
-
29
- fig, ax = plt.subplots()
30
- for file in files:
31
- if not file.endswith(".txt"):
32
- raise ValueError(f"File {file} does not end with .txt")
33
- data = np.loadtxt(file)
34
- ax.plot(data[:, 0], data[:, 1], label=file)
35
-
36
- ax.set_xlabel("time")
37
- ax.set_ylabel(r"$<m_1>$")
38
- ax.legend()
39
- ax.set_title(r"Space average of $m_1$ according to time")
40
- fig.savefig(output_file)
41
- print(f"Written to {output_file}")
42
- plt.show()
43
-
44
-
45
- def main():
46
- parser = argparse.ArgumentParser(
47
- description="Plot results from one or more files."
48
- )
49
- parser.add_argument(
50
- "files", nargs="+", type=str, help="Path to the result files."
51
- )
52
- parser.add_argument(
53
- "--output",
54
- "-o",
55
- type=str,
56
- default=DEFAULT_OUTPUT_FILE,
57
- help=f"Path to the output image file (default: {DEFAULT_OUTPUT_FILE}).",
58
- )
59
- args = parser.parse_args()
60
-
61
- plot(*args.files, output_file=args.output)
62
-
63
-
64
- if __name__ == "__main__":
65
- main()
llg3d/post/temperature.py DELETED
@@ -1,83 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Plot the magnetization vs temperature and determine the Curie temperature.
4
- """
5
-
6
- import argparse
7
- from pathlib import Path
8
-
9
- import matplotlib.pyplot as plt
10
-
11
- from .process import MagData
12
-
13
-
14
- def plot_m_vs_T(m: MagData, show: bool):
15
- """
16
- Plots the data (T, <m>), interpolates the values,
17
- calculates the Curie temperature.
18
- Exports to PNG.
19
-
20
- Args:
21
- data: numpy array (T, <m>)
22
- parentdir: path to the directory containing the runs
23
- run: descriptive dictionary of the run
24
- show: display the graph in a graphical window
25
- """
26
-
27
- print(f"T_Curie = {m.T_Curie:.0f} K")
28
-
29
- fig, ax = plt.subplots()
30
- fig.suptitle("Average magnetization vs Temperature")
31
- params = m.run["params"]
32
- ax.set_title(
33
- params["element"]
34
- + rf", ${params['Jx']}\times{params['Jy']}\times{params['Jz']}$"
35
- rf" ($dx = ${params['dx']})",
36
- fontdict={"size": 10},
37
- )
38
- ax.plot(m.temperature, m.m1_mean, "o", label="computed")
39
- ax.plot(m.T, m.interp(m.T), label="interpolated (cubic)")
40
- ax.annotate(
41
- "$T_{{Curie}} = {:.0f} K$".format(m.T_Curie),
42
- xy=(m.T_Curie, m.interp(m.T_Curie)),
43
- xytext=(m.T_Curie + 20, m.interp(m.T_Curie) + 0.01),
44
- )
45
- ax.axvline(x=m.T_Curie, color="k")
46
- ax.set_xlabel("Temperature [K]")
47
- ax.set_ylabel("Magnetization")
48
- ax.legend()
49
-
50
- if show:
51
- plt.show()
52
-
53
- image_filename = m.parentpath / "m1_mean.png"
54
- fig.savefig(image_filename)
55
- print(f"Image saved in {image_filename}")
56
-
57
-
58
- def main():
59
- """
60
- Parses the command line to execute processing functions
61
- """
62
- parser = argparse.ArgumentParser(description=__doc__)
63
- parser.add_argument("--job_dir", type=Path, help="Slurm main job directory")
64
- parser.add_argument(
65
- "--run_file", type=Path, default="run.json", help="Path to the run.json file"
66
- )
67
- parser.add_argument(
68
- "-s",
69
- "--show",
70
- action="store_true",
71
- default=False,
72
- help="Display the graph in a graphical window",
73
- )
74
- args = parser.parse_args()
75
- if args.job_dir:
76
- m = MagData(job_dir=args.job_dir)
77
- else:
78
- m = MagData(run_file=args.run_file)
79
- plot_m_vs_T(m, args.show)
80
-
81
-
82
- if __name__ == "__main__":
83
- main()
llg3d/simulation.py DELETED
@@ -1,104 +0,0 @@
1
- """
2
- Define the Simulation class.
3
-
4
- Example usage:
5
-
6
- >>> from llg3d.main import Simulation
7
- >>> from llg3d.parameters import parameters
8
- >>> run_parameters = {name: value["default"] for name, value in parameters.items()}
9
- >>> sim = Simulation(run_parameters)
10
- >>> sim.run()
11
- >>> sim.save()
12
-
13
- """
14
-
15
- import inspect
16
-
17
- from . import rank, size
18
- from .element import get_element_class
19
- from .parameters import Parameter
20
- from .output import write_json
21
-
22
-
23
- class Simulation:
24
- """
25
- Class to encapsulate the simulation logic.
26
- """
27
-
28
- json_file = "run.json" #: JSON file to store the results
29
-
30
- def __init__(self, params: dict[str, Parameter]):
31
- """
32
- Initializes the simulation with parameters.
33
-
34
- Args:
35
- params: Dictionary of simulation parameters.
36
- """
37
- self.params: dict[str, Parameter] = params.copy() #: simulation parameters
38
- self.simulate: callable = self._get_simulate_function_from_name(
39
- self.params["solver"]
40
- ) #: simulation function imported from the solver module
41
- self.total_time: None | float = None #: total simulation time
42
- self.filenames: list[str] = [] #: list of output filenames
43
- self.m1_mean: None | float = None #: space and time average of m1
44
- self.params["np"] = size # Add a parameter for the number of processes
45
- # Reference the element class from the element string
46
- self.params["element_class"] = get_element_class(params["element"])
47
-
48
- def run(self):
49
- """
50
- Runs the simulation and store the results.
51
- """
52
- self.total_time, self.filenames, self.m1_mean = self.simulate(**self.params)
53
-
54
- def _get_simulate_function_from_name(self, name: str) -> callable:
55
- """
56
- Retrieves the simulation function for a given solver name.
57
-
58
- Args:
59
- name: Name of the solver
60
-
61
- Returns:
62
- callable: The simulation function
63
-
64
- Example:
65
-
66
- >>> simulate = self.get_simulate_function_from_name("mpi")
67
-
68
- Will return the `simulate` function from the `llg3d.solver.mpi` module.
69
- """
70
-
71
- module = __import__(f"llg3d.solver.{name}", fromlist=["simulate"])
72
- return inspect.getattr_static(module, "simulate")
73
-
74
- def save(self):
75
- """
76
- Saves the results of the simulation to a JSON file.
77
- """
78
- params = self.params.copy() # save the parameters
79
- del params["element_class"] # remove class object before serialization
80
- if rank == 0:
81
- results = {"total_time": self.total_time}
82
- # Export the integral of m1
83
- if len(self.filenames) > 0:
84
- results["integral_file"] = self.filenames[0]
85
- print(f"Integral of m1 in {self.filenames[0]}")
86
- # Export the x-profiles of m1, m2 and m3
87
- for i, filename in enumerate(self.filenames[1:]):
88
- results[f"xprofile_m{i}"] = filename
89
- print(f"x-profile of m{i} in {filename}")
90
-
91
- print(
92
- f"""\
93
- N iterations = {params["N"]}
94
- total_time [s] = {self.total_time:.03f}
95
- time/ite [s/iter] = {self.total_time / params["N"]:.03e}\
96
- """
97
- )
98
- # Export the mean of m1
99
- if params["N"] > params["start_averaging"]:
100
- print(f"m1_mean = {self.m1_mean:e}")
101
- results["m1_mean"] = float(self.m1_mean)
102
-
103
- write_json(self.json_file, {"params": params, "results": results})
104
- print(f"Summary in {self.json_file}")
llg3d/solver/__init__.py DELETED
@@ -1,45 +0,0 @@
1
- """
2
- Solver module for LLG3D
3
-
4
- This module contains different solver implementations.
5
- """
6
-
7
- import importlib.util
8
-
9
-
10
- __all__ = ["numpy", "solver", "rank", "size", "comm", "status"]
11
-
12
- LIB_AVAILABLE: dict[str, bool] = {}
13
-
14
- # Check for other solver availability
15
- for lib in "opencl", "jax", "mpi4py":
16
- if importlib.util.find_spec(lib, package=__package__) is not None:
17
- LIB_AVAILABLE[lib] = True
18
- __all__.append(lib)
19
- else:
20
- LIB_AVAILABLE[lib] = False
21
-
22
-
23
- # MPI utilities
24
- if LIB_AVAILABLE["mpi4py"]:
25
- from mpi4py import MPI
26
-
27
- comm = MPI.COMM_WORLD
28
- rank = comm.Get_rank()
29
- size = comm.Get_size()
30
- status = MPI.Status()
31
- else:
32
- # MPI library is not available: use dummy values
33
- class DummyComm:
34
- pass
35
-
36
- comm = DummyComm()
37
- rank = 0
38
- size = 1
39
-
40
- class DummyStatus:
41
- pass
42
-
43
- status = DummyStatus()
44
-
45
- from . import numpy, solver
llg3d/solver/jax.py DELETED
@@ -1,383 +0,0 @@
1
- """
2
- LLG3D solver using XLA compilation
3
- """
4
-
5
- import os
6
- import time
7
-
8
- import jax
9
- import jax.numpy as jnp
10
- from jax import random
11
-
12
- from ..output import progress_bar, get_output_files, close_output_files
13
- from ..grid import Grid
14
- from ..element import Element
15
-
16
-
17
- # JIT compile individual components for better performance and modularity
18
- @jax.jit
19
- def compute_H_anisotropy(
20
- m: jnp.ndarray, coeff_2: float, anisotropy: int
21
- ) -> jnp.ndarray:
22
- """
23
- Compute anisotropy field (JIT compiled)
24
-
25
- Args:
26
- m: Magnetization array (shape (3, nx, ny, nz))
27
- coeff_2: Coefficient for anisotropy
28
- anisotropy: Anisotropy type (0: uniaxial, 1: cubic)
29
-
30
- Returns:
31
- Anisotropy field array (shape (3, nx, ny, nz))
32
- """
33
-
34
- m1, m2, m3 = m
35
-
36
- m1m1 = m1 * m1
37
- m2m2 = m2 * m2
38
- m3m3 = m3 * m3
39
-
40
- # Uniaxial anisotropy
41
- aniso_1_uniaxial = m1
42
- aniso_2_uniaxial = jnp.zeros_like(m1)
43
- aniso_3_uniaxial = jnp.zeros_like(m1)
44
-
45
- # Cubic anisotropy
46
- aniso_1_cubic = -(1 - m1m1 + m2m2 * m3m3) * m1
47
- aniso_2_cubic = -(1 - m2m2 + m1m1 * m3m3) * m2
48
- aniso_3_cubic = -(1 - m3m3 + m1m1 * m2m2) * m3
49
-
50
- # Select based on anisotropy type
51
- aniso_1 = jnp.where(
52
- anisotropy == 0,
53
- aniso_1_uniaxial,
54
- jnp.where(anisotropy == 1, aniso_1_cubic, jnp.zeros_like(m1)),
55
- )
56
- aniso_2 = jnp.where(
57
- anisotropy == 0,
58
- aniso_2_uniaxial,
59
- jnp.where(anisotropy == 1, aniso_2_cubic, jnp.zeros_like(m1)),
60
- )
61
- aniso_3 = jnp.where(
62
- anisotropy == 0,
63
- aniso_3_uniaxial,
64
- jnp.where(anisotropy == 1, aniso_3_cubic, jnp.zeros_like(m1)),
65
- )
66
-
67
- return coeff_2 * jnp.stack([aniso_1, aniso_2, aniso_3], axis=0)
68
-
69
-
70
- @jax.jit
71
- def laplacian3D(
72
- m_i: jnp.ndarray, dx2_inv: float, dy2_inv: float, dz2_inv: float, center_coeff: float
73
- ) -> jnp.ndarray:
74
- """
75
- Compute Laplacian for a single component with Neumann boundary conditions (JIT compiled)
76
-
77
- Args:
78
- m_i: Single component of magnetization (shape (nx, ny, nz))
79
- dx2_inv: Inverse of squared grid spacing in x direction
80
- dy2_inv: Inverse of squared grid spacing in y direction
81
- dz2_inv: Inverse of squared grid spacing in z direction
82
- center_coeff: Coefficient for the center point
83
-
84
- Returns:
85
- Laplacian of m_i (shape (nx, ny, nz))
86
- """
87
- m_i_padded = jnp.pad(m_i, ((1, 1), (1, 1), (1, 1)), mode="reflect")
88
- return (
89
- dx2_inv * (m_i_padded[2:, 1:-1, 1:-1] + m_i_padded[:-2, 1:-1, 1:-1])
90
- + dy2_inv * (m_i_padded[1:-1, 2:, 1:-1] + m_i_padded[1:-1, :-2, 1:-1])
91
- + dz2_inv * (m_i_padded[1:-1, 1:-1, 2:] + m_i_padded[1:-1, 1:-1, :-2])
92
- + center_coeff * m_i
93
- )
94
-
95
-
96
- @jax.jit
97
- def compute_laplacian(
98
- m: jnp.ndarray, dx: float, dy: float, dz: float
99
- ) -> jnp.ndarray:
100
- """
101
- Compute 3D Laplacian with Neumann boundary conditions (JIT compiled)
102
-
103
- Args:
104
- m: Magnetization array (shape (3, nx, ny, nz))
105
- dx: Grid spacing in x direction
106
- dy: Grid spacing in y direction
107
- dz: Grid spacing in z direction
108
-
109
- Returns:
110
- Laplacian of m (shape (3, nx, ny, nz))
111
- """
112
- dx2_inv, dy2_inv, dz2_inv = 1 / dx**2, 1 / dy**2, 1 / dz**2
113
- center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
114
-
115
- return jnp.stack(
116
- [
117
- laplacian3D(m[0], dx2_inv, dy2_inv, dz2_inv, center_coeff),
118
- laplacian3D(m[1], dx2_inv, dy2_inv, dz2_inv, center_coeff),
119
- laplacian3D(m[2], dx2_inv, dy2_inv, dz2_inv, center_coeff),
120
- ],
121
- axis=0,
122
- )
123
-
124
-
125
- @jax.jit
126
- def compute_space_average_jax(m1: jnp.ndarray) -> float:
127
- """
128
- Compute space average using midpoint method on GPU (JIT compiled)
129
-
130
- Args:
131
- m1: First component of magnetization (shape (nx, ny, nz))
132
-
133
- Returns:
134
- Space average of m1
135
- """
136
- # Get dimensions directly from the array shape
137
- Jx, Jy, Jz = m1.shape
138
-
139
- # Create 3D coordinate grids using the shape
140
- i_coords = jnp.arange(Jx)
141
- j_coords = jnp.arange(Jy)
142
- k_coords = jnp.arange(Jz)
143
-
144
- # Create 3D coordinate grids
145
- ii, jj, kk = jnp.meshgrid(i_coords, j_coords, k_coords, indexing='ij')
146
-
147
- # Apply midpoint weights (0.5 on edges, 1.0 elsewhere)
148
- weights = jnp.ones_like(m1)
149
- weights = jnp.where((ii == 0) | (ii == Jx-1), weights * 0.5, weights)
150
- weights = jnp.where((jj == 0) | (jj == Jy-1), weights * 0.5, weights)
151
- weights = jnp.where((kk == 0) | (kk == Jz-1), weights * 0.5, weights)
152
-
153
- # Compute weighted sum and normalize
154
- weighted_sum = jnp.sum(weights * m1)
155
-
156
- # Compute ncell from the weights (this is the effective cell count)
157
- ncell = jnp.sum(weights)
158
-
159
- return weighted_sum / ncell
160
-
161
-
162
- @jax.jit
163
- def cross_product(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
164
- r"""
165
- Compute cross product :math:`a \times b` (JIT compiled)
166
-
167
- Args:
168
- a: First vector (shape (3, nx, ny, nz))
169
- b: Second vector (shape (3, nx, ny, nz))
170
-
171
- Returns:
172
- Cross product :math:`a \times b` (shape (3, nx, ny, nz))
173
- """
174
- # Use JAX's optimized cross product function directly on axis 0
175
- return jnp.cross(a, b, axis=0)
176
-
177
-
178
- # JIT compile the slope computation for performance
179
- @jax.jit
180
- def compute_slope(
181
- g_params: dict, e_params: dict, m: jnp.ndarray, R_random: jnp.ndarray
182
- ) -> jnp.ndarray:
183
- """
184
- JIT-compiled version of compute_slope_jax using modular sub-functions
185
-
186
- Args:
187
- g_params: Grid parameters dict (dx, dy, dz)
188
- e_params: Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G, anisotropy)
189
- m: Magnetization array (shape (3, nx, ny, nz))
190
- R_random: Random field array (shape (3, nx, ny, nz))
191
-
192
- Returns:
193
- Slope array (shape (3, nx, ny, nz))
194
- """
195
- # Extract parameters
196
- dx, dy, dz = g_params["dx"], g_params["dy"], g_params["dz"]
197
- coeff_1 = e_params["coeff_1"]
198
- coeff_2 = e_params["coeff_2"]
199
- coeff_3 = e_params["coeff_3"]
200
- lambda_G = e_params["lambda_G"]
201
- anisotropy = e_params["anisotropy"]
202
-
203
- # Compute components using modular sub-functions
204
- H_aniso = compute_H_anisotropy(m, coeff_2, anisotropy)
205
- laplacian_m = compute_laplacian(m, dx, dy, dz)
206
-
207
- # Effective field
208
- R_eff = coeff_1 * laplacian_m + R_random + H_aniso
209
- R_eff = R_eff.at[0].add(coeff_3)
210
-
211
- # Cross products using modular functions
212
- m_cross_R_eff = cross_product(m, R_eff)
213
- m_cross_m_cross_R_eff = cross_product(m, m_cross_R_eff)
214
-
215
- return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
216
-
217
-
218
- def simulate(
219
- N: int,
220
- Jx: int,
221
- Jy: int,
222
- Jz: int,
223
- dx: float,
224
- T: float,
225
- H_ext: float,
226
- dt: float,
227
- start_averaging: int,
228
- n_mean: int,
229
- n_profile: int,
230
- element_class: Element,
231
- precision: str,
232
- seed: int,
233
- device: str = "auto",
234
- **_,
235
- ) -> tuple[float, str, float]:
236
- """
237
- Simulates the system for N iterations using JAX
238
-
239
- Args:
240
- N: Number of iterations
241
- Jx: Number of grid points in x direction
242
- Jy: Number of grid points in y direction
243
- Jz: Number of grid points in z direction
244
- dx: Grid spacing
245
- T: Temperature in Kelvin
246
- H_ext: External magnetic field strength
247
- dt: Time step for the simulation
248
- start_averaging: Number of iterations for averaging
249
- n_mean: Number of iterations for integral output
250
- n_profile: Number of iterations for profile output
251
- element_class: Element of the sample (default: Cobalt)
252
- precision: Precision of the simulation (single or double)
253
- seed: Random seed for temperature fluctuations
254
- device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
255
-
256
- Returns:
257
- - The time taken for the simulation
258
- - The output filenames
259
- - The average magnetization
260
- """
261
-
262
- # Configure JAX
263
- if device == "auto":
264
- # Let JAX choose the best available device
265
- pass
266
- elif device == "cpu":
267
- jax.config.update("jax_platform_name", "cpu")
268
- elif device == "gpu":
269
- jax.config.update("jax_platform_name", "gpu")
270
- elif device.startswith("gpu:"):
271
- # Select specific GPU using environment variable
272
- jax.config.update("jax_platform_name", "gpu")
273
- gpu_id = device.split(":")[1]
274
- # Check if CUDA_VISIBLE_DEVICES is already set externally
275
- if "CUDA_VISIBLE_DEVICES" not in os.environ:
276
- os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
277
- print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
278
- else:
279
- print(f"Using external CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
280
-
281
- # Set precision
282
- if precision == "double":
283
- jax.config.update("jax_enable_x64", True)
284
- jnp_float = jnp.float64
285
- else:
286
- jax.config.update("jax_enable_x64", False)
287
- jnp_float = jnp.float32
288
-
289
- print(f"Available JAX devices: {jax.devices()}")
290
- print(f"Using JAX on device: {jax.devices()[0]}")
291
- print(f"Precision: {precision} ({jnp_float})")
292
-
293
- # Initialize random key for JAX
294
- key = random.PRNGKey(seed)
295
-
296
- g = Grid(Jx, Jy, Jz, dx)
297
- dims = g.dims
298
-
299
- e = element_class(T, H_ext, g, dt)
300
- print(f"CFL = {e.get_CFL()}")
301
-
302
- # Prepare parameters for JIT compilation using to_dict methods
303
- g_params = g.to_dict()
304
- e_params = e.to_dict()
305
-
306
- # --- Initialization ---
307
- def theta_init(shape):
308
- """Initialization of theta"""
309
- return jnp.zeros(shape, dtype=jnp_float)
310
-
311
- def phi_init(t, shape):
312
- """Initialization of phi"""
313
- return jnp.zeros(shape, dtype=jnp_float) + e.gamma_0 * H_ext * t
314
-
315
- m_n = jnp.zeros((3,) + dims, dtype=jnp_float)
316
-
317
- theta = theta_init(dims)
318
- phi = phi_init(0, dims)
319
-
320
- m_n = m_n.at[0].set(jnp.cos(theta))
321
- m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
322
- m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
323
-
324
- f_mean, f_profiles, output_filenames = get_output_files(g, T, n_mean, n_profile)
325
-
326
- t = 0.0
327
- m1_average = 0.0
328
-
329
- # === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
330
- print("Warming up JIT compilation...")
331
-
332
- # Generate dummy random field for warmup
333
- warmup_key = random.PRNGKey(42)
334
- R_warmup = e.coeff_4 * random.normal(warmup_key, (3,) + dims, dtype=jnp_float)
335
-
336
- # Warmup all JIT functions with actual data shapes
337
- _ = compute_slope(g_params, e_params, m_n, R_warmup)
338
- if n_mean != 0:
339
- _ = compute_space_average_jax(m_n[0])
340
-
341
- # Force compilation and execution to complete
342
- jax.block_until_ready(m_n)
343
- print("JIT warmup completed.")
344
-
345
- start_time = time.perf_counter()
346
-
347
- for n in progress_bar(range(1, N + 1), "Iteration : ", 40):
348
- t += dt
349
-
350
- # Generate random field for temperature effect
351
- key, subkey = random.split(key)
352
- R_random = e.coeff_4 * random.normal(subkey, (3,) + dims, dtype=jnp_float)
353
-
354
- # Use JIT-compiled version for better performance
355
- s_pre = compute_slope(g_params, e_params, m_n, R_random)
356
- m_pre = m_n + dt * s_pre
357
- s_cor = compute_slope(g_params, e_params, m_pre, R_random)
358
-
359
- # Update magnetization
360
- m_n = m_n + dt * 0.5 * (s_pre + s_cor)
361
-
362
- # Renormalize to unit sphere
363
- norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
364
- m_n = m_n / norm
365
-
366
- # Export the average of m1 to a file (optimized for GPU)
367
- if n_mean != 0 and n % n_mean == 0:
368
- # Compute space average directly on GPU without CPU transfer
369
- m1_mean = compute_space_average_jax(m_n[0])
370
- # Convert to Python float for file writing
371
- m1_mean = float(m1_mean)
372
- if n >= start_averaging:
373
- m1_average += m1_mean * n_mean
374
- f_mean.write(f"{t:10.8e} {m1_mean:10.8e}\n")
375
-
376
- total_time = time.perf_counter() - start_time
377
-
378
- close_output_files(f_mean, f_profiles)
379
-
380
- if n > start_averaging:
381
- m1_average /= N - start_averaging
382
-
383
- return total_time, output_filenames, m1_average