llg3d 2.0.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. llg3d/__init__.py +3 -3
  2. llg3d/__main__.py +2 -2
  3. llg3d/benchmarks/__init__.py +1 -0
  4. llg3d/benchmarks/compare_commits.py +321 -0
  5. llg3d/benchmarks/efficiency.py +451 -0
  6. llg3d/benchmarks/utils.py +25 -0
  7. llg3d/element.py +118 -31
  8. llg3d/grid.py +51 -64
  9. llg3d/io.py +395 -0
  10. llg3d/main.py +36 -38
  11. llg3d/parameters.py +159 -49
  12. llg3d/post/__init__.py +1 -1
  13. llg3d/post/extract.py +105 -0
  14. llg3d/post/info.py +178 -0
  15. llg3d/post/m1_vs_T.py +90 -0
  16. llg3d/post/m1_vs_time.py +56 -0
  17. llg3d/post/process.py +82 -75
  18. llg3d/post/utils.py +38 -0
  19. llg3d/post/x_profiles.py +141 -0
  20. llg3d/py.typed +1 -0
  21. llg3d/solvers/__init__.py +153 -0
  22. llg3d/solvers/base.py +345 -0
  23. llg3d/solvers/experimental/__init__.py +9 -0
  24. llg3d/solvers/experimental/jax.py +361 -0
  25. llg3d/solvers/math_utils.py +41 -0
  26. llg3d/solvers/mpi.py +370 -0
  27. llg3d/solvers/numpy.py +126 -0
  28. llg3d/solvers/opencl.py +439 -0
  29. llg3d/solvers/profiling.py +38 -0
  30. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/METADATA +6 -3
  31. llg3d-3.0.0.dist-info/RECORD +36 -0
  32. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/WHEEL +1 -1
  33. llg3d-3.0.0.dist-info/entry_points.txt +9 -0
  34. llg3d/output.py +0 -108
  35. llg3d/post/plot_results.py +0 -65
  36. llg3d/post/temperature.py +0 -83
  37. llg3d/simulation.py +0 -104
  38. llg3d/solver/__init__.py +0 -45
  39. llg3d/solver/jax.py +0 -383
  40. llg3d/solver/mpi.py +0 -449
  41. llg3d/solver/numpy.py +0 -210
  42. llg3d/solver/opencl.py +0 -329
  43. llg3d/solver/solver.py +0 -93
  44. llg3d-2.0.0.dist-info/RECORD +0 -25
  45. llg3d-2.0.0.dist-info/entry_points.txt +0 -4
  46. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/AUTHORS +0 -0
  47. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/LICENSE +0 -0
  48. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,361 @@
