llg3d 1.4.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,45 @@
1
+ """
2
+ Solver module for LLG3D
3
+
4
+ This module contains different solver implementations.
5
+ """
6
+
7
+ import importlib.util
8
+
9
+
10
+ __all__ = ["numpy", "solver", "rank", "size", "comm", "status"]
11
+
12
+ LIB_AVAILABLE: dict[str, bool] = {}
13
+
14
+ # Check for other solver availability
15
+ for lib in "opencl", "jax", "mpi4py":
16
+ if importlib.util.find_spec(lib, package=__package__) is not None:
17
+ LIB_AVAILABLE[lib] = True
18
+ __all__.append(lib)
19
+ else:
20
+ LIB_AVAILABLE[lib] = False
21
+
22
+
23
+ # MPI utilities
24
+ if LIB_AVAILABLE["mpi4py"]:
25
+ from mpi4py import MPI
26
+
27
+ comm = MPI.COMM_WORLD
28
+ rank = comm.Get_rank()
29
+ size = comm.Get_size()
30
+ status = MPI.Status()
31
+ else:
32
+ # MPI library is not available: use dummy values
33
+ class DummyComm:
34
+ pass
35
+
36
+ comm = DummyComm()
37
+ rank = 0
38
+ size = 1
39
+
40
+ class DummyStatus:
41
+ pass
42
+
43
+ status = DummyStatus()
44
+
45
+ from . import numpy, solver
llg3d/solver/jax.py ADDED
@@ -0,0 +1,383 @@
1
+ """
2
+ LLG3D solver using XLA compilation
3
+ """
4
+
5
+ import os
6
+ import time
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax import random
11
+
12
+ from ..output import progress_bar, get_output_files, close_output_files
13
+ from ..grid import Grid
14
+ from ..element import Element
15
+
16
+
17
+ # JIT compile individual components for better performance and modularity
18
+ @jax.jit
19
+ def compute_H_anisotropy(
20
+ m: jnp.ndarray, coeff_2: float, anisotropy: int
21
+ ) -> jnp.ndarray:
22
+ """
23
+ Compute anisotropy field (JIT compiled)
24
+
25
+ Args:
26
+ m: Magnetization array (shape (3, nx, ny, nz))
27
+ coeff_2: Coefficient for anisotropy
28
+ anisotropy: Anisotropy type (0: uniaxial, 1: cubic)
29
+
30
+ Returns:
31
+ Anisotropy field array (shape (3, nx, ny, nz))
32
+ """
33
+
34
+ m1, m2, m3 = m
35
+
36
+ m1m1 = m1 * m1
37
+ m2m2 = m2 * m2
38
+ m3m3 = m3 * m3
39
+
40
+ # Uniaxial anisotropy
41
+ aniso_1_uniaxial = m1
42
+ aniso_2_uniaxial = jnp.zeros_like(m1)
43
+ aniso_3_uniaxial = jnp.zeros_like(m1)
44
+
45
+ # Cubic anisotropy
46
+ aniso_1_cubic = -(1 - m1m1 + m2m2 * m3m3) * m1
47
+ aniso_2_cubic = -(1 - m2m2 + m1m1 * m3m3) * m2
48
+ aniso_3_cubic = -(1 - m3m3 + m1m1 * m2m2) * m3
49
+
50
+ # Select based on anisotropy type
51
+ aniso_1 = jnp.where(
52
+ anisotropy == 0,
53
+ aniso_1_uniaxial,
54
+ jnp.where(anisotropy == 1, aniso_1_cubic, jnp.zeros_like(m1)),
55
+ )
56
+ aniso_2 = jnp.where(
57
+ anisotropy == 0,
58
+ aniso_2_uniaxial,
59
+ jnp.where(anisotropy == 1, aniso_2_cubic, jnp.zeros_like(m1)),
60
+ )
61
+ aniso_3 = jnp.where(
62
+ anisotropy == 0,
63
+ aniso_3_uniaxial,
64
+ jnp.where(anisotropy == 1, aniso_3_cubic, jnp.zeros_like(m1)),
65
+ )
66
+
67
+ return coeff_2 * jnp.stack([aniso_1, aniso_2, aniso_3], axis=0)
68
+
69
+
70
+ @jax.jit
71
+ def laplacian3D(
72
+ m_i: jnp.ndarray, dx2_inv: float, dy2_inv: float, dz2_inv: float, center_coeff: float
73
+ ) -> jnp.ndarray:
74
+ """
75
+ Compute Laplacian for a single component with Neumann boundary conditions (JIT compiled)
76
+
77
+ Args:
78
+ m_i: Single component of magnetization (shape (nx, ny, nz))
79
+ dx2_inv: Inverse of squared grid spacing in x direction
80
+ dy2_inv: Inverse of squared grid spacing in y direction
81
+ dz2_inv: Inverse of squared grid spacing in z direction
82
+ center_coeff: Coefficient for the center point
83
+
84
+ Returns:
85
+ Laplacian of m_i (shape (nx, ny, nz))
86
+ """
87
+ m_i_padded = jnp.pad(m_i, ((1, 1), (1, 1), (1, 1)), mode="reflect")
88
+ return (
89
+ dx2_inv * (m_i_padded[2:, 1:-1, 1:-1] + m_i_padded[:-2, 1:-1, 1:-1])
90
+ + dy2_inv * (m_i_padded[1:-1, 2:, 1:-1] + m_i_padded[1:-1, :-2, 1:-1])
91
+ + dz2_inv * (m_i_padded[1:-1, 1:-1, 2:] + m_i_padded[1:-1, 1:-1, :-2])
92
+ + center_coeff * m_i
93
+ )
94
+
95
+
96
+ @jax.jit
97
+ def compute_laplacian(
98
+ m: jnp.ndarray, dx: float, dy: float, dz: float
99
+ ) -> jnp.ndarray:
100
+ """
101
+ Compute 3D Laplacian with Neumann boundary conditions (JIT compiled)
102
+
103
+ Args:
104
+ m: Magnetization array (shape (3, nx, ny, nz))
105
+ dx: Grid spacing in x direction
106
+ dy: Grid spacing in y direction
107
+ dz: Grid spacing in z direction
108
+
109
+ Returns:
110
+ Laplacian of m (shape (3, nx, ny, nz))
111
+ """
112
+ dx2_inv, dy2_inv, dz2_inv = 1 / dx**2, 1 / dy**2, 1 / dz**2
113
+ center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
114
+
115
+ return jnp.stack(
116
+ [
117
+ laplacian3D(m[0], dx2_inv, dy2_inv, dz2_inv, center_coeff),
118
+ laplacian3D(m[1], dx2_inv, dy2_inv, dz2_inv, center_coeff),
119
+ laplacian3D(m[2], dx2_inv, dy2_inv, dz2_inv, center_coeff),
120
+ ],
121
+ axis=0,
122
+ )
123
+
124
+
125
+ @jax.jit
126
+ def compute_space_average_jax(m1: jnp.ndarray) -> float:
127
+ """
128
+ Compute space average using midpoint method on GPU (JIT compiled)
129
+
130
+ Args:
131
+ m1: First component of magnetization (shape (nx, ny, nz))
132
+
133
+ Returns:
134
+ Space average of m1
135
+ """
136
+ # Get dimensions directly from the array shape
137
+ Jx, Jy, Jz = m1.shape
138
+
139
+ # Create 3D coordinate grids using the shape
140
+ i_coords = jnp.arange(Jx)
141
+ j_coords = jnp.arange(Jy)
142
+ k_coords = jnp.arange(Jz)
143
+
144
+ # Create 3D coordinate grids
145
+ ii, jj, kk = jnp.meshgrid(i_coords, j_coords, k_coords, indexing='ij')
146
+
147
+ # Apply midpoint weights (0.5 on edges, 1.0 elsewhere)
148
+ weights = jnp.ones_like(m1)
149
+ weights = jnp.where((ii == 0) | (ii == Jx-1), weights * 0.5, weights)
150
+ weights = jnp.where((jj == 0) | (jj == Jy-1), weights * 0.5, weights)
151
+ weights = jnp.where((kk == 0) | (kk == Jz-1), weights * 0.5, weights)
152
+
153
+ # Compute weighted sum and normalize
154
+ weighted_sum = jnp.sum(weights * m1)
155
+
156
+ # Compute ncell from the weights (this is the effective cell count)
157
+ ncell = jnp.sum(weights)
158
+
159
+ return weighted_sum / ncell
160
+
161
+
162
+ @jax.jit
163
+ def cross_product(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
164
+ r"""
165
+ Compute cross product :math:`a \times b` (JIT compiled)
166
+
167
+ Args:
168
+ a: First vector (shape (3, nx, ny, nz))
169
+ b: Second vector (shape (3, nx, ny, nz))
170
+
171
+ Returns:
172
+ Cross product :math:`a \times b` (shape (3, nx, ny, nz))
173
+ """
174
+ # Use JAX's optimized cross product function directly on axis 0
175
+ return jnp.cross(a, b, axis=0)
176
+
177
+
178
+ # JIT compile the slope computation for performance
179
+ @jax.jit
180
+ def compute_slope(
181
+ g_params: dict, e_params: dict, m: jnp.ndarray, R_random: jnp.ndarray
182
+ ) -> jnp.ndarray:
183
+ """
184
+ JIT-compiled version of compute_slope_jax using modular sub-functions
185
+
186
+ Args:
187
+ g_params: Grid parameters dict (dx, dy, dz)
188
+ e_params: Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G, anisotropy)
189
+ m: Magnetization array (shape (3, nx, ny, nz))
190
+ R_random: Random field array (shape (3, nx, ny, nz))
191
+
192
+ Returns:
193
+ Slope array (shape (3, nx, ny, nz))
194
+ """
195
+ # Extract parameters
196
+ dx, dy, dz = g_params["dx"], g_params["dy"], g_params["dz"]
197
+ coeff_1 = e_params["coeff_1"]
198
+ coeff_2 = e_params["coeff_2"]
199
+ coeff_3 = e_params["coeff_3"]
200
+ lambda_G = e_params["lambda_G"]
201
+ anisotropy = e_params["anisotropy"]
202
+
203
+ # Compute components using modular sub-functions
204
+ H_aniso = compute_H_anisotropy(m, coeff_2, anisotropy)
205
+ laplacian_m = compute_laplacian(m, dx, dy, dz)
206
+
207
+ # Effective field
208
+ R_eff = coeff_1 * laplacian_m + R_random + H_aniso
209
+ R_eff = R_eff.at[0].add(coeff_3)
210
+
211
+ # Cross products using modular functions
212
+ m_cross_R_eff = cross_product(m, R_eff)
213
+ m_cross_m_cross_R_eff = cross_product(m, m_cross_R_eff)
214
+
215
+ return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
216
+
217
+
218
+ def simulate(
219
+ N: int,
220
+ Jx: int,
221
+ Jy: int,
222
+ Jz: int,
223
+ dx: float,
224
+ T: float,
225
+ H_ext: float,
226
+ dt: float,
227
+ start_averaging: int,
228
+ n_mean: int,
229
+ n_profile: int,
230
+ element_class: Element,
231
+ precision: str,
232
+ seed: int,
233
+ device: str = "auto",
234
+ **_,
235
+ ) -> tuple[float, str, float]:
236
+ """
237
+ Simulates the system for N iterations using JAX
238
+
239
+ Args:
240
+ N: Number of iterations
241
+ Jx: Number of grid points in x direction
242
+ Jy: Number of grid points in y direction
243
+ Jz: Number of grid points in z direction
244
+ dx: Grid spacing
245
+ T: Temperature in Kelvin
246
+ H_ext: External magnetic field strength
247
+ dt: Time step for the simulation
248
+ start_averaging: Number of iterations for averaging
249
+ n_mean: Number of iterations for integral output
250
+ n_profile: Number of iterations for profile output
251
+ element_class: Element of the sample (default: Cobalt)
252
+ precision: Precision of the simulation (single or double)
253
+ seed: Random seed for temperature fluctuations
254
+ device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
255
+
256
+ Returns:
257
+ - The time taken for the simulation
258
+ - The output filenames
259
+ - The average magnetization
260
+ """
261
+
262
+ # Configure JAX
263
+ if device == "auto":
264
+ # Let JAX choose the best available device
265
+ pass
266
+ elif device == "cpu":
267
+ jax.config.update("jax_platform_name", "cpu")
268
+ elif device == "gpu":
269
+ jax.config.update("jax_platform_name", "gpu")
270
+ elif device.startswith("gpu:"):
271
+ # Select specific GPU using environment variable
272
+ jax.config.update("jax_platform_name", "gpu")
273
+ gpu_id = device.split(":")[1]
274
+ # Check if CUDA_VISIBLE_DEVICES is already set externally
275
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
276
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
277
+ print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
278
+ else:
279
+ print(f"Using external CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
280
+
281
+ # Set precision
282
+ if precision == "double":
283
+ jax.config.update("jax_enable_x64", True)
284
+ jnp_float = jnp.float64
285
+ else:
286
+ jax.config.update("jax_enable_x64", False)
287
+ jnp_float = jnp.float32
288
+
289
+ print(f"Available JAX devices: {jax.devices()}")
290
+ print(f"Using JAX on device: {jax.devices()[0]}")
291
+ print(f"Precision: {precision} ({jnp_float})")
292
+
293
+ # Initialize random key for JAX
294
+ key = random.PRNGKey(seed)
295
+
296
+ g = Grid(Jx, Jy, Jz, dx)
297
+ dims = g.dims
298
+
299
+ e = element_class(T, H_ext, g, dt)
300
+ print(f"CFL = {e.get_CFL()}")
301
+
302
+ # Prepare parameters for JIT compilation using to_dict methods
303
+ g_params = g.to_dict()
304
+ e_params = e.to_dict()
305
+
306
+ # --- Initialization ---
307
+ def theta_init(shape):
308
+ """Initialization of theta"""
309
+ return jnp.zeros(shape, dtype=jnp_float)
310
+
311
+ def phi_init(t, shape):
312
+ """Initialization of phi"""
313
+ return jnp.zeros(shape, dtype=jnp_float) + e.gamma_0 * H_ext * t
314
+
315
+ m_n = jnp.zeros((3,) + dims, dtype=jnp_float)
316
+
317
+ theta = theta_init(dims)
318
+ phi = phi_init(0, dims)
319
+
320
+ m_n = m_n.at[0].set(jnp.cos(theta))
321
+ m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
322
+ m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
323
+
324
+ f_mean, f_profiles, output_filenames = get_output_files(g, T, n_mean, n_profile)
325
+
326
+ t = 0.0
327
+ m1_average = 0.0
328
+
329
+ # === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
330
+ print("Warming up JIT compilation...")
331
+
332
+ # Generate dummy random field for warmup
333
+ warmup_key = random.PRNGKey(42)
334
+ R_warmup = e.coeff_4 * random.normal(warmup_key, (3,) + dims, dtype=jnp_float)
335
+
336
+ # Warmup all JIT functions with actual data shapes
337
+ _ = compute_slope(g_params, e_params, m_n, R_warmup)
338
+ if n_mean != 0:
339
+ _ = compute_space_average_jax(m_n[0])
340
+
341
+ # Force compilation and execution to complete
342
+ jax.block_until_ready(m_n)
343
+ print("JIT warmup completed.")
344
+
345
+ start_time = time.perf_counter()
346
+
347
+ for n in progress_bar(range(1, N + 1), "Iteration : ", 40):
348
+ t += dt
349
+
350
+ # Generate random field for temperature effect
351
+ key, subkey = random.split(key)
352
+ R_random = e.coeff_4 * random.normal(subkey, (3,) + dims, dtype=jnp_float)
353
+
354
+ # Use JIT-compiled version for better performance
355
+ s_pre = compute_slope(g_params, e_params, m_n, R_random)
356
+ m_pre = m_n + dt * s_pre
357
+ s_cor = compute_slope(g_params, e_params, m_pre, R_random)
358
+
359
+ # Update magnetization
360
+ m_n = m_n + dt * 0.5 * (s_pre + s_cor)
361
+
362
+ # Renormalize to unit sphere
363
+ norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
364
+ m_n = m_n / norm
365
+
366
+ # Export the average of m1 to a file (optimized for GPU)
367
+ if n_mean != 0 and n % n_mean == 0:
368
+ # Compute space average directly on GPU without CPU transfer
369
+ m1_mean = compute_space_average_jax(m_n[0])
370
+ # Convert to Python float for file writing
371
+ m1_mean = float(m1_mean)
372
+ if n >= start_averaging:
373
+ m1_average += m1_mean * n_mean
374
+ f_mean.write(f"{t:10.8e} {m1_mean:10.8e}\n")
375
+
376
+ total_time = time.perf_counter() - start_time
377
+
378
+ close_output_files(f_mean, f_profiles)
379
+
380
+ if n > start_averaging:
381
+ m1_average /= N - start_averaging
382
+
383
+ return total_time, output_filenames, m1_average