llg3d 1.4.1__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llg3d/__init__.py +6 -1
- llg3d/__main__.py +6 -0
- llg3d/element.py +134 -0
- llg3d/grid.py +123 -0
- llg3d/main.py +67 -0
- llg3d/output.py +107 -0
- llg3d/parameters.py +75 -0
- llg3d/post/__init__.py +1 -1
- llg3d/post/plot_results.py +61 -0
- llg3d/post/process.py +18 -13
- llg3d/post/temperature.py +6 -14
- llg3d/simulation.py +95 -0
- llg3d/solver/__init__.py +45 -0
- llg3d/solver/jax.py +387 -0
- llg3d/solver/mpi.py +450 -0
- llg3d/solver/numpy.py +207 -0
- llg3d/solver/opencl.py +330 -0
- llg3d/solver/solver.py +89 -0
- {llg3d-1.4.1.dist-info → llg3d-2.0.1.dist-info}/METADATA +14 -22
- llg3d-2.0.1.dist-info/RECORD +25 -0
- {llg3d-1.4.1.dist-info → llg3d-2.0.1.dist-info}/WHEEL +1 -1
- llg3d-2.0.1.dist-info/entry_points.txt +4 -0
- llg3d/llg3d.py +0 -742
- llg3d/llg3d_seq.py +0 -447
- llg3d-1.4.1.dist-info/RECORD +0 -13
- llg3d-1.4.1.dist-info/entry_points.txt +0 -3
- {llg3d-1.4.1.dist-info → llg3d-2.0.1.dist-info/licenses}/AUTHORS +0 -0
- {llg3d-1.4.1.dist-info → llg3d-2.0.1.dist-info/licenses}/LICENSE +0 -0
- {llg3d-1.4.1.dist-info → llg3d-2.0.1.dist-info}/top_level.txt +0 -0
llg3d/post/temperature.py
CHANGED
|
@@ -1,30 +1,24 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
Plot the magnetization vs temperature and determine the Curie temperature.
|
|
4
|
-
"""
|
|
2
|
+
"""Plot the magnetization vs temperature and determine the Curie temperature."""
|
|
5
3
|
|
|
6
4
|
import argparse
|
|
7
5
|
from pathlib import Path
|
|
8
6
|
|
|
9
7
|
import matplotlib.pyplot as plt
|
|
10
|
-
import numpy as np
|
|
11
8
|
|
|
12
9
|
from .process import MagData
|
|
13
10
|
|
|
14
11
|
|
|
15
12
|
def plot_m_vs_T(m: MagData, show: bool):
|
|
16
13
|
"""
|
|
17
|
-
Plots the data (T, <m>)
|
|
18
|
-
|
|
19
|
-
|
|
14
|
+
Plots the data (T, <m>).
|
|
15
|
+
|
|
16
|
+
Interpolates the values, calculates the Curie temperature, exports to PNG.
|
|
20
17
|
|
|
21
18
|
Args:
|
|
22
|
-
|
|
23
|
-
parentdir: path to the directory containing the runs
|
|
24
|
-
run: descriptive dictionary of the run
|
|
19
|
+
m: Magnetization data object
|
|
25
20
|
show: display the graph in a graphical window
|
|
26
21
|
"""
|
|
27
|
-
|
|
28
22
|
print(f"T_Curie = {m.T_Curie:.0f} K")
|
|
29
23
|
|
|
30
24
|
fig, ax = plt.subplots()
|
|
@@ -57,9 +51,7 @@ def plot_m_vs_T(m: MagData, show: bool):
|
|
|
57
51
|
|
|
58
52
|
|
|
59
53
|
def main():
|
|
60
|
-
"""
|
|
61
|
-
Parses the command line to execute processing functions
|
|
62
|
-
"""
|
|
54
|
+
"""Parses the command line to execute processing functions."""
|
|
63
55
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
64
56
|
parser.add_argument("--job_dir", type=Path, help="Slurm main job directory")
|
|
65
57
|
parser.add_argument(
|
llg3d/simulation.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Define the Simulation class.
|
|
3
|
+
|
|
4
|
+
Example usage:
|
|
5
|
+
|
|
6
|
+
>>> from llg3d.main import Simulation
|
|
7
|
+
>>> from llg3d.parameters import parameters
|
|
8
|
+
>>> run_parameters = {name: value["default"] for name, value in parameters.items()}
|
|
9
|
+
>>> sim = Simulation(run_parameters)
|
|
10
|
+
>>> sim.run()
|
|
11
|
+
>>> sim.save()
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
|
|
17
|
+
from . import rank, size
|
|
18
|
+
from .element import get_element_class
|
|
19
|
+
from .parameters import Parameter
|
|
20
|
+
from .output import write_json
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Simulation:
|
|
24
|
+
"""
|
|
25
|
+
Class to encapsulate the simulation logic.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
params: Dictionary of simulation parameters.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
json_file = "run.json" #: JSON file to store the results
|
|
32
|
+
|
|
33
|
+
def __init__(self, params: dict[str, Parameter]):
|
|
34
|
+
self.params: dict[str, Parameter] = params.copy() #: simulation parameters
|
|
35
|
+
self.simulate: callable = self._get_simulate_function_from_name(
|
|
36
|
+
self.params["solver"]
|
|
37
|
+
) #: simulation function imported from the solver module
|
|
38
|
+
self.total_time: None | float = None #: total simulation time
|
|
39
|
+
self.filenames: list[str] = [] #: list of output filenames
|
|
40
|
+
self.m1_mean: None | float = None #: space and time average of m1
|
|
41
|
+
self.params["np"] = size # Add a parameter for the number of processes
|
|
42
|
+
# Reference the element class from the element string
|
|
43
|
+
self.params["element_class"] = get_element_class(params["element"])
|
|
44
|
+
|
|
45
|
+
def run(self):
|
|
46
|
+
"""Runs the simulation and store the results."""
|
|
47
|
+
self.total_time, self.filenames, self.m1_mean = self.simulate(**self.params)
|
|
48
|
+
|
|
49
|
+
def _get_simulate_function_from_name(self, name: str) -> callable:
|
|
50
|
+
"""
|
|
51
|
+
Retrieves the simulation function for a given solver name.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
name: Name of the solver
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
callable: The simulation function
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
>>> simulate = self.get_simulate_function_from_name("mpi")
|
|
61
|
+
|
|
62
|
+
Will return the `simulate` function from the `llg3d.solver.mpi` module.
|
|
63
|
+
"""
|
|
64
|
+
module = __import__(f"llg3d.solver.{name}", fromlist=["simulate"])
|
|
65
|
+
return inspect.getattr_static(module, "simulate")
|
|
66
|
+
|
|
67
|
+
def save(self):
|
|
68
|
+
"""Saves the results of the simulation to a JSON file."""
|
|
69
|
+
params = self.params.copy() # save the parameters
|
|
70
|
+
del params["element_class"] # remove class object before serialization
|
|
71
|
+
if rank == 0:
|
|
72
|
+
results = {"total_time": self.total_time}
|
|
73
|
+
# Export the integral of m1
|
|
74
|
+
if len(self.filenames) > 0:
|
|
75
|
+
results["integral_file"] = self.filenames[0]
|
|
76
|
+
print(f"Integral of m1 in {self.filenames[0]}")
|
|
77
|
+
# Export the x-profiles of m1, m2 and m3
|
|
78
|
+
for i, filename in enumerate(self.filenames[1:]):
|
|
79
|
+
results[f"xprofile_m{i}"] = filename
|
|
80
|
+
print(f"x-profile of m{i} in {filename}")
|
|
81
|
+
|
|
82
|
+
print(
|
|
83
|
+
f"""\
|
|
84
|
+
N iterations = {params["N"]}
|
|
85
|
+
total_time [s] = {self.total_time:.03f}
|
|
86
|
+
time/ite [s/iter] = {self.total_time / params["N"]:.03e}\
|
|
87
|
+
"""
|
|
88
|
+
)
|
|
89
|
+
# Export the mean of m1
|
|
90
|
+
if params["N"] > params["start_averaging"]:
|
|
91
|
+
print(f"m1_mean = {self.m1_mean:e}")
|
|
92
|
+
results["m1_mean"] = float(self.m1_mean)
|
|
93
|
+
|
|
94
|
+
write_json(self.json_file, {"params": params, "results": results})
|
|
95
|
+
print(f"Summary in {self.json_file}")
|
llg3d/solver/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Solver module for LLG3D.
|
|
3
|
+
|
|
4
|
+
This module contains different solver implementations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = ["numpy", "solver", "rank", "size", "comm", "status"]
|
|
11
|
+
|
|
12
|
+
LIB_AVAILABLE: dict[str, bool] = {}
|
|
13
|
+
|
|
14
|
+
# Check for other solver availability
|
|
15
|
+
for lib in "opencl", "jax", "mpi4py":
|
|
16
|
+
if importlib.util.find_spec(lib, package=__package__) is not None:
|
|
17
|
+
LIB_AVAILABLE[lib] = True
|
|
18
|
+
__all__.append(lib)
|
|
19
|
+
else:
|
|
20
|
+
LIB_AVAILABLE[lib] = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# MPI utilities
|
|
24
|
+
if LIB_AVAILABLE["mpi4py"]:
|
|
25
|
+
from mpi4py import MPI
|
|
26
|
+
|
|
27
|
+
comm = MPI.COMM_WORLD
|
|
28
|
+
rank = comm.Get_rank()
|
|
29
|
+
size = comm.Get_size()
|
|
30
|
+
status = MPI.Status()
|
|
31
|
+
else:
|
|
32
|
+
# MPI library is not available: use dummy values
|
|
33
|
+
class DummyComm:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
comm = DummyComm()
|
|
37
|
+
rank = 0
|
|
38
|
+
size = 1
|
|
39
|
+
|
|
40
|
+
class DummyStatus:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
status = DummyStatus()
|
|
44
|
+
|
|
45
|
+
from . import numpy, solver # noqa: E402
|
llg3d/solver/jax.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""LLG3D solver using XLA compilation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jax import random
|
|
9
|
+
|
|
10
|
+
from ..output import progress_bar, get_output_files, close_output_files
|
|
11
|
+
from ..grid import Grid
|
|
12
|
+
from ..element import Element
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# JIT compile individual components for better performance and modularity
|
|
16
|
+
@jax.jit
|
|
17
|
+
def compute_H_anisotropy(
|
|
18
|
+
m: jnp.ndarray, coeff_2: float, anisotropy: int
|
|
19
|
+
) -> jnp.ndarray:
|
|
20
|
+
"""
|
|
21
|
+
Compute anisotropy field (JIT compiled).
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
25
|
+
coeff_2: Coefficient for anisotropy
|
|
26
|
+
anisotropy: Anisotropy type (0: uniaxial, 1: cubic)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Anisotropy field array (shape (3, nx, ny, nz))
|
|
30
|
+
"""
|
|
31
|
+
m1, m2, m3 = m
|
|
32
|
+
|
|
33
|
+
m1m1 = m1 * m1
|
|
34
|
+
m2m2 = m2 * m2
|
|
35
|
+
m3m3 = m3 * m3
|
|
36
|
+
|
|
37
|
+
# Uniaxial anisotropy
|
|
38
|
+
aniso_1_uniaxial = m1
|
|
39
|
+
aniso_2_uniaxial = jnp.zeros_like(m1)
|
|
40
|
+
aniso_3_uniaxial = jnp.zeros_like(m1)
|
|
41
|
+
|
|
42
|
+
# Cubic anisotropy
|
|
43
|
+
aniso_1_cubic = -(1 - m1m1 + m2m2 * m3m3) * m1
|
|
44
|
+
aniso_2_cubic = -(1 - m2m2 + m1m1 * m3m3) * m2
|
|
45
|
+
aniso_3_cubic = -(1 - m3m3 + m1m1 * m2m2) * m3
|
|
46
|
+
|
|
47
|
+
# Select based on anisotropy type
|
|
48
|
+
aniso_1 = jnp.where(
|
|
49
|
+
anisotropy == 0,
|
|
50
|
+
aniso_1_uniaxial,
|
|
51
|
+
jnp.where(anisotropy == 1, aniso_1_cubic, jnp.zeros_like(m1)),
|
|
52
|
+
)
|
|
53
|
+
aniso_2 = jnp.where(
|
|
54
|
+
anisotropy == 0,
|
|
55
|
+
aniso_2_uniaxial,
|
|
56
|
+
jnp.where(anisotropy == 1, aniso_2_cubic, jnp.zeros_like(m1)),
|
|
57
|
+
)
|
|
58
|
+
aniso_3 = jnp.where(
|
|
59
|
+
anisotropy == 0,
|
|
60
|
+
aniso_3_uniaxial,
|
|
61
|
+
jnp.where(anisotropy == 1, aniso_3_cubic, jnp.zeros_like(m1)),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return coeff_2 * jnp.stack([aniso_1, aniso_2, aniso_3], axis=0)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@jax.jit
|
|
68
|
+
def laplacian3D(
|
|
69
|
+
m_i: jnp.ndarray,
|
|
70
|
+
dx2_inv: float,
|
|
71
|
+
dy2_inv: float,
|
|
72
|
+
dz2_inv: float,
|
|
73
|
+
center_coeff: float,
|
|
74
|
+
) -> jnp.ndarray:
|
|
75
|
+
"""
|
|
76
|
+
Compute Laplacian for a single component with Neumann boundary conditions.
|
|
77
|
+
|
|
78
|
+
(JIT compiled)
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
m_i: Single component of magnetization (shape (nx, ny, nz))
|
|
82
|
+
dx2_inv: Inverse of squared grid spacing in x direction
|
|
83
|
+
dy2_inv: Inverse of squared grid spacing in y direction
|
|
84
|
+
dz2_inv: Inverse of squared grid spacing in z direction
|
|
85
|
+
center_coeff: Coefficient for the center point
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Laplacian of m_i (shape (nx, ny, nz))
|
|
89
|
+
"""
|
|
90
|
+
m_i_padded = jnp.pad(m_i, ((1, 1), (1, 1), (1, 1)), mode="reflect")
|
|
91
|
+
return (
|
|
92
|
+
dx2_inv * (m_i_padded[2:, 1:-1, 1:-1] + m_i_padded[:-2, 1:-1, 1:-1])
|
|
93
|
+
+ dy2_inv * (m_i_padded[1:-1, 2:, 1:-1] + m_i_padded[1:-1, :-2, 1:-1])
|
|
94
|
+
+ dz2_inv * (m_i_padded[1:-1, 1:-1, 2:] + m_i_padded[1:-1, 1:-1, :-2])
|
|
95
|
+
+ center_coeff * m_i
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@jax.jit
|
|
100
|
+
def compute_laplacian(m: jnp.ndarray, dx: float, dy: float, dz: float) -> jnp.ndarray:
|
|
101
|
+
"""
|
|
102
|
+
Compute 3D Laplacian with Neumann boundary conditions (JIT compiled).
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
106
|
+
dx: Grid spacing in x direction
|
|
107
|
+
dy: Grid spacing in y direction
|
|
108
|
+
dz: Grid spacing in z direction
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Laplacian of m (shape (3, nx, ny, nz))
|
|
112
|
+
"""
|
|
113
|
+
dx2_inv, dy2_inv, dz2_inv = 1 / dx**2, 1 / dy**2, 1 / dz**2
|
|
114
|
+
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
|
|
115
|
+
|
|
116
|
+
return jnp.stack(
|
|
117
|
+
[
|
|
118
|
+
laplacian3D(m[0], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
119
|
+
laplacian3D(m[1], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
120
|
+
laplacian3D(m[2], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
121
|
+
],
|
|
122
|
+
axis=0,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@jax.jit
|
|
127
|
+
def compute_space_average_jax(m1: jnp.ndarray) -> float:
|
|
128
|
+
"""
|
|
129
|
+
Compute space average using midpoint method on GPU (JIT compiled).
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
m1: First component of magnetization (shape (nx, ny, nz))
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Space average of m1
|
|
136
|
+
"""
|
|
137
|
+
# Get dimensions directly from the array shape
|
|
138
|
+
Jx, Jy, Jz = m1.shape
|
|
139
|
+
|
|
140
|
+
# Create 3D coordinate grids using the shape
|
|
141
|
+
i_coords = jnp.arange(Jx)
|
|
142
|
+
j_coords = jnp.arange(Jy)
|
|
143
|
+
k_coords = jnp.arange(Jz)
|
|
144
|
+
|
|
145
|
+
# Create 3D coordinate grids
|
|
146
|
+
ii, jj, kk = jnp.meshgrid(i_coords, j_coords, k_coords, indexing="ij")
|
|
147
|
+
|
|
148
|
+
# Apply midpoint weights (0.5 on edges, 1.0 elsewhere)
|
|
149
|
+
weights = jnp.ones_like(m1)
|
|
150
|
+
weights = jnp.where((ii == 0) | (ii == Jx - 1), weights * 0.5, weights)
|
|
151
|
+
weights = jnp.where((jj == 0) | (jj == Jy - 1), weights * 0.5, weights)
|
|
152
|
+
weights = jnp.where((kk == 0) | (kk == Jz - 1), weights * 0.5, weights)
|
|
153
|
+
|
|
154
|
+
# Compute weighted sum and normalize
|
|
155
|
+
weighted_sum = jnp.sum(weights * m1)
|
|
156
|
+
|
|
157
|
+
# Compute ncell from the weights (this is the effective cell count)
|
|
158
|
+
ncell = jnp.sum(weights)
|
|
159
|
+
|
|
160
|
+
return weighted_sum / ncell
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@jax.jit
|
|
164
|
+
def cross_product(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
|
|
165
|
+
r"""
|
|
166
|
+
Compute cross product :math:`a \times b` (JIT compiled).
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
a: First vector (shape (3, nx, ny, nz))
|
|
170
|
+
b: Second vector (shape (3, nx, ny, nz))
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Cross product :math:`a \times b` (shape (3, nx, ny, nz))
|
|
174
|
+
"""
|
|
175
|
+
# Use JAX's optimized cross product function directly on axis 0
|
|
176
|
+
return jnp.cross(a, b, axis=0)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# JIT compile the slope computation for performance
|
|
180
|
+
@jax.jit
|
|
181
|
+
def compute_slope(
|
|
182
|
+
g_params: dict, e_params: dict, m: jnp.ndarray, R_random: jnp.ndarray
|
|
183
|
+
) -> jnp.ndarray:
|
|
184
|
+
"""
|
|
185
|
+
JIT-compiled version of compute_slope_jax using modular sub-functions.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
g_params: Grid parameters dict (dx, dy, dz)
|
|
189
|
+
e_params: Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G,
|
|
190
|
+
anisotropy)
|
|
191
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
192
|
+
R_random: Random field array (shape (3, nx, ny, nz))
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Slope array (shape (3, nx, ny, nz))
|
|
196
|
+
"""
|
|
197
|
+
# Extract parameters
|
|
198
|
+
dx, dy, dz = g_params["dx"], g_params["dy"], g_params["dz"]
|
|
199
|
+
coeff_1 = e_params["coeff_1"]
|
|
200
|
+
coeff_2 = e_params["coeff_2"]
|
|
201
|
+
coeff_3 = e_params["coeff_3"]
|
|
202
|
+
lambda_G = e_params["lambda_G"]
|
|
203
|
+
anisotropy = e_params["anisotropy"]
|
|
204
|
+
|
|
205
|
+
# Compute components using modular sub-functions
|
|
206
|
+
H_aniso = compute_H_anisotropy(m, coeff_2, anisotropy)
|
|
207
|
+
laplacian_m = compute_laplacian(m, dx, dy, dz)
|
|
208
|
+
|
|
209
|
+
# Effective field
|
|
210
|
+
R_eff = coeff_1 * laplacian_m + R_random + H_aniso
|
|
211
|
+
R_eff = R_eff.at[0].add(coeff_3)
|
|
212
|
+
|
|
213
|
+
# Cross products using modular functions
|
|
214
|
+
m_cross_R_eff = cross_product(m, R_eff)
|
|
215
|
+
m_cross_m_cross_R_eff = cross_product(m, m_cross_R_eff)
|
|
216
|
+
|
|
217
|
+
return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def simulate(
|
|
221
|
+
N: int,
|
|
222
|
+
Jx: int,
|
|
223
|
+
Jy: int,
|
|
224
|
+
Jz: int,
|
|
225
|
+
dx: float,
|
|
226
|
+
T: float,
|
|
227
|
+
H_ext: float,
|
|
228
|
+
dt: float,
|
|
229
|
+
start_averaging: int,
|
|
230
|
+
n_mean: int,
|
|
231
|
+
n_profile: int,
|
|
232
|
+
element_class: Element,
|
|
233
|
+
precision: str,
|
|
234
|
+
seed: int,
|
|
235
|
+
device: str = "auto",
|
|
236
|
+
**_,
|
|
237
|
+
) -> tuple[float, str, float]:
|
|
238
|
+
"""
|
|
239
|
+
Simulates the system for N iterations using JAX.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
N: Number of iterations
|
|
243
|
+
Jx: Number of grid points in x direction
|
|
244
|
+
Jy: Number of grid points in y direction
|
|
245
|
+
Jz: Number of grid points in z direction
|
|
246
|
+
dx: Grid spacing
|
|
247
|
+
T: Temperature in Kelvin
|
|
248
|
+
H_ext: External magnetic field strength
|
|
249
|
+
dt: Time step for the simulation
|
|
250
|
+
start_averaging: Number of iterations for averaging
|
|
251
|
+
n_mean: Number of iterations for integral output
|
|
252
|
+
n_profile: Number of iterations for profile output
|
|
253
|
+
element_class: Element of the sample (default: Cobalt)
|
|
254
|
+
precision: Precision of the simulation (single or double)
|
|
255
|
+
seed: Random seed for temperature fluctuations
|
|
256
|
+
device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
- The time taken for the simulation
|
|
260
|
+
- The output filenames
|
|
261
|
+
- The average magnetization
|
|
262
|
+
"""
|
|
263
|
+
# Configure JAX
|
|
264
|
+
if device == "auto":
|
|
265
|
+
# Let JAX choose the best available device
|
|
266
|
+
pass
|
|
267
|
+
elif device == "cpu":
|
|
268
|
+
jax.config.update("jax_platform_name", "cpu")
|
|
269
|
+
elif device == "gpu":
|
|
270
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
271
|
+
elif device.startswith("gpu:"):
|
|
272
|
+
# Select specific GPU using environment variable
|
|
273
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
274
|
+
gpu_id = device.split(":")[1]
|
|
275
|
+
# Check if CUDA_VISIBLE_DEVICES is already set externally
|
|
276
|
+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
|
277
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
278
|
+
print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
|
|
279
|
+
else:
|
|
280
|
+
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
|
281
|
+
print(
|
|
282
|
+
f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Set precision
|
|
286
|
+
if precision == "double":
|
|
287
|
+
jax.config.update("jax_enable_x64", True)
|
|
288
|
+
jnp_float = jnp.float64
|
|
289
|
+
else:
|
|
290
|
+
jax.config.update("jax_enable_x64", False)
|
|
291
|
+
jnp_float = jnp.float32
|
|
292
|
+
|
|
293
|
+
print(f"Available JAX devices: {jax.devices()}")
|
|
294
|
+
print(f"Using JAX on device: {jax.devices()[0]}")
|
|
295
|
+
print(f"Precision: {precision} ({jnp_float})")
|
|
296
|
+
|
|
297
|
+
# Initialize random key for JAX
|
|
298
|
+
key = random.PRNGKey(seed)
|
|
299
|
+
|
|
300
|
+
g = Grid(Jx, Jy, Jz, dx)
|
|
301
|
+
dims = g.dims
|
|
302
|
+
|
|
303
|
+
e = element_class(T, H_ext, g, dt)
|
|
304
|
+
print(f"CFL = {e.get_CFL()}")
|
|
305
|
+
|
|
306
|
+
# Prepare parameters for JIT compilation using to_dict methods
|
|
307
|
+
g_params = g.to_dict()
|
|
308
|
+
e_params = e.to_dict()
|
|
309
|
+
|
|
310
|
+
# --- Initialization ---
|
|
311
|
+
def theta_init(shape):
|
|
312
|
+
"""Initialization of theta."""
|
|
313
|
+
return jnp.zeros(shape, dtype=jnp_float)
|
|
314
|
+
|
|
315
|
+
def phi_init(t, shape):
|
|
316
|
+
"""Initialization of phi."""
|
|
317
|
+
return jnp.zeros(shape, dtype=jnp_float) + e.gamma_0 * H_ext * t
|
|
318
|
+
|
|
319
|
+
m_n = jnp.zeros((3,) + dims, dtype=jnp_float)
|
|
320
|
+
|
|
321
|
+
theta = theta_init(dims)
|
|
322
|
+
phi = phi_init(0, dims)
|
|
323
|
+
|
|
324
|
+
m_n = m_n.at[0].set(jnp.cos(theta))
|
|
325
|
+
m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
|
|
326
|
+
m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
|
|
327
|
+
|
|
328
|
+
f_mean, f_profiles, output_filenames = get_output_files(g, T, n_mean, n_profile)
|
|
329
|
+
|
|
330
|
+
t = 0.0
|
|
331
|
+
m1_average = 0.0
|
|
332
|
+
|
|
333
|
+
# === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
|
|
334
|
+
print("Warming up JIT compilation...")
|
|
335
|
+
|
|
336
|
+
# Generate dummy random field for warmup
|
|
337
|
+
warmup_key = random.PRNGKey(42)
|
|
338
|
+
R_warmup = e.coeff_4 * random.normal(warmup_key, (3,) + dims, dtype=jnp_float)
|
|
339
|
+
|
|
340
|
+
# Warmup all JIT functions with actual data shapes
|
|
341
|
+
_ = compute_slope(g_params, e_params, m_n, R_warmup)
|
|
342
|
+
if n_mean != 0:
|
|
343
|
+
_ = compute_space_average_jax(m_n[0])
|
|
344
|
+
|
|
345
|
+
# Force compilation and execution to complete
|
|
346
|
+
jax.block_until_ready(m_n)
|
|
347
|
+
print("JIT warmup completed.")
|
|
348
|
+
|
|
349
|
+
start_time = time.perf_counter()
|
|
350
|
+
|
|
351
|
+
for n in progress_bar(range(1, N + 1), "Iteration : ", 40):
|
|
352
|
+
t += dt
|
|
353
|
+
|
|
354
|
+
# Generate random field for temperature effect
|
|
355
|
+
key, subkey = random.split(key)
|
|
356
|
+
R_random = e.coeff_4 * random.normal(subkey, (3,) + dims, dtype=jnp_float)
|
|
357
|
+
|
|
358
|
+
# Use JIT-compiled version for better performance
|
|
359
|
+
s_pre = compute_slope(g_params, e_params, m_n, R_random)
|
|
360
|
+
m_pre = m_n + dt * s_pre
|
|
361
|
+
s_cor = compute_slope(g_params, e_params, m_pre, R_random)
|
|
362
|
+
|
|
363
|
+
# Update magnetization
|
|
364
|
+
m_n = m_n + dt * 0.5 * (s_pre + s_cor)
|
|
365
|
+
|
|
366
|
+
# Renormalize to unit sphere
|
|
367
|
+
norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
|
|
368
|
+
m_n = m_n / norm
|
|
369
|
+
|
|
370
|
+
# Export the average of m1 to a file (optimized for GPU)
|
|
371
|
+
if n_mean != 0 and n % n_mean == 0:
|
|
372
|
+
# Compute space average directly on GPU without CPU transfer
|
|
373
|
+
m1_mean = compute_space_average_jax(m_n[0])
|
|
374
|
+
# Convert to Python float for file writing
|
|
375
|
+
m1_mean = float(m1_mean)
|
|
376
|
+
if n >= start_averaging:
|
|
377
|
+
m1_average += m1_mean * n_mean
|
|
378
|
+
f_mean.write(f"{t:10.8e} {m1_mean:10.8e}\n")
|
|
379
|
+
|
|
380
|
+
total_time = time.perf_counter() - start_time
|
|
381
|
+
|
|
382
|
+
close_output_files(f_mean, f_profiles)
|
|
383
|
+
|
|
384
|
+
if n > start_averaging:
|
|
385
|
+
m1_average /= N - start_averaging
|
|
386
|
+
|
|
387
|
+
return total_time, output_filenames, m1_average
|