llg3d 2.0.1__py3-none-any.whl → 3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llg3d/__init__.py +2 -4
- llg3d/benchmarks/__init__.py +1 -0
- llg3d/benchmarks/compare_commits.py +321 -0
- llg3d/benchmarks/efficiency.py +451 -0
- llg3d/benchmarks/utils.py +25 -0
- llg3d/element.py +98 -17
- llg3d/grid.py +48 -58
- llg3d/io.py +395 -0
- llg3d/main.py +32 -35
- llg3d/parameters.py +159 -49
- llg3d/post/__init__.py +1 -1
- llg3d/post/extract.py +112 -0
- llg3d/post/info.py +192 -0
- llg3d/post/m1_vs_T.py +107 -0
- llg3d/post/m1_vs_time.py +81 -0
- llg3d/post/process.py +87 -85
- llg3d/post/utils.py +38 -0
- llg3d/post/x_profiles.py +161 -0
- llg3d/py.typed +1 -0
- llg3d/solvers/__init__.py +153 -0
- llg3d/solvers/base.py +345 -0
- llg3d/solvers/experimental/__init__.py +9 -0
- llg3d/{solver → solvers/experimental}/jax.py +117 -143
- llg3d/solvers/math_utils.py +41 -0
- llg3d/solvers/mpi.py +370 -0
- llg3d/solvers/numpy.py +126 -0
- llg3d/solvers/opencl.py +439 -0
- llg3d/solvers/profiling.py +38 -0
- {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/METADATA +5 -2
- llg3d-3.1.0.dist-info/RECORD +36 -0
- {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/WHEEL +1 -1
- llg3d-3.1.0.dist-info/entry_points.txt +9 -0
- llg3d/output.py +0 -107
- llg3d/post/plot_results.py +0 -61
- llg3d/post/temperature.py +0 -76
- llg3d/simulation.py +0 -95
- llg3d/solver/__init__.py +0 -45
- llg3d/solver/mpi.py +0 -450
- llg3d/solver/numpy.py +0 -207
- llg3d/solver/opencl.py +0 -330
- llg3d/solver/solver.py +0 -89
- llg3d-2.0.1.dist-info/RECORD +0 -25
- llg3d-2.0.1.dist-info/entry_points.txt +0 -4
- {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/AUTHORS +0 -0
- {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/LICENSE +0 -0
- {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/top_level.txt +0 -0
llg3d/solvers/base.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""Define the base solver class."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, ClassVar
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from ..element import Element, get_element_class
|
|
13
|
+
from ..grid import Grid
|
|
14
|
+
from ..io import (
|
|
15
|
+
Metrics,
|
|
16
|
+
Observables,
|
|
17
|
+
RecordsBuffer,
|
|
18
|
+
format_profiling_table,
|
|
19
|
+
get_tqdm_file,
|
|
20
|
+
save_results,
|
|
21
|
+
)
|
|
22
|
+
from ..parameters import InitType, RunParameters
|
|
23
|
+
from . import rank, size
|
|
24
|
+
from .math_utils import normalize
|
|
25
|
+
from .profiling import ProfilingStats, timeit
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class BaseSolver(ABC, RunParameters):
|
|
30
|
+
"""Abstract data base class for LLG3D solvers."""
|
|
31
|
+
|
|
32
|
+
solver_type: ClassVar[str] = "base" #: Solver type name
|
|
33
|
+
|
|
34
|
+
def __post_init__(self) -> None:
|
|
35
|
+
"""Initialize the solver after dataclass creation."""
|
|
36
|
+
# Ensure solver name matches class solver_type
|
|
37
|
+
if self.solver != self.solver_type:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Trying to initialize a {self.__class__.__name__}, but solver name "
|
|
40
|
+
f'mismatch: expected "{self.solver_type}", got "{self.solver}"'
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Initialize results structure aligned with RunResults
|
|
44
|
+
self.metrics: Metrics = {
|
|
45
|
+
"total_time": 0.0,
|
|
46
|
+
"time_per_ite": 0.0,
|
|
47
|
+
"efficiency": 0.0,
|
|
48
|
+
"CFL": 0.0,
|
|
49
|
+
}
|
|
50
|
+
# Physical observables
|
|
51
|
+
self.observables: Observables = {}
|
|
52
|
+
# Records with optional x_profiles and xyz_average
|
|
53
|
+
self.records: RecordsBuffer = {}
|
|
54
|
+
|
|
55
|
+
self.np = size # Add a parameter for the number of processes
|
|
56
|
+
|
|
57
|
+
self.np_float: np.dtype = np.dtype(
|
|
58
|
+
np.float64 if self.precision == "double" else np.float32
|
|
59
|
+
)
|
|
60
|
+
self.grid: Grid = Grid(self.Jx, self.Jy, self.Jz, self.dx)
|
|
61
|
+
if rank == 0:
|
|
62
|
+
print(self.grid)
|
|
63
|
+
# Reference the element class from the element string
|
|
64
|
+
ElementClass: type[Element] = get_element_class(self.element)
|
|
65
|
+
# Pass dtype to Element so its scalar coefficients have the correct
|
|
66
|
+
# numpy dtype (`self.np_float`) and avoid implicit promotion.
|
|
67
|
+
self.elem: Element = ElementClass(
|
|
68
|
+
self.T, self.H_ext, self.grid, self.dt, dtype=self.np_float
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.rng: np.random.Generator = self._init_rng()
|
|
72
|
+
self.profiling_stats: ProfilingStats = defaultdict(
|
|
73
|
+
lambda: {"time": 0.0, "calls": 0}
|
|
74
|
+
)
|
|
75
|
+
self._tqdm_file = get_tqdm_file() #: File object for tqdm output
|
|
76
|
+
|
|
77
|
+
def theta_init_0(self, t: float) -> np.ndarray:
|
|
78
|
+
"""Initialization of theta with 0."""
|
|
79
|
+
return np.zeros(self.grid.dims, dtype=self.np_float)
|
|
80
|
+
|
|
81
|
+
def theta_init_dw(self, t: float) -> np.ndarray:
|
|
82
|
+
"""Initialization of theta with a domain wall profile."""
|
|
83
|
+
x, _, _ = self.grid.get_mesh(local=size > 1, dtype=self.np_float)
|
|
84
|
+
return 2.0 * np.arctan(
|
|
85
|
+
np.exp(
|
|
86
|
+
-(
|
|
87
|
+
x
|
|
88
|
+
- self.grid.Lx / 2
|
|
89
|
+
+ self.elem.d_0 * self.elem.coeff_3 * self.elem.lambda_G * t
|
|
90
|
+
)
|
|
91
|
+
/ self.elem.d_0
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def phi_init(self, t: float) -> np.ndarray:
|
|
96
|
+
"""Initialization of phi."""
|
|
97
|
+
return (
|
|
98
|
+
np.zeros(self.grid.dims, dtype=self.np_float)
|
|
99
|
+
+ self.elem.gamma_0 * self.elem.H_ext * t
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def _init_m_n(self) -> np.ndarray:
|
|
103
|
+
"""Initialize the magnetization array at time step n."""
|
|
104
|
+
m_n = np.zeros((3,) + self.grid.dims, dtype=self.np_float)
|
|
105
|
+
|
|
106
|
+
if self.init_type == "0":
|
|
107
|
+
theta = self.theta_init_0(0)
|
|
108
|
+
elif self.init_type == "dw":
|
|
109
|
+
theta = self.theta_init_dw(0)
|
|
110
|
+
else:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Unknown initialization type: {self.init_type}, "
|
|
113
|
+
f"should be in {InitType}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
phi = self.phi_init(0)
|
|
117
|
+
|
|
118
|
+
m_n[0] = np.cos(theta)
|
|
119
|
+
m_n[1] = np.sin(theta) * np.cos(phi)
|
|
120
|
+
m_n[2] = np.sin(theta) * np.sin(phi)
|
|
121
|
+
# renormalize to verify the constraint of being on the sphere
|
|
122
|
+
normalize(m_n)
|
|
123
|
+
return m_n
|
|
124
|
+
|
|
125
|
+
def _init_rng(self) -> np.random.Generator:
|
|
126
|
+
"""
|
|
127
|
+
Initialize a random number generator for temperature fluctuations.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
A numpy random number generator
|
|
131
|
+
"""
|
|
132
|
+
# Initialize a sequence of random seeds
|
|
133
|
+
# See: https://numpy.org/doc/stable/reference/random/parallel.html
|
|
134
|
+
ss = np.random.SeedSequence(self.seed)
|
|
135
|
+
|
|
136
|
+
# Deploy size x SeedSequence to pass to child processes
|
|
137
|
+
child_seeds = ss.spawn(size)
|
|
138
|
+
streams = [np.random.default_rng(s) for s in child_seeds]
|
|
139
|
+
rng = streams[rank]
|
|
140
|
+
return rng
|
|
141
|
+
|
|
142
|
+
@timeit
|
|
143
|
+
def _get_R_random(self) -> np.ndarray:
|
|
144
|
+
"""
|
|
145
|
+
Generate the random field for temperature fluctuations.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Random field array (shape (3, nx, ny, nz))
|
|
149
|
+
"""
|
|
150
|
+
R_random = self.elem.coeff_4 * self.rng.standard_normal(
|
|
151
|
+
(3, *self.grid.dims), dtype=self.np_float
|
|
152
|
+
)
|
|
153
|
+
return R_random
|
|
154
|
+
|
|
155
|
+
@timeit
|
|
156
|
+
def _normalize(self, m_n: np.ndarray) -> None:
|
|
157
|
+
r"""
|
|
158
|
+
Normalize the magnetization array (in place).
|
|
159
|
+
|
|
160
|
+
.. math::
|
|
161
|
+
|
|
162
|
+
\mathbf{m}_n = \frac{\mathbf{m}_n}{|\mathbf{m}_n|}
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
m_n: Magnetization array at time step n (shape (3, nx, ny, nz)).
|
|
166
|
+
"""
|
|
167
|
+
normalize(m_n)
|
|
168
|
+
|
|
169
|
+
@timeit
|
|
170
|
+
def _xyz_average(self, m: np.ndarray) -> float:
|
|
171
|
+
"""
|
|
172
|
+
Returns the spatial average of m with shape (g.dims) using the midpoint method.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
m: Array to be integrated
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Spatial average of m
|
|
179
|
+
"""
|
|
180
|
+
mm = m.copy() # copy m to avoid modifying its value
|
|
181
|
+
|
|
182
|
+
# on the edges, we divide the contribution by 2
|
|
183
|
+
# x
|
|
184
|
+
mm[0, :, :] /= 2
|
|
185
|
+
mm[-1, :, :] /= 2
|
|
186
|
+
# y
|
|
187
|
+
mm[:, 0, :] /= 2
|
|
188
|
+
mm[:, -1, :] /= 2
|
|
189
|
+
# z
|
|
190
|
+
mm[:, :, 0] /= 2
|
|
191
|
+
mm[:, :, -1] /= 2
|
|
192
|
+
|
|
193
|
+
average = mm.sum() / self.grid.ncell
|
|
194
|
+
return float(average)
|
|
195
|
+
|
|
196
|
+
def _record_xyz_average(self, m_n: np.ndarray, t: float, n: int) -> None:
|
|
197
|
+
"""Update the time average of m1."""
|
|
198
|
+
xyz_average = self._xyz_average(m_n[0])
|
|
199
|
+
if rank == 0:
|
|
200
|
+
# Ensure xyz_average list exists
|
|
201
|
+
if "xyz_average" not in self.records:
|
|
202
|
+
self.records["xyz_average"] = []
|
|
203
|
+
# Record the mean value at time t
|
|
204
|
+
self.records["xyz_average"].append((t, xyz_average))
|
|
205
|
+
# Update time average of m1
|
|
206
|
+
if n >= self.start_averaging:
|
|
207
|
+
# Initialize m1_mean on first use
|
|
208
|
+
if "m1_mean" not in self.observables:
|
|
209
|
+
self.observables["m1_mean"] = 0.0
|
|
210
|
+
# Accumulate time average (each sample contributes equally)
|
|
211
|
+
self.observables["m1_mean"] += xyz_average
|
|
212
|
+
|
|
213
|
+
def _finalize(self) -> None:
|
|
214
|
+
"""Normalize m1_mean by the actual number of samples accumulated."""
|
|
215
|
+
if rank == 0:
|
|
216
|
+
if "m1_mean" in self.observables:
|
|
217
|
+
# Divide by actual number of samples
|
|
218
|
+
# (accounting for n_mean sampling interval)
|
|
219
|
+
num_samples = (self.N - self.start_averaging) // self.n_mean
|
|
220
|
+
self.observables["m1_mean"] /= num_samples
|
|
221
|
+
|
|
222
|
+
@timeit
|
|
223
|
+
def _yz_average(self, m_i: np.ndarray) -> np.ndarray:
|
|
224
|
+
"""
|
|
225
|
+
Returns the spatial average of m using the midpoint method along y and z.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
m_i: Array to be integrated
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Spatial average of m in y and z of shape (g.dims[0],)
|
|
232
|
+
"""
|
|
233
|
+
# Make a copy of m to avoid modifying its value
|
|
234
|
+
mm = m_i.copy()
|
|
235
|
+
|
|
236
|
+
# On y and z edges, divide the contribution by 2
|
|
237
|
+
mm[:, 0, :] /= 2
|
|
238
|
+
mm[:, -1, :] /= 2
|
|
239
|
+
mm[:, :, 0] /= 2
|
|
240
|
+
mm[:, :, -1] /= 2
|
|
241
|
+
|
|
242
|
+
n_cell_yz = (mm.shape[1] - 1) * (mm.shape[2] - 1)
|
|
243
|
+
return mm.sum(axis=(1, 2)) / n_cell_yz
|
|
244
|
+
|
|
245
|
+
@timeit
|
|
246
|
+
def _update_x_profiles(self, m_n: np.ndarray, t: float) -> None:
|
|
247
|
+
"""Update x profiles of the averaged m_i in y and z."""
|
|
248
|
+
# Initialize x_profiles on first use
|
|
249
|
+
if "x_profiles" not in self.records:
|
|
250
|
+
self.records["x_profiles"] = {
|
|
251
|
+
"t": [],
|
|
252
|
+
"m1": [],
|
|
253
|
+
"m2": [],
|
|
254
|
+
"m3": [],
|
|
255
|
+
}
|
|
256
|
+
x_prof = self.records["x_profiles"]
|
|
257
|
+
x_prof["t"].append(t)
|
|
258
|
+
x_prof["m1"].append(self._yz_average(m_n[0]))
|
|
259
|
+
x_prof["m2"].append(self._yz_average(m_n[1]))
|
|
260
|
+
x_prof["m3"].append(self._yz_average(m_n[2]))
|
|
261
|
+
|
|
262
|
+
def _record(self, m_n: Any, t: float, n: int) -> None:
|
|
263
|
+
"""Record simulation data."""
|
|
264
|
+
# Record the average of m1
|
|
265
|
+
if self.n_mean != 0 and n % self.n_mean == 0:
|
|
266
|
+
self._record_xyz_average(m_n, t, n)
|
|
267
|
+
# Record the x profiles of the averaged m_i in y and z
|
|
268
|
+
if self.n_profile != 0 and n % self.n_profile == 0:
|
|
269
|
+
self._update_x_profiles(m_n, t)
|
|
270
|
+
|
|
271
|
+
def _progress_bar(self):
|
|
272
|
+
"""Return a progress bar for the given range using tqdm."""
|
|
273
|
+
if self._tqdm_file is None:
|
|
274
|
+
# corresponds to rank != 0 in MPI
|
|
275
|
+
return range(self.N)
|
|
276
|
+
|
|
277
|
+
return tqdm(
|
|
278
|
+
range(self.N),
|
|
279
|
+
file=self._tqdm_file,
|
|
280
|
+
dynamic_ncols=True,
|
|
281
|
+
leave=True,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
@abstractmethod
|
|
285
|
+
def _simulate(self) -> float:
|
|
286
|
+
"""
|
|
287
|
+
Simulates the system for N iterations.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
total_time: Total simulation time
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
def run(self) -> None:
|
|
294
|
+
"""Runs the simulation and store the results."""
|
|
295
|
+
total_time = self._simulate()
|
|
296
|
+
self.metrics["total_time"] = total_time
|
|
297
|
+
time_per_ite = total_time / self.N if self.N > 0 else 0.0
|
|
298
|
+
self.metrics["time_per_ite"] = time_per_ite
|
|
299
|
+
self.metrics["efficiency"] = time_per_ite / self.grid.ntot
|
|
300
|
+
self.metrics["CFL"] = float(self.elem.get_CFL())
|
|
301
|
+
if rank == 0:
|
|
302
|
+
# Store only profiling stats for functions/kernels that were actually called
|
|
303
|
+
profiling_filtered = {
|
|
304
|
+
k: v for k, v in self.profiling_stats.items() if v.get("calls", 0) > 0
|
|
305
|
+
}
|
|
306
|
+
if profiling_filtered:
|
|
307
|
+
self.metrics["profiling_stats"] = profiling_filtered
|
|
308
|
+
|
|
309
|
+
def _format_profiling(self) -> str:
|
|
310
|
+
"""Format the profiling information for display."""
|
|
311
|
+
return format_profiling_table(self.profiling_stats, self.metrics["total_time"])
|
|
312
|
+
|
|
313
|
+
def save(self, dir_path: str | Path = ".") -> None:
|
|
314
|
+
"""
|
|
315
|
+
Saves the results of the simulation to a .npz file.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
dir_path: Directory path to save the results
|
|
319
|
+
"""
|
|
320
|
+
if rank == 0:
|
|
321
|
+
if self.metrics["total_time"] > 0:
|
|
322
|
+
s = f"""\
|
|
323
|
+
N iterations = {self.N}
|
|
324
|
+
total_time [s] = {self.metrics["total_time"]:.03f}
|
|
325
|
+
time/ite [s/ite] = {self.metrics["time_per_ite"]:.03e}
|
|
326
|
+
efficiency [s/ite/pt] = {self.metrics["efficiency"]:.03e}
|
|
327
|
+
CFL = {self.metrics["CFL"]:.03e}"""
|
|
328
|
+
print(s)
|
|
329
|
+
|
|
330
|
+
# Print profiling info only if enabled
|
|
331
|
+
if self.profiling:
|
|
332
|
+
print(f"Profiling info:\n{self._format_profiling()}")
|
|
333
|
+
|
|
334
|
+
# Export the mean of m1 over space and time
|
|
335
|
+
if "m1_mean" in self.observables:
|
|
336
|
+
print(f"m1_mean = {self.observables['m1_mean']:e}")
|
|
337
|
+
|
|
338
|
+
print(f"Saving {self.result_file}")
|
|
339
|
+
save_results(
|
|
340
|
+
Path(dir_path) / self.result_file,
|
|
341
|
+
self,
|
|
342
|
+
self.metrics,
|
|
343
|
+
observables=self.observables if self.observables else None,
|
|
344
|
+
records_buffer=self.records if self.records else None,
|
|
345
|
+
)
|
|
@@ -1,4 +1,12 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
LLG3D solver using XLA compilation.
|
|
3
|
+
|
|
4
|
+
.. warning::
|
|
5
|
+
|
|
6
|
+
It is experimental and not maintained.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import ClassVar
|
|
2
10
|
|
|
3
11
|
import os
|
|
4
12
|
import time
|
|
@@ -7,9 +15,7 @@ import jax
|
|
|
7
15
|
import jax.numpy as jnp
|
|
8
16
|
from jax import random
|
|
9
17
|
|
|
10
|
-
from ..
|
|
11
|
-
from ..grid import Grid
|
|
12
|
-
from ..element import Element
|
|
18
|
+
from ..base import BaseSolver
|
|
13
19
|
|
|
14
20
|
|
|
15
21
|
# JIT compile individual components for better performance and modularity
|
|
@@ -157,7 +163,7 @@ def compute_space_average_jax(m1: jnp.ndarray) -> float:
|
|
|
157
163
|
# Compute ncell from the weights (this is the effective cell count)
|
|
158
164
|
ncell = jnp.sum(weights)
|
|
159
165
|
|
|
160
|
-
return weighted_sum / ncell
|
|
166
|
+
return weighted_sum / ncell # type: ignore
|
|
161
167
|
|
|
162
168
|
|
|
163
169
|
@jax.jit
|
|
@@ -217,171 +223,139 @@ def compute_slope(
|
|
|
217
223
|
return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
|
|
218
224
|
|
|
219
225
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
Jx: int,
|
|
223
|
-
Jy: int,
|
|
224
|
-
Jz: int,
|
|
225
|
-
dx: float,
|
|
226
|
-
T: float,
|
|
227
|
-
H_ext: float,
|
|
228
|
-
dt: float,
|
|
229
|
-
start_averaging: int,
|
|
230
|
-
n_mean: int,
|
|
231
|
-
n_profile: int,
|
|
232
|
-
element_class: Element,
|
|
233
|
-
precision: str,
|
|
234
|
-
seed: int,
|
|
235
|
-
device: str = "auto",
|
|
236
|
-
**_,
|
|
237
|
-
) -> tuple[float, str, float]:
|
|
238
|
-
"""
|
|
239
|
-
Simulates the system for N iterations using JAX.
|
|
226
|
+
class JaxSolver(BaseSolver):
|
|
227
|
+
"""JAX-based LLG3D solver."""
|
|
240
228
|
|
|
241
|
-
|
|
242
|
-
N: Number of iterations
|
|
243
|
-
Jx: Number of grid points in x direction
|
|
244
|
-
Jy: Number of grid points in y direction
|
|
245
|
-
Jz: Number of grid points in z direction
|
|
246
|
-
dx: Grid spacing
|
|
247
|
-
T: Temperature in Kelvin
|
|
248
|
-
H_ext: External magnetic field strength
|
|
249
|
-
dt: Time step for the simulation
|
|
250
|
-
start_averaging: Number of iterations for averaging
|
|
251
|
-
n_mean: Number of iterations for integral output
|
|
252
|
-
n_profile: Number of iterations for profile output
|
|
253
|
-
element_class: Element of the sample (default: Cobalt)
|
|
254
|
-
precision: Precision of the simulation (single or double)
|
|
255
|
-
seed: Random seed for temperature fluctuations
|
|
256
|
-
device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
|
|
229
|
+
solver_type: ClassVar[str] = "jax" #: Solver type name
|
|
257
230
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
- The average magnetization
|
|
262
|
-
"""
|
|
263
|
-
# Configure JAX
|
|
264
|
-
if device == "auto":
|
|
265
|
-
# Let JAX choose the best available device
|
|
266
|
-
pass
|
|
267
|
-
elif device == "cpu":
|
|
268
|
-
jax.config.update("jax_platform_name", "cpu")
|
|
269
|
-
elif device == "gpu":
|
|
270
|
-
jax.config.update("jax_platform_name", "gpu")
|
|
271
|
-
elif device.startswith("gpu:"):
|
|
272
|
-
# Select specific GPU using environment variable
|
|
273
|
-
jax.config.update("jax_platform_name", "gpu")
|
|
274
|
-
gpu_id = device.split(":")[1]
|
|
275
|
-
# Check if CUDA_VISIBLE_DEVICES is already set externally
|
|
276
|
-
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
|
277
|
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
278
|
-
print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
|
|
279
|
-
else:
|
|
280
|
-
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
|
281
|
-
print(
|
|
282
|
-
f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
|
|
283
|
-
)
|
|
231
|
+
def _xyz_average(self, m1: jnp.ndarray) -> float:
|
|
232
|
+
"""Compute the space average of m1 using JAX."""
|
|
233
|
+
return compute_space_average_jax(m1)
|
|
284
234
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
jnp_float = jnp.float64
|
|
289
|
-
else:
|
|
290
|
-
jax.config.update("jax_enable_x64", False)
|
|
291
|
-
jnp_float = jnp.float32
|
|
235
|
+
def _simulate(self) -> float:
|
|
236
|
+
"""
|
|
237
|
+
Simulates the system for N iterations using JAX.
|
|
292
238
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
print(f"Precision: {precision} ({jnp_float})")
|
|
239
|
+
Attributes:
|
|
240
|
+
device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
|
|
296
241
|
|
|
297
|
-
|
|
298
|
-
|
|
242
|
+
Returns:
|
|
243
|
+
The time taken for the simulation
|
|
299
244
|
|
|
300
|
-
|
|
301
|
-
|
|
245
|
+
Raises:
|
|
246
|
+
NotImplementedError: If n_profile is not zero
|
|
247
|
+
"""
|
|
248
|
+
if self.n_profile != 0:
|
|
249
|
+
raise NotImplementedError(
|
|
250
|
+
"Saving x-profiles is not implemented for the JAX solver."
|
|
251
|
+
)
|
|
302
252
|
|
|
303
|
-
|
|
304
|
-
|
|
253
|
+
# Configure JAX
|
|
254
|
+
if self.device == "auto":
|
|
255
|
+
# Let JAX choose the best available device
|
|
256
|
+
pass
|
|
257
|
+
elif self.device == "cpu":
|
|
258
|
+
jax.config.update("jax_platform_name", "cpu")
|
|
259
|
+
elif self.device == "gpu":
|
|
260
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
261
|
+
elif self.device.startswith("gpu:"):
|
|
262
|
+
# Select specific GPU using environment variable
|
|
263
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
264
|
+
gpu_id = self.device.split(":")[1]
|
|
265
|
+
# Check if CUDA_VISIBLE_DEVICES is already set externally
|
|
266
|
+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
|
267
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
268
|
+
print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
|
|
269
|
+
else:
|
|
270
|
+
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
|
271
|
+
print(f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}")
|
|
272
|
+
|
|
273
|
+
# Set precision
|
|
274
|
+
if self.precision == "double":
|
|
275
|
+
jax.config.update("jax_enable_x64", True)
|
|
276
|
+
jnp_float = jnp.float64
|
|
277
|
+
else:
|
|
278
|
+
jax.config.update("jax_enable_x64", False)
|
|
279
|
+
jnp_float = jnp.float32
|
|
305
280
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
281
|
+
print(f"Available JAX devices: {jax.devices()}")
|
|
282
|
+
print(f"Using JAX on device: {jax.devices()[0]}")
|
|
283
|
+
print(f"Precision: {self.precision} ({jnp_float})")
|
|
309
284
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
"""Initialization of theta."""
|
|
313
|
-
return jnp.zeros(shape, dtype=jnp_float)
|
|
285
|
+
# Initialize random key for JAX
|
|
286
|
+
key = random.PRNGKey(self.seed)
|
|
314
287
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
288
|
+
# Prepare parameters for JIT compilation using to_dict methods
|
|
289
|
+
g_params = self.grid.as_dict()
|
|
290
|
+
e_params = self.elem.to_dict()
|
|
318
291
|
|
|
319
|
-
|
|
292
|
+
# --- Initialization ---
|
|
293
|
+
def theta_init(shape):
|
|
294
|
+
"""Initialization of theta."""
|
|
295
|
+
return jnp.zeros(shape, dtype=jnp_float)
|
|
320
296
|
|
|
321
|
-
|
|
322
|
-
|
|
297
|
+
def phi_init(t, shape):
|
|
298
|
+
"""Initialization of phi."""
|
|
299
|
+
return (
|
|
300
|
+
jnp.zeros(shape, dtype=jnp_float) + self.elem.gamma_0 * self.H_ext * t
|
|
301
|
+
)
|
|
323
302
|
|
|
324
|
-
|
|
325
|
-
m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
|
|
326
|
-
m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
|
|
303
|
+
m_n = jnp.zeros((3, *self.grid.dims), dtype=jnp_float)
|
|
327
304
|
|
|
328
|
-
|
|
305
|
+
theta = theta_init(self.grid.dims)
|
|
306
|
+
phi = phi_init(0, self.grid.dims)
|
|
329
307
|
|
|
330
|
-
|
|
331
|
-
|
|
308
|
+
m_n = m_n.at[0].set(jnp.cos(theta))
|
|
309
|
+
m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
|
|
310
|
+
m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
|
|
332
311
|
|
|
333
|
-
|
|
334
|
-
print("Warming up JIT compilation...")
|
|
312
|
+
t = 0.0
|
|
335
313
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
R_warmup = e.coeff_4 * random.normal(warmup_key, (3,) + dims, dtype=jnp_float)
|
|
314
|
+
# === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
|
|
315
|
+
print("Warming up JIT compilation...")
|
|
339
316
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
317
|
+
# Generate dummy random field for warmup
|
|
318
|
+
warmup_key = random.PRNGKey(42)
|
|
319
|
+
R_warmup = self.elem.coeff_4 * random.normal(
|
|
320
|
+
warmup_key, (3, *self.grid.dims), dtype=jnp_float
|
|
321
|
+
)
|
|
344
322
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
323
|
+
# Warmup all JIT functions with actual data shapes
|
|
324
|
+
_ = compute_slope(g_params, e_params, m_n, R_warmup)
|
|
325
|
+
if self.n_mean != 0:
|
|
326
|
+
_ = compute_space_average_jax(m_n[0])
|
|
348
327
|
|
|
349
|
-
|
|
328
|
+
# Force compilation and execution to complete
|
|
329
|
+
jax.block_until_ready(m_n)
|
|
330
|
+
print("JIT warmup completed.")
|
|
350
331
|
|
|
351
|
-
|
|
352
|
-
t += dt
|
|
332
|
+
start_time = time.perf_counter()
|
|
353
333
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
R_random = e.coeff_4 * random.normal(subkey, (3,) + dims, dtype=jnp_float)
|
|
334
|
+
for n in self._progress_bar():
|
|
335
|
+
t += self.dt
|
|
357
336
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
337
|
+
# Generate random field for temperature effect
|
|
338
|
+
key, subkey = random.split(key)
|
|
339
|
+
R_random = self.elem.coeff_4 * random.normal(
|
|
340
|
+
subkey, (3, *self.grid.dims), dtype=jnp_float
|
|
341
|
+
)
|
|
362
342
|
|
|
363
|
-
|
|
364
|
-
|
|
343
|
+
# Use JIT-compiled version for better performance
|
|
344
|
+
s_pre = compute_slope(g_params, e_params, m_n, R_random)
|
|
345
|
+
m_pre = m_n + self.dt * s_pre
|
|
346
|
+
s_cor = compute_slope(g_params, e_params, m_pre, R_random)
|
|
365
347
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
m_n = m_n / norm
|
|
348
|
+
# Update magnetization
|
|
349
|
+
m_n = m_n + self.dt * 0.5 * (s_pre + s_cor)
|
|
369
350
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
m1_mean = compute_space_average_jax(m_n[0])
|
|
374
|
-
# Convert to Python float for file writing
|
|
375
|
-
m1_mean = float(m1_mean)
|
|
376
|
-
if n >= start_averaging:
|
|
377
|
-
m1_average += m1_mean * n_mean
|
|
378
|
-
f_mean.write(f"{t:10.8e} {m1_mean:10.8e}\n")
|
|
351
|
+
# Renormalize to unit sphere
|
|
352
|
+
norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
|
|
353
|
+
m_n = m_n / norm
|
|
379
354
|
|
|
380
|
-
|
|
355
|
+
self._record(m_n, t, n)
|
|
381
356
|
|
|
382
|
-
|
|
357
|
+
total_time = time.perf_counter() - start_time
|
|
383
358
|
|
|
384
|
-
|
|
385
|
-
m1_average /= N - start_averaging
|
|
359
|
+
self._finalize()
|
|
386
360
|
|
|
387
|
-
|
|
361
|
+
return total_time
|