llg3d 1.4.1__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llg3d/post/temperature.py CHANGED
@@ -1,30 +1,24 @@
1
1
  #!/usr/bin/env python3
2
- """
3
- Plot the magnetization vs temperature and determine the Curie temperature.
4
- """
2
+ """Plot the magnetization vs temperature and determine the Curie temperature."""
5
3
 
6
4
  import argparse
7
5
  from pathlib import Path
8
6
 
9
7
  import matplotlib.pyplot as plt
10
- import numpy as np
11
8
 
12
9
  from .process import MagData
13
10
 
14
11
 
15
12
  def plot_m_vs_T(m: MagData, show: bool):
16
13
  """
17
- Plots the data (T, <m>), interpolates the values,
18
- calculates the Curie temperature.
19
- Exports to PNG.
14
+ Plots the data (T, <m>).
15
+
16
+ Interpolates the values, calculates the Curie temperature, exports to PNG.
20
17
 
21
18
  Args:
22
- data: numpy array (T, <m>)
23
- parentdir: path to the directory containing the runs
24
- run: descriptive dictionary of the run
19
+ m: Magnetization data object
25
20
  show: display the graph in a graphical window
26
21
  """
27
-
28
22
  print(f"T_Curie = {m.T_Curie:.0f} K")
29
23
 
30
24
  fig, ax = plt.subplots()
@@ -57,9 +51,7 @@ def plot_m_vs_T(m: MagData, show: bool):
57
51
 
58
52
 
59
53
  def main():
60
- """
61
- Parses the command line to execute processing functions
62
- """
54
+ """Parses the command line to execute processing functions."""
63
55
  parser = argparse.ArgumentParser(description=__doc__)
64
56
  parser.add_argument("--job_dir", type=Path, help="Slurm main job directory")