1
+ """
2
+ LLG3D solver using XLA compilation.
3
+
4
+ .. warning::
5
+
6
+ It is experimental and not maintained.
7
+ """
8
+
9
+ from typing import ClassVar
10
+
11
+ import os
12
+ import time
13
+
14
+ import jax
15
+ import jax.numpy as jnp
16
+ from jax import random
17
+
18
+ from ..base import BaseSolver
19
+
20
+
21
+ # JIT compile individual components for better performance and modularity
22
+ @jax.jit
23
+ def compute_H_anisotropy(
24
+ m: jnp.ndarray, coeff_2: float, anisotropy: int
25
+ ) -> jnp.ndarray:
26
+ """
27
+ Compute anisotropy field (JIT compiled).
28
+
29
+ Args:
30
+ m: Magnetization array (shape (3, nx, ny, nz))
31
+ coeff_2: Coefficient for anisotropy
32
+ anisotropy: Anisotropy type (0: uniaxial, 1: cubic)
33
+
34
+ Returns:
35
+ Anisotropy field array (shape (3, nx, ny, nz))
36
+ """
37
+ m1, m2, m3 = m
38
+
39
+ m1m1 = m1 * m1
40
+ m2m2 = m2 * m2
41
+ m3m3 = m3 * m3
42
+
43
+ # Uniaxial anisotropy
44
+ aniso_1_uniaxial = m1
45
+ aniso_2_uniaxial = jnp.zeros_like(m1)
46
+ aniso_3_uniaxial = jnp.zeros_like(m1)
47
+
48
+ # Cubic anisotropy
49
+ aniso_1_cubic = -(1 - m1m1 + m2m2 * m3m3) * m1
50
+ aniso_2_cubic = -(1 - m2m2 + m1m1 * m3m3) * m2
51
+ aniso_3_cubic = -(1 - m3m3 + m1m1 * m2m2) * m3
52
+
53
+ # Select based on anisotropy type
54
+ aniso_1 = jnp.where(
55
+ anisotropy == 0,
56
+ aniso_1_uniaxial,
57
+ jnp.where(anisotropy == 1, aniso_1_cubic, jnp.zeros_like(m1)),
58
+ )
59
+ aniso_2 = jnp.where(
60
+ anisotropy == 0,
61
+ aniso_2_uniaxial,
62
+ jnp.where(anisotropy == 1, aniso_2_cubic, jnp.zeros_like(m1)),
63
+ )
64
+ aniso_3 = jnp.where(
65
+ anisotropy == 0,
66
+ aniso_3_uniaxial,
67
+ jnp.where(anisotropy == 1, aniso_3_cubic, jnp.zeros_like(m1)),
68
+ )
69
+
70
+ return coeff_2 * jnp.stack([aniso_1, aniso_2, aniso_3], axis=0)
71
+
72
+
73
+ @jax.jit
74
+ def laplacian3D(
75
+ m_i: jnp.ndarray,
76
+ dx2_inv: float,
77
+ dy2_inv: float,
78
+ dz2_inv: float,
79
+ center_coeff: float,
80
+ ) -> jnp.ndarray:
81
+ """
82
+ Compute Laplacian for a single component with Neumann boundary conditions.
83
+
84
+ (JIT compiled)
85
+
86
+ Args:
87
+ m_i: Single component of magnetization (shape (nx, ny, nz))
88
+ dx2_inv: Inverse of squared grid spacing in x direction
89
+ dy2_inv: Inverse of squared grid spacing in y direction
90
+ dz2_inv: Inverse of squared grid spacing in z direction
91
+ center_coeff: Coefficient for the center point
92
+
93
+ Returns:
94
+ Laplacian of m_i (shape (nx, ny, nz))
95
+ """
96
+ m_i_padded = jnp.pad(m_i, ((1, 1), (1, 1), (1, 1)), mode="reflect")
97
+ return (
98
+ dx2_inv * (m_i_padded[2:, 1:-1, 1:-1] + m_i_padded[:-2, 1:-1, 1:-1])
99
+ + dy2_inv * (m_i_padded[1:-1, 2:, 1:-1] + m_i_padded[1:-1, :-2, 1:-1])
100
+ + dz2_inv * (m_i_padded[1:-1, 1:-1, 2:] + m_i_padded[1:-1, 1:-1, :-2])
101
+ + center_coeff * m_i
102
+ )
103
+
104
+
105
+ @jax.jit
106
+ def compute_laplacian(m: jnp.ndarray, dx: float, dy: float, dz: float) -> jnp.ndarray:
107
+ """
108
+ Compute 3D Laplacian with Neumann boundary conditions (JIT compiled).
109
+
110
+ Args:
111
+ m: Magnetization array (shape (3, nx, ny, nz))
112
+ dx: Grid spacing in x direction
113
+ dy: Grid spacing in y direction
114
+ dz: Grid spacing in z direction
115
+
116
+ Returns:
117
+ Laplacian of m (shape (3, nx, ny, nz))
118
+ """
119
+ dx2_inv, dy2_inv, dz2_inv = 1 / dx**2, 1 / dy**2, 1 / dz**2
120
+ center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
121
+
122
+ return jnp.stack(
123
+ [
124
+ laplacian3D(m[0], dx2_inv, dy2_inv, dz2_inv, center_coeff),
125
+ laplacian3D(m[1], dx2_inv, dy2_inv, dz2_inv, center_coeff),
126
+ laplacian3D(m[2], dx2_inv, dy2_inv, dz2_inv, center_coeff),
127
+ ],
128
+ axis=0,
129
+ )
130
+
131
+
132
+ @jax.jit
133
+ def compute_space_average_jax(m1: jnp.ndarray) -> float:
134
+ """
135
+ Compute space average using midpoint method on GPU (JIT compiled).
136
+
137
+ Args:
138
+ m1: First component of magnetization (shape (nx, ny, nz))
139
+
140
+ Returns:
141
+ Space average of m1
142
+ """
143
+ # Get dimensions directly from the array shape
144
+ Jx, Jy, Jz = m1.shape
145
+
146
+ # Create 3D coordinate grids using the shape
147
+ i_coords = jnp.arange(Jx)
148
+ j_coords = jnp.arange(Jy)
149
+ k_coords = jnp.arange(Jz)
150
+
151
+ # Create 3D coordinate grids
152
+ ii, jj, kk = jnp.meshgrid(i_coords, j_coords, k_coords, indexing="ij")
153
+
154
+ # Apply midpoint weights (0.5 on edges, 1.0 elsewhere)
155
+ weights = jnp.ones_like(m1)
156
+ weights = jnp.where((ii == 0) | (ii == Jx - 1), weights * 0.5, weights)
157
+ weights = jnp.where((jj == 0) | (jj == Jy - 1), weights * 0.5, weights)
158
+ weights = jnp.where((kk == 0) | (kk == Jz - 1), weights * 0.5, weights)
159
+
160
+ # Compute weighted sum and normalize
161
+ weighted_sum = jnp.sum(weights * m1)
162
+
163
+ # Compute ncell from the weights (this is the effective cell count)
164
+ ncell = jnp.sum(weights)
165
+
166
+ return weighted_sum / ncell # type: ignore
167
+
168
+
169
+ @jax.jit
170
+ def cross_product(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
171
+ r"""
172
+ Compute cross product :math:`a \times b` (JIT compiled).
173
+
174
+ Args:
175
+ a: First vector (shape (3, nx, ny, nz))
176
+ b: Second vector (shape (3, nx, ny, nz))
177
+
178
+ Returns:
179
+ Cross product :math:`a \times b` (shape (3, nx, ny, nz))
180
+ """
181
+ # Use JAX's optimized cross product function directly on axis 0
182
+ return jnp.cross(a, b, axis=0)
183
+
184
+
185
+ # JIT compile the slope computation for performance
186
+ @jax.jit
187
+ def compute_slope(
188
+ g_params: dict, e_params: dict, m: jnp.ndarray, R_random: jnp.ndarray
189
+ ) -> jnp.ndarray:
190
+ """
191
+ JIT-compiled version of compute_slope_jax using modular sub-functions.
192
+
193
+ Args:
194
+ g_params: Grid parameters dict (dx, dy, dz)
195
+ e_params: Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G,
196
+ anisotropy)
197
+ m: Magnetization array (shape (3, nx, ny, nz))
198
+ R_random: Random field array (shape (3, nx, ny, nz))
199
+
200
+ Returns:
201
+ Slope array (shape (3, nx, ny, nz))
202
+ """
203
+ # Extract parameters
204
+ dx, dy, dz = g_params["dx"], g_params["dy"], g_params["dz"]
205
+ coeff_1 = e_params["coeff_1"]
206
+ coeff_2 = e_params["coeff_2"]
207
+ coeff_3 = e_params["coeff_3"]
208
+ lambda_G = e_params["lambda_G"]
209
+ anisotropy = e_params["anisotropy"]
210
+
211
+ # Compute components using modular sub-functions
212
+ H_aniso = compute_H_anisotropy(m, coeff_2, anisotropy)
213
+ laplacian_m = compute_laplacian(m, dx, dy, dz)
214
+
215
+ # Effective field
216
+ R_eff = coeff_1 * laplacian_m + R_random + H_aniso
217
+ R_eff = R_eff.at[0].add(coeff_3)
218
+
219
+ # Cross products using modular functions
220
+ m_cross_R_eff = cross_product(m, R_eff)
221
+ m_cross_m_cross_R_eff = cross_product(m, m_cross_R_eff)
222
+
223
+ return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
224
+
225
+
226
+ class JaxSolver(BaseSolver):
227
+ """JAX-based LLG3D solver."""
228
+
229
+ solver_type: ClassVar[str] = "jax" #: Solver type name
230
+
231
+ def _xyz_average(self, m1: jnp.ndarray) -> float:
232
+ """Compute the space average of m1 using JAX."""
233
+ return compute_space_average_jax(m1)
234
+
235
+ def _simulate(self) -> float:
236
+ """
237
+ Simulates the system for N iterations using JAX.
238
+
239
+ Attributes:
240
+ device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
241
+
242
+ Returns:
243
+ The time taken for the simulation
244
+
245
+ Raises:
246
+ NotImplementedError: If n_profile is not zero
247
+ """
248
+ if self.n_profile != 0:
249
+ raise NotImplementedError(
250
+ "Saving x-profiles is not implemented for the JAX solver."
251
+ )
252
+
253
+ # Configure JAX
254
+ if self.device == "auto":
255
+ # Let JAX choose the best available device
256
+ pass
257
+ elif self.device == "cpu":
258
+ jax.config.update("jax_platform_name", "cpu")
259
+ elif self.device == "gpu":
260
+ jax.config.update("jax_platform_name", "gpu")
261
+ elif self.device.startswith("gpu:"):
262
+ # Select specific GPU using environment variable
263
+ jax.config.update("jax_platform_name", "gpu")
264
+ gpu_id = self.device.split(":")[1]
265
+ # Check if CUDA_VISIBLE_DEVICES is already set externally
266
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
267
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
268
+ print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
269
+ else:
270
+ cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
271
+ print(f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}")
272
+
273
+ # Set precision
274
+ if self.precision == "double":
275
+ jax.config.update("jax_enable_x64", True)
276
+ jnp_float = jnp.float64
277
+ else:
278
+ jax.config.update("jax_enable_x64", False)
279
+ jnp_float = jnp.float32
280
+
281
+ print(f"Available JAX devices: {jax.devices()}")
282
+ print(f"Using JAX on device: {jax.devices()[0]}")
283
+ print(f"Precision: {self.precision} ({jnp_float})")
284
+
285
+ # Initialize random key for JAX
286
+ key = random.PRNGKey(self.seed)
287
+
288
+ # Prepare parameters for JIT compilation using to_dict methods
289
+ g_params = self.grid.as_dict()
290
+ e_params = self.elem.to_dict()
291
+
292
+ # --- Initialization ---
293
+ def theta_init(shape):
294
+ """Initialization of theta."""
295
+ return jnp.zeros(shape, dtype=jnp_float)
296
+
297
+ def phi_init(t, shape):
298
+ """Initialization of phi."""
299
+ return (
300
+ jnp.zeros(shape, dtype=jnp_float) + self.elem.gamma_0 * self.H_ext * t
301
+ )
302
+
303
+ m_n = jnp.zeros((3, *self.grid.dims), dtype=jnp_float)
304
+
305
+ theta = theta_init(self.grid.dims)
306
+ phi = phi_init(0, self.grid.dims)
307
+
308
+ m_n = m_n.at[0].set(jnp.cos(theta))
309
+ m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
310
+ m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
311
+
312
+ t = 0.0
313
+
314
+ # === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
315
+ print("Warming up JIT compilation...")
316
+
317
+ # Generate dummy random field for warmup
318
+ warmup_key = random.PRNGKey(42)
319
+ R_warmup = self.elem.coeff_4 * random.normal(
320
+ warmup_key, (3, *self.grid.dims), dtype=jnp_float
321
+ )
322
+
323
+ # Warmup all JIT functions with actual data shapes
324
+ _ = compute_slope(g_params, e_params, m_n, R_warmup)
325
+ if self.n_mean != 0:
326
+ _ = compute_space_average_jax(m_n[0])
327
+
328
+ # Force compilation and execution to complete
329
+ jax.block_until_ready(m_n)
330
+ print("JIT warmup completed.")
331
+
332
+ start_time = time.perf_counter()
333
+
334
+ for n in self._progress_bar():
335
+ t += self.dt
336
+
337
+ # Generate random field for temperature effect
338
+ key, subkey = random.split(key)
339
+ R_random = self.elem.coeff_4 * random.normal(
340
+ subkey, (3, *self.grid.dims), dtype=jnp_float
341
+ )
342
+
343
+ # Use JIT-compiled version for better performance
344
+ s_pre = compute_slope(g_params, e_params, m_n, R_random)
345
+ m_pre = m_n + self.dt * s_pre
346
+ s_cor = compute_slope(g_params, e_params, m_pre, R_random)
347
+
348
+ # Update magnetization
349
+ m_n = m_n + self.dt * 0.5 * (s_pre + s_cor)
350
+
351
+ # Renormalize to unit sphere
352
+ norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
353
+ m_n = m_n / norm
354
+
355
+ self._record(m_n, t, n)
356
+
357
+ total_time = time.perf_counter() - start_time
358
+
359
+ self._finalize()
360
+
361
+ return total_time
@@ -0,0 +1,41 @@
1
+ """Mathematical utility functions for solvers."""
2
+
3
+ import numpy as np
4
+
5
+
6
+ def cross_product(a: np.ndarray, b: np.ndarray) -> np.ndarray:
7
+ r"""
8
+ Compute cross product :math:`a \times b`.
9
+
10
+ This implementation is faster than np.cross for large arrays.
11
+
12
+ Args:
13
+ a: First vector (shape (3, nx, ny, nz))
14
+ b: Second vector (shape (3, nx, ny, nz))
15
+
16
+ Returns:
17
+ Cross product :math:`a \times b` (shape (3, nx, ny, nz))
18
+ """
19
+ return np.stack(
20
+ [
21
+ a[1] * b[2] - a[2] * b[1], # x-component
22
+ a[2] * b[0] - a[0] * b[2], # y-component
23
+ a[0] * b[1] - a[1] * b[0], # z-component
24
+ ],
25
+ axis=0,
26
+ )
27
+
28
+
29
+ def normalize(m_n: np.ndarray):
30
+ r"""
31
+ Normalize the magnetization array (in place).
32
+
33
+ .. math::
34
+
35
+ \mathbf{m}_n = \frac{\mathbf{m}_n}{|\mathbf{m}_n|}
36
+
37
+ Args:
38
+ m_n: Magnetization array at time step n (shape (3, nx, ny, nz)).
39
+ """
40
+ norm = np.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
41
+ m_n /= norm