65
57
  parser.add_argument(
llg3d/simulation.py ADDED
@@ -0,0 +1,95 @@
1
+ """
2
+ Define the Simulation class.
3
+
4
+ Example usage:
5
+
6
+ >>> from llg3d.main import Simulation
7
+ >>> from llg3d.parameters import parameters
8
+ >>> run_parameters = {name: value["default"] for name, value in parameters.items()}
9
+ >>> sim = Simulation(run_parameters)
10
+ >>> sim.run()
11
+ >>> sim.save()
12
+
13
+ """
14
+
15
+ import inspect
16
+
17
+ from . import rank, size
18
+ from .element import get_element_class
19
+ from .parameters import Parameter
20
+ from .output import write_json
21
+
22
+
23
+ class Simulation:
24
+ """
25
+ Class to encapsulate the simulation logic.
26
+
27
+ Args:
28
+ params: Dictionary of simulation parameters.
29
+ """
30
+
31
+ json_file = "run.json" #: JSON file to store the results
32
+
33
+ def __init__(self, params: dict[str, Parameter]):
34
+ self.params: dict[str, Parameter] = params.copy() #: simulation parameters
35
+ self.simulate: callable = self._get_simulate_function_from_name(
36
+ self.params["solver"]
37
+ ) #: simulation function imported from the solver module
38
+ self.total_time: None | float = None #: total simulation time
39
+ self.filenames: list[str] = [] #: list of output filenames
40
+ self.m1_mean: None | float = None #: space and time average of m1
41
+ self.params["np"] = size # Add a parameter for the number of processes
42
+ # Reference the element class from the element string
43
+ self.params["element_class"] = get_element_class(params["element"])
44
+
45
+ def run(self):
46
+ """Runs the simulation and store the results."""
47
+ self.total_time, self.filenames, self.m1_mean = self.simulate(**self.params)
48
+
49
+ def _get_simulate_function_from_name(self, name: str) -> callable:
50
+ """
51
+ Retrieves the simulation function for a given solver name.
52
+
53
+ Args:
54
+ name: Name of the solver
55
+
56
+ Returns:
57
+ callable: The simulation function
58
+
59
+ Example:
60
+ >>> simulate = self.get_simulate_function_from_name("mpi")
61
+
62
+ Will return the `simulate` function from the `llg3d.solver.mpi` module.
63
+ """
64
+ module = __import__(f"llg3d.solver.{name}", fromlist=["simulate"])
65
+ return inspect.getattr_static(module, "simulate")
66
+
67
+ def save(self):
68
+ """Saves the results of the simulation to a JSON file."""
69
+ params = self.params.copy() # save the parameters
70
+ del params["element_class"] # remove class object before serialization
71
+ if rank == 0:
72
+ results = {"total_time": self.total_time}
73
+ # Export the integral of m1
74
+ if len(self.filenames) > 0:
75
+ results["integral_file"] = self.filenames[0]
76
+ print(f"Integral of m1 in {self.filenames[0]}")
77
+ # Export the x-profiles of m1, m2 and m3
78
+ for i, filename in enumerate(self.filenames[1:]):
79
+ results[f"xprofile_m{i}"] = filename
80
+ print(f"x-profile of m{i} in {filename}")
81
+
82
+ print(
83
+ f"""\
84
+ N iterations = {params["N"]}
85
+ total_time [s] = {self.total_time:.03f}
86
+ time/ite [s/iter] = {self.total_time / params["N"]:.03e}\
87
+ """
88
+ )
89
+ # Export the mean of m1
90
+ if params["N"] > params["start_averaging"]:
91
+ print(f"m1_mean = {self.m1_mean:e}")
92
+ results["m1_mean"] = float(self.m1_mean)
93
+
94
+ write_json(self.json_file, {"params": params, "results": results})
95
+ print(f"Summary in {self.json_file}")
@@ -0,0 +1,45 @@
1
+ """
2
+ Solver module for LLG3D.
3
+
4
+ This module contains different solver implementations.
5
+ """
6
+
7
+ import importlib.util
8
+
9
+
10
+ __all__ = ["numpy", "solver", "rank", "size", "comm", "status"]
11
+
12
+ LIB_AVAILABLE: dict[str, bool] = {}
13
+
14
+ # Check for other solver availability
15
+ for lib in "opencl", "jax", "mpi4py":
16
+ if importlib.util.find_spec(lib, package=__package__) is not None:
17
+ LIB_AVAILABLE[lib] = True
18
+ __all__.append(lib)
19
+ else:
20
+ LIB_AVAILABLE[lib] = False
21
+
22
+
23
+ # MPI utilities
24
+ if LIB_AVAILABLE["mpi4py"]:
25
+ from mpi4py import MPI
26
+
27
+ comm = MPI.COMM_WORLD
28
+ rank = comm.Get_rank()
29
+ size = comm.Get_size()
30
+ status = MPI.Status()
31
+ else:
32
+ # MPI library is not available: use dummy values
33
+ class DummyComm:
34
+ pass
35
+
36
+ comm = DummyComm()
37
+ rank = 0
38
+ size = 1
39
+
40
+ class DummyStatus:
41
+ pass
42
+
43
+ status = DummyStatus()
44
+
45
+ from . import numpy, solver # noqa: E402
llg3d/solver/jax.py ADDED
@@ -0,0 +1,387 @@
1
+ """LLG3D solver using XLA compilation."""
2
+
3
+ import os
4
+ import time
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jax import random
9
+
10
+ from ..output import progress_bar, get_output_files, close_output_files
11
+ from ..grid import Grid
12
+ from ..element import Element
13
+
14
+
15
+ # JIT compile individual components for better performance and modularity
16
+ @jax.jit
17
+ def compute_H_anisotropy(
18
+ m: jnp.ndarray, coeff_2: float, anisotropy: int
19
+ ) -> jnp.ndarray:
20
+ """
21
+ Compute anisotropy field (JIT compiled).
22
+
23
+ Args:
24
+ m: Magnetization array (shape (3, nx, ny, nz))
25
+ coeff_2: Coefficient for anisotropy
26
+ anisotropy: Anisotropy type (0: uniaxial, 1: cubic)
27
+
28
+ Returns:
29
+ Anisotropy field array (shape (3, nx, ny, nz))
30
+ """
31
+ m1, m2, m3 = m
32
+
33
+ m1m1 = m1 * m1
34
+ m2m2 = m2 * m2
35
+ m3m3 = m3 * m3
36
+
37
+ # Uniaxial anisotropy
38
+ aniso_1_uniaxial = m1
39
+ aniso_2_uniaxial = jnp.zeros_like(m1)
40
+ aniso_3_uniaxial = jnp.zeros_like(m1)
41
+
42
+ # Cubic anisotropy
43
+ aniso_1_cubic = -(1 - m1m1 + m2m2 * m3m3) * m1
44
+ aniso_2_cubic = -(1 - m2m2 + m1m1 * m3m3) * m2
45
+ aniso_3_cubic = -(1 - m3m3 + m1m1 * m2m2) * m3
46
+
47
+ # Select based on anisotropy type
48
+ aniso_1 = jnp.where(
49
+ anisotropy == 0,
50
+ aniso_1_uniaxial,
51
+ jnp.where(anisotropy == 1, aniso_1_cubic, jnp.zeros_like(m1)),
52
+ )
53
+ aniso_2 = jnp.where(
54
+ anisotropy == 0,
55
+ aniso_2_uniaxial,
56
+ jnp.where(anisotropy == 1, aniso_2_cubic, jnp.zeros_like(m1)),
57
+ )
58
+ aniso_3 = jnp.where(
59
+ anisotropy == 0,
60
+ aniso_3_uniaxial,
61
+ jnp.where(anisotropy == 1, aniso_3_cubic, jnp.zeros_like(m1)),
62
+ )
63
+
64
+ return coeff_2 * jnp.stack([aniso_1, aniso_2, aniso_3], axis=0)
65
+
66
+
67
+ @jax.jit
68
+ def laplacian3D(
69
+ m_i: jnp.ndarray,
70
+ dx2_inv: float,
71
+ dy2_inv: float,
72
+ dz2_inv: float,
73
+ center_coeff: float,
74
+ ) -> jnp.ndarray:
75
+ """
76
+ Compute Laplacian for a single component with Neumann boundary conditions.
77
+
78
+ (JIT compiled)
79
+
80
+ Args:
81
+ m_i: Single component of magnetization (shape (nx, ny, nz))
82
+ dx2_inv: Inverse of squared grid spacing in x direction
83
+ dy2_inv: Inverse of squared grid spacing in y direction
84
+ dz2_inv: Inverse of squared grid spacing in z direction
85
+ center_coeff: Coefficient for the center point
86
+
87
+ Returns:
88
+ Laplacian of m_i (shape (nx, ny, nz))
89
+ """
90
+ m_i_padded = jnp.pad(m_i, ((1, 1), (1, 1), (1, 1)), mode="reflect")
91
+ return (
92
+ dx2_inv * (m_i_padded[2:, 1:-1, 1:-1] + m_i_padded[:-2, 1:-1, 1:-1])
93
+ + dy2_inv * (m_i_padded[1:-1, 2:, 1:-1] + m_i_padded[1:-1, :-2, 1:-1])
94
+ + dz2_inv * (m_i_padded[1:-1, 1:-1, 2:] + m_i_padded[1:-1, 1:-1, :-2])
95
+ + center_coeff * m_i
96
+ )
97
+
98
+
99
+ @jax.jit
100
+ def compute_laplacian(m: jnp.ndarray, dx: float, dy: float, dz: float) -> jnp.ndarray:
101
+ """
102
+ Compute 3D Laplacian with Neumann boundary conditions (JIT compiled).
103
+
104
+ Args:
105
+ m: Magnetization array (shape (3, nx, ny, nz))
106
+ dx: Grid spacing in x direction
107
+ dy: Grid spacing in y direction
108
+ dz: Grid spacing in z direction
109
+
110
+ Returns:
111
+ Laplacian of m (shape (3, nx, ny, nz))
112
+ """
113
+ dx2_inv, dy2_inv, dz2_inv = 1 / dx**2, 1 / dy**2, 1 / dz**2
114
+ center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
115
+
116
+ return jnp.stack(
117
+ [
118
+ laplacian3D(m[0], dx2_inv, dy2_inv, dz2_inv, center_coeff),
119
+ laplacian3D(m[1], dx2_inv, dy2_inv, dz2_inv, center_coeff),
120
+ laplacian3D(m[2], dx2_inv, dy2_inv, dz2_inv, center_coeff),
121
+ ],
122
+ axis=0,
123
+ )
124
+
125
+
126
+ @jax.jit
127
+ def compute_space_average_jax(m1: jnp.ndarray) -> float:
128
+ """
129
+ Compute space average using midpoint method on GPU (JIT compiled).
130
+
131
+ Args:
132
+ m1: First component of magnetization (shape (nx, ny, nz))
133
+
134
+ Returns:
135
+ Space average of m1
136
+ """
137
+ # Get dimensions directly from the array shape
138
+ Jx, Jy, Jz = m1.shape
139
+
140
+ # Create 3D coordinate grids using the shape
141
+ i_coords = jnp.arange(Jx)
142
+ j_coords = jnp.arange(Jy)
143
+ k_coords = jnp.arange(Jz)
144
+
145
+ # Create 3D coordinate grids
146
+ ii, jj, kk = jnp.meshgrid(i_coords, j_coords, k_coords, indexing="ij")
147
+
148
+ # Apply midpoint weights (0.5 on edges, 1.0 elsewhere)
149
+ weights = jnp.ones_like(m1)
150
+ weights = jnp.where((ii == 0) | (ii == Jx - 1), weights * 0.5, weights)
151
+ weights = jnp.where((jj == 0) | (jj == Jy - 1), weights * 0.5, weights)
152
+ weights = jnp.where((kk == 0) | (kk == Jz - 1), weights * 0.5, weights)
153
+
154
+ # Compute weighted sum and normalize
155
+ weighted_sum = jnp.sum(weights * m1)
156
+
157
+ # Compute ncell from the weights (this is the effective cell count)
158
+ ncell = jnp.sum(weights)
159
+
160
+ return weighted_sum / ncell
161
+
162
+
163
+ @jax.jit
164
+ def cross_product(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
165
+ r"""
166
+ Compute cross product :math:`a \times b` (JIT compiled).
167
+
168
+ Args:
169
+ a: First vector (shape (3, nx, ny, nz))
170
+ b: Second vector (shape (3, nx, ny, nz))
171
+
172
+ Returns:
173
+ Cross product :math:`a \times b` (shape (3, nx, ny, nz))
174
+ """
175
+ # Use JAX's optimized cross product function directly on axis 0
176
+ return jnp.cross(a, b, axis=0)
177
+
178
+
179
+ # JIT compile the slope computation for performance
180
+ @jax.jit
181
+ def compute_slope(
182
+ g_params: dict, e_params: dict, m: jnp.ndarray, R_random: jnp.ndarray
183
+ ) -> jnp.ndarray:
184
+ """
185
+ JIT-compiled version of compute_slope_jax using modular sub-functions.
186
+
187
+ Args:
188
+ g_params: Grid parameters dict (dx, dy, dz)
189
+ e_params: Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G,
190
+ anisotropy)
191
+ m: Magnetization array (shape (3, nx, ny, nz))
192
+ R_random: Random field array (shape (3, nx, ny, nz))
193
+
194
+ Returns:
195
+ Slope array (shape (3, nx, ny, nz))
196
+ """
197
+ # Extract parameters
198
+ dx, dy, dz = g_params["dx"], g_params["dy"], g_params["dz"]
199
+ coeff_1 = e_params["coeff_1"]
200
+ coeff_2 = e_params["coeff_2"]
201
+ coeff_3 = e_params["coeff_3"]
202
+ lambda_G = e_params["lambda_G"]
203
+ anisotropy = e_params["anisotropy"]
204
+
205
+ # Compute components using modular sub-functions
206
+ H_aniso = compute_H_anisotropy(m, coeff_2, anisotropy)
207
+ laplacian_m = compute_laplacian(m, dx, dy, dz)
208
+
209
+ # Effective field
210
+ R_eff = coeff_1 * laplacian_m + R_random + H_aniso
211
+ R_eff = R_eff.at[0].add(coeff_3)
212
+
213
+ # Cross products using modular functions
214
+ m_cross_R_eff = cross_product(m, R_eff)
215
+ m_cross_m_cross_R_eff = cross_product(m, m_cross_R_eff)
216
+
217
+ return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
218
+
219
+
220
+ def simulate(
221
+ N: int,
222
+ Jx: int,
223
+ Jy: int,
224
+ Jz: int,
225
+ dx: float,
226
+ T: float,
227
+ H_ext: float,
228
+ dt: float,
229
+ start_averaging: int,
230
+ n_mean: int,
231
+ n_profile: int,
232
+ element_class: Element,
233
+ precision: str,
234
+ seed: int,
235
+ device: str = "auto",
236
+ **_,
237
+ ) -> tuple[float, str, float]:
238
+ """
239
+ Simulates the system for N iterations using JAX.
240
+
241
+ Args:
242
+ N: Number of iterations
243
+ Jx: Number of grid points in x direction
244
+ Jy: Number of grid points in y direction
245
+ Jz: Number of grid points in z direction
246
+ dx: Grid spacing
247
+ T: Temperature in Kelvin
248
+ H_ext: External magnetic field strength
249
+ dt: Time step for the simulation
250
+ start_averaging: Number of iterations for averaging
251
+ n_mean: Number of iterations for integral output
252
+ n_profile: Number of iterations for profile output
253
+ element_class: Element of the sample (default: Cobalt)
254
+ precision: Precision of the simulation (single or double)
255
+ seed: Random seed for temperature fluctuations
256
+ device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
257
+
258
+ Returns:
259
+ - The time taken for the simulation
260
+ - The output filenames
261
+ - The average magnetization
262
+ """
263
+ # Configure JAX
264
+ if device == "auto":
265
+ # Let JAX choose the best available device
266
+ pass
267
+ elif device == "cpu":
268
+ jax.config.update("jax_platform_name", "cpu")
269
+ elif device == "gpu":
270
+ jax.config.update("jax_platform_name", "gpu")
271
+ elif device.startswith("gpu:"):
272
+ # Select specific GPU using environment variable
273
+ jax.config.update("jax_platform_name", "gpu")
274
+ gpu_id = device.split(":")[1]
275
+ # Check if CUDA_VISIBLE_DEVICES is already set externally
276
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
277
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
278
+ print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
279
+ else:
280
+ cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
281
+ print(
282
+ f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
283
+ )
284
+
285
+ # Set precision
286
+ if precision == "double":
287
+ jax.config.update("jax_enable_x64", True)
288
+ jnp_float = jnp.float64
289
+ else:
290
+ jax.config.update("jax_enable_x64", False)
291
+ jnp_float = jnp.float32
292
+
293
+ print(f"Available JAX devices: {jax.devices()}")
294
+ print(f"Using JAX on device: {jax.devices()[0]}")
295
+ print(f"Precision: {precision} ({jnp_float})")
296
+
297
+ # Initialize random key for JAX
298
+ key = random.PRNGKey(seed)
299
+
300
+ g = Grid(Jx, Jy, Jz, dx)
301
+ dims = g.dims
302
+
303
+ e = element_class(T, H_ext, g, dt)
304
+ print(f"CFL = {e.get_CFL()}")
305
+
306
+ # Prepare parameters for JIT compilation using to_dict methods
307
+ g_params = g.to_dict()
308
+ e_params = e.to_dict()
309
+
310
+ # --- Initialization ---
311
+ def theta_init(shape):
312
+ """Initialization of theta."""
313
+ return jnp.zeros(shape, dtype=jnp_float)
314
+
315
+ def phi_init(t, shape):
316
+ """Initialization of phi."""
317
+ return jnp.zeros(shape, dtype=jnp_float) + e.gamma_0 * H_ext * t
318
+
319
+ m_n = jnp.zeros((3,) + dims, dtype=jnp_float)
320
+
321
+ theta = theta_init(dims)
322
+ phi = phi_init(0, dims)
323
+
324
+ m_n = m_n.at[0].set(jnp.cos(theta))
325
+ m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
326
+ m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
327
+
328
+ f_mean, f_profiles, output_filenames = get_output_files(g, T, n_mean, n_profile)
329
+
330
+ t = 0.0
331
+ m1_average = 0.0
332
+
333
+ # === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
334
+ print("Warming up JIT compilation...")
335
+
336
+ # Generate dummy random field for warmup
337
+ warmup_key = random.PRNGKey(42)
338
+ R_warmup = e.coeff_4 * random.normal(warmup_key, (3,) + dims, dtype=jnp_float)
339
+
340
+ # Warmup all JIT functions with actual data shapes
341
+ _ = compute_slope(g_params, e_params, m_n, R_warmup)
342
+ if n_mean != 0:
343
+ _ = compute_space_average_jax(m_n[0])
344
+
345
+ # Force compilation and execution to complete
346
+ jax.block_until_ready(m_n)
347
+ print("JIT warmup completed.")
348
+
349
+ start_time = time.perf_counter()
350
+
351
+ for n in progress_bar(range(1, N + 1), "Iteration : ", 40):
352
+ t += dt
353
+
354
+ # Generate random field for temperature effect
355
+ key, subkey = random.split(key)
356
+ R_random = e.coeff_4 * random.normal(subkey, (3,) + dims, dtype=jnp_float)
357
+
358
+ # Use JIT-compiled version for better performance
359
+ s_pre = compute_slope(g_params, e_params, m_n, R_random)
360
+ m_pre = m_n + dt * s_pre
361
+ s_cor = compute_slope(g_params, e_params, m_pre, R_random)
362
+
363
+ # Update magnetization
364
+ m_n = m_n + dt * 0.5 * (s_pre + s_cor)
365
+
366
+ # Renormalize to unit sphere
367
+ norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
368
+ m_n = m_n / norm
369
+
370
+ # Export the average of m1 to a file (optimized for GPU)
371
+ if n_mean != 0 and n % n_mean == 0:
372
+ # Compute space average directly on GPU without CPU transfer
373
+ m1_mean = compute_space_average_jax(m_n[0])
374
+ # Convert to Python float for file writing
375
+ m1_mean = float(m1_mean)
376
+ if n >= start_averaging:
377
+ m1_average += m1_mean * n_mean
378
+ f_mean.write(f"{t:10.8e} {m1_mean:10.8e}\n")
379
+
380
+ total_time = time.perf_counter() - start_time
381
+
382
+ close_output_files(f_mean, f_profiles)
383
+
384
+ if n > start_averaging:
385
+ m1_average /= N - start_averaging
386
+
387
+ return total_time, output_filenames, m1_average