llg3d 2.0.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llg3d/__init__.py +3 -3
- llg3d/__main__.py +2 -2
- llg3d/benchmarks/__init__.py +1 -0
- llg3d/benchmarks/compare_commits.py +321 -0
- llg3d/benchmarks/efficiency.py +451 -0
- llg3d/benchmarks/utils.py +25 -0
- llg3d/element.py +118 -31
- llg3d/grid.py +51 -64
- llg3d/io.py +395 -0
- llg3d/main.py +36 -38
- llg3d/parameters.py +159 -49
- llg3d/post/__init__.py +1 -1
- llg3d/post/extract.py +105 -0
- llg3d/post/info.py +178 -0
- llg3d/post/m1_vs_T.py +90 -0
- llg3d/post/m1_vs_time.py +56 -0
- llg3d/post/process.py +82 -75
- llg3d/post/utils.py +38 -0
- llg3d/post/x_profiles.py +141 -0
- llg3d/py.typed +1 -0
- llg3d/solvers/__init__.py +153 -0
- llg3d/solvers/base.py +345 -0
- llg3d/solvers/experimental/__init__.py +9 -0
- llg3d/solvers/experimental/jax.py +361 -0
- llg3d/solvers/math_utils.py +41 -0
- llg3d/solvers/mpi.py +370 -0
- llg3d/solvers/numpy.py +126 -0
- llg3d/solvers/opencl.py +439 -0
- llg3d/solvers/profiling.py +38 -0
- {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/METADATA +6 -3
- llg3d-3.0.0.dist-info/RECORD +36 -0
- {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/WHEEL +1 -1
- llg3d-3.0.0.dist-info/entry_points.txt +9 -0
- llg3d/output.py +0 -108
- llg3d/post/plot_results.py +0 -65
- llg3d/post/temperature.py +0 -83
- llg3d/simulation.py +0 -104
- llg3d/solver/__init__.py +0 -45
- llg3d/solver/jax.py +0 -383
- llg3d/solver/mpi.py +0 -449
- llg3d/solver/numpy.py +0 -210
- llg3d/solver/opencl.py +0 -329
- llg3d/solver/solver.py +0 -93
- llg3d-2.0.0.dist-info/RECORD +0 -25
- llg3d-2.0.0.dist-info/entry_points.txt +0 -4
- {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/AUTHORS +0 -0
- {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/LICENSE +0 -0
- {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLG3D solver using XLA compilation.
|
|
3
|
+
|
|
4
|
+
.. warning::
|
|
5
|
+
|
|
6
|
+
It is experimental and not maintained.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import ClassVar
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
import jax
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
from jax import random
|
|
17
|
+
|
|
18
|
+
from ..base import BaseSolver
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# JIT compile individual components for better performance and modularity
|
|
22
|
+
@jax.jit
|
|
23
|
+
def compute_H_anisotropy(
|
|
24
|
+
m: jnp.ndarray, coeff_2: float, anisotropy: int
|
|
25
|
+
) -> jnp.ndarray:
|
|
26
|
+
"""
|
|
27
|
+
Compute anisotropy field (JIT compiled).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
31
|
+
coeff_2: Coefficient for anisotropy
|
|
32
|
+
anisotropy: Anisotropy type (0: uniaxial, 1: cubic)
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Anisotropy field array (shape (3, nx, ny, nz))
|
|
36
|
+
"""
|
|
37
|
+
m1, m2, m3 = m
|
|
38
|
+
|
|
39
|
+
m1m1 = m1 * m1
|
|
40
|
+
m2m2 = m2 * m2
|
|
41
|
+
m3m3 = m3 * m3
|
|
42
|
+
|
|
43
|
+
# Uniaxial anisotropy
|
|
44
|
+
aniso_1_uniaxial = m1
|
|
45
|
+
aniso_2_uniaxial = jnp.zeros_like(m1)
|
|
46
|
+
aniso_3_uniaxial = jnp.zeros_like(m1)
|
|
47
|
+
|
|
48
|
+
# Cubic anisotropy
|
|
49
|
+
aniso_1_cubic = -(1 - m1m1 + m2m2 * m3m3) * m1
|
|
50
|
+
aniso_2_cubic = -(1 - m2m2 + m1m1 * m3m3) * m2
|
|
51
|
+
aniso_3_cubic = -(1 - m3m3 + m1m1 * m2m2) * m3
|
|
52
|
+
|
|
53
|
+
# Select based on anisotropy type
|
|
54
|
+
aniso_1 = jnp.where(
|
|
55
|
+
anisotropy == 0,
|
|
56
|
+
aniso_1_uniaxial,
|
|
57
|
+
jnp.where(anisotropy == 1, aniso_1_cubic, jnp.zeros_like(m1)),
|
|
58
|
+
)
|
|
59
|
+
aniso_2 = jnp.where(
|
|
60
|
+
anisotropy == 0,
|
|
61
|
+
aniso_2_uniaxial,
|
|
62
|
+
jnp.where(anisotropy == 1, aniso_2_cubic, jnp.zeros_like(m1)),
|
|
63
|
+
)
|
|
64
|
+
aniso_3 = jnp.where(
|
|
65
|
+
anisotropy == 0,
|
|
66
|
+
aniso_3_uniaxial,
|
|
67
|
+
jnp.where(anisotropy == 1, aniso_3_cubic, jnp.zeros_like(m1)),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return coeff_2 * jnp.stack([aniso_1, aniso_2, aniso_3], axis=0)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@jax.jit
|
|
74
|
+
def laplacian3D(
|
|
75
|
+
m_i: jnp.ndarray,
|
|
76
|
+
dx2_inv: float,
|
|
77
|
+
dy2_inv: float,
|
|
78
|
+
dz2_inv: float,
|
|
79
|
+
center_coeff: float,
|
|
80
|
+
) -> jnp.ndarray:
|
|
81
|
+
"""
|
|
82
|
+
Compute Laplacian for a single component with Neumann boundary conditions.
|
|
83
|
+
|
|
84
|
+
(JIT compiled)
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
m_i: Single component of magnetization (shape (nx, ny, nz))
|
|
88
|
+
dx2_inv: Inverse of squared grid spacing in x direction
|
|
89
|
+
dy2_inv: Inverse of squared grid spacing in y direction
|
|
90
|
+
dz2_inv: Inverse of squared grid spacing in z direction
|
|
91
|
+
center_coeff: Coefficient for the center point
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Laplacian of m_i (shape (nx, ny, nz))
|
|
95
|
+
"""
|
|
96
|
+
m_i_padded = jnp.pad(m_i, ((1, 1), (1, 1), (1, 1)), mode="reflect")
|
|
97
|
+
return (
|
|
98
|
+
dx2_inv * (m_i_padded[2:, 1:-1, 1:-1] + m_i_padded[:-2, 1:-1, 1:-1])
|
|
99
|
+
+ dy2_inv * (m_i_padded[1:-1, 2:, 1:-1] + m_i_padded[1:-1, :-2, 1:-1])
|
|
100
|
+
+ dz2_inv * (m_i_padded[1:-1, 1:-1, 2:] + m_i_padded[1:-1, 1:-1, :-2])
|
|
101
|
+
+ center_coeff * m_i
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@jax.jit
|
|
106
|
+
def compute_laplacian(m: jnp.ndarray, dx: float, dy: float, dz: float) -> jnp.ndarray:
|
|
107
|
+
"""
|
|
108
|
+
Compute 3D Laplacian with Neumann boundary conditions (JIT compiled).
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
112
|
+
dx: Grid spacing in x direction
|
|
113
|
+
dy: Grid spacing in y direction
|
|
114
|
+
dz: Grid spacing in z direction
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Laplacian of m (shape (3, nx, ny, nz))
|
|
118
|
+
"""
|
|
119
|
+
dx2_inv, dy2_inv, dz2_inv = 1 / dx**2, 1 / dy**2, 1 / dz**2
|
|
120
|
+
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
|
|
121
|
+
|
|
122
|
+
return jnp.stack(
|
|
123
|
+
[
|
|
124
|
+
laplacian3D(m[0], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
125
|
+
laplacian3D(m[1], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
126
|
+
laplacian3D(m[2], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
127
|
+
],
|
|
128
|
+
axis=0,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@jax.jit
|
|
133
|
+
def compute_space_average_jax(m1: jnp.ndarray) -> float:
|
|
134
|
+
"""
|
|
135
|
+
Compute space average using midpoint method on GPU (JIT compiled).
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
m1: First component of magnetization (shape (nx, ny, nz))
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Space average of m1
|
|
142
|
+
"""
|
|
143
|
+
# Get dimensions directly from the array shape
|
|
144
|
+
Jx, Jy, Jz = m1.shape
|
|
145
|
+
|
|
146
|
+
# Create 3D coordinate grids using the shape
|
|
147
|
+
i_coords = jnp.arange(Jx)
|
|
148
|
+
j_coords = jnp.arange(Jy)
|
|
149
|
+
k_coords = jnp.arange(Jz)
|
|
150
|
+
|
|
151
|
+
# Create 3D coordinate grids
|
|
152
|
+
ii, jj, kk = jnp.meshgrid(i_coords, j_coords, k_coords, indexing="ij")
|
|
153
|
+
|
|
154
|
+
# Apply midpoint weights (0.5 on edges, 1.0 elsewhere)
|
|
155
|
+
weights = jnp.ones_like(m1)
|
|
156
|
+
weights = jnp.where((ii == 0) | (ii == Jx - 1), weights * 0.5, weights)
|
|
157
|
+
weights = jnp.where((jj == 0) | (jj == Jy - 1), weights * 0.5, weights)
|
|
158
|
+
weights = jnp.where((kk == 0) | (kk == Jz - 1), weights * 0.5, weights)
|
|
159
|
+
|
|
160
|
+
# Compute weighted sum and normalize
|
|
161
|
+
weighted_sum = jnp.sum(weights * m1)
|
|
162
|
+
|
|
163
|
+
# Compute ncell from the weights (this is the effective cell count)
|
|
164
|
+
ncell = jnp.sum(weights)
|
|
165
|
+
|
|
166
|
+
return weighted_sum / ncell # type: ignore
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@jax.jit
|
|
170
|
+
def cross_product(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
|
|
171
|
+
r"""
|
|
172
|
+
Compute cross product :math:`a \times b` (JIT compiled).
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
a: First vector (shape (3, nx, ny, nz))
|
|
176
|
+
b: Second vector (shape (3, nx, ny, nz))
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Cross product :math:`a \times b` (shape (3, nx, ny, nz))
|
|
180
|
+
"""
|
|
181
|
+
# Use JAX's optimized cross product function directly on axis 0
|
|
182
|
+
return jnp.cross(a, b, axis=0)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# JIT compile the slope computation for performance
|
|
186
|
+
@jax.jit
|
|
187
|
+
def compute_slope(
|
|
188
|
+
g_params: dict, e_params: dict, m: jnp.ndarray, R_random: jnp.ndarray
|
|
189
|
+
) -> jnp.ndarray:
|
|
190
|
+
"""
|
|
191
|
+
JIT-compiled version of compute_slope_jax using modular sub-functions.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
g_params: Grid parameters dict (dx, dy, dz)
|
|
195
|
+
e_params: Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G,
|
|
196
|
+
anisotropy)
|
|
197
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
198
|
+
R_random: Random field array (shape (3, nx, ny, nz))
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Slope array (shape (3, nx, ny, nz))
|
|
202
|
+
"""
|
|
203
|
+
# Extract parameters
|
|
204
|
+
dx, dy, dz = g_params["dx"], g_params["dy"], g_params["dz"]
|
|
205
|
+
coeff_1 = e_params["coeff_1"]
|
|
206
|
+
coeff_2 = e_params["coeff_2"]
|
|
207
|
+
coeff_3 = e_params["coeff_3"]
|
|
208
|
+
lambda_G = e_params["lambda_G"]
|
|
209
|
+
anisotropy = e_params["anisotropy"]
|
|
210
|
+
|
|
211
|
+
# Compute components using modular sub-functions
|
|
212
|
+
H_aniso = compute_H_anisotropy(m, coeff_2, anisotropy)
|
|
213
|
+
laplacian_m = compute_laplacian(m, dx, dy, dz)
|
|
214
|
+
|
|
215
|
+
# Effective field
|
|
216
|
+
R_eff = coeff_1 * laplacian_m + R_random + H_aniso
|
|
217
|
+
R_eff = R_eff.at[0].add(coeff_3)
|
|
218
|
+
|
|
219
|
+
# Cross products using modular functions
|
|
220
|
+
m_cross_R_eff = cross_product(m, R_eff)
|
|
221
|
+
m_cross_m_cross_R_eff = cross_product(m, m_cross_R_eff)
|
|
222
|
+
|
|
223
|
+
return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class JaxSolver(BaseSolver):
|
|
227
|
+
"""JAX-based LLG3D solver."""
|
|
228
|
+
|
|
229
|
+
solver_type: ClassVar[str] = "jax" #: Solver type name
|
|
230
|
+
|
|
231
|
+
def _xyz_average(self, m1: jnp.ndarray) -> float:
|
|
232
|
+
"""Compute the space average of m1 using JAX."""
|
|
233
|
+
return compute_space_average_jax(m1)
|
|
234
|
+
|
|
235
|
+
def _simulate(self) -> float:
|
|
236
|
+
"""
|
|
237
|
+
Simulates the system for N iterations using JAX.
|
|
238
|
+
|
|
239
|
+
Attributes:
|
|
240
|
+
device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
The time taken for the simulation
|
|
244
|
+
|
|
245
|
+
Raises:
|
|
246
|
+
NotImplementedError: If n_profile is not zero
|
|
247
|
+
"""
|
|
248
|
+
if self.n_profile != 0:
|
|
249
|
+
raise NotImplementedError(
|
|
250
|
+
"Saving x-profiles is not implemented for the JAX solver."
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Configure JAX
|
|
254
|
+
if self.device == "auto":
|
|
255
|
+
# Let JAX choose the best available device
|
|
256
|
+
pass
|
|
257
|
+
elif self.device == "cpu":
|
|
258
|
+
jax.config.update("jax_platform_name", "cpu")
|
|
259
|
+
elif self.device == "gpu":
|
|
260
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
261
|
+
elif self.device.startswith("gpu:"):
|
|
262
|
+
# Select specific GPU using environment variable
|
|
263
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
264
|
+
gpu_id = self.device.split(":")[1]
|
|
265
|
+
# Check if CUDA_VISIBLE_DEVICES is already set externally
|
|
266
|
+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
|
267
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
268
|
+
print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
|
|
269
|
+
else:
|
|
270
|
+
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
|
271
|
+
print(f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}")
|
|
272
|
+
|
|
273
|
+
# Set precision
|
|
274
|
+
if self.precision == "double":
|
|
275
|
+
jax.config.update("jax_enable_x64", True)
|
|
276
|
+
jnp_float = jnp.float64
|
|
277
|
+
else:
|
|
278
|
+
jax.config.update("jax_enable_x64", False)
|
|
279
|
+
jnp_float = jnp.float32
|
|
280
|
+
|
|
281
|
+
print(f"Available JAX devices: {jax.devices()}")
|
|
282
|
+
print(f"Using JAX on device: {jax.devices()[0]}")
|
|
283
|
+
print(f"Precision: {self.precision} ({jnp_float})")
|
|
284
|
+
|
|
285
|
+
# Initialize random key for JAX
|
|
286
|
+
key = random.PRNGKey(self.seed)
|
|
287
|
+
|
|
288
|
+
# Prepare parameters for JIT compilation using to_dict methods
|
|
289
|
+
g_params = self.grid.as_dict()
|
|
290
|
+
e_params = self.elem.to_dict()
|
|
291
|
+
|
|
292
|
+
# --- Initialization ---
|
|
293
|
+
def theta_init(shape):
|
|
294
|
+
"""Initialization of theta."""
|
|
295
|
+
return jnp.zeros(shape, dtype=jnp_float)
|
|
296
|
+
|
|
297
|
+
def phi_init(t, shape):
|
|
298
|
+
"""Initialization of phi."""
|
|
299
|
+
return (
|
|
300
|
+
jnp.zeros(shape, dtype=jnp_float) + self.elem.gamma_0 * self.H_ext * t
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
m_n = jnp.zeros((3, *self.grid.dims), dtype=jnp_float)
|
|
304
|
+
|
|
305
|
+
theta = theta_init(self.grid.dims)
|
|
306
|
+
phi = phi_init(0, self.grid.dims)
|
|
307
|
+
|
|
308
|
+
m_n = m_n.at[0].set(jnp.cos(theta))
|
|
309
|
+
m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
|
|
310
|
+
m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
|
|
311
|
+
|
|
312
|
+
t = 0.0
|
|
313
|
+
|
|
314
|
+
# === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
|
|
315
|
+
print("Warming up JIT compilation...")
|
|
316
|
+
|
|
317
|
+
# Generate dummy random field for warmup
|
|
318
|
+
warmup_key = random.PRNGKey(42)
|
|
319
|
+
R_warmup = self.elem.coeff_4 * random.normal(
|
|
320
|
+
warmup_key, (3, *self.grid.dims), dtype=jnp_float
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Warmup all JIT functions with actual data shapes
|
|
324
|
+
_ = compute_slope(g_params, e_params, m_n, R_warmup)
|
|
325
|
+
if self.n_mean != 0:
|
|
326
|
+
_ = compute_space_average_jax(m_n[0])
|
|
327
|
+
|
|
328
|
+
# Force compilation and execution to complete
|
|
329
|
+
jax.block_until_ready(m_n)
|
|
330
|
+
print("JIT warmup completed.")
|
|
331
|
+
|
|
332
|
+
start_time = time.perf_counter()
|
|
333
|
+
|
|
334
|
+
for n in self._progress_bar():
|
|
335
|
+
t += self.dt
|
|
336
|
+
|
|
337
|
+
# Generate random field for temperature effect
|
|
338
|
+
key, subkey = random.split(key)
|
|
339
|
+
R_random = self.elem.coeff_4 * random.normal(
|
|
340
|
+
subkey, (3, *self.grid.dims), dtype=jnp_float
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Use JIT-compiled version for better performance
|
|
344
|
+
s_pre = compute_slope(g_params, e_params, m_n, R_random)
|
|
345
|
+
m_pre = m_n + self.dt * s_pre
|
|
346
|
+
s_cor = compute_slope(g_params, e_params, m_pre, R_random)
|
|
347
|
+
|
|
348
|
+
# Update magnetization
|
|
349
|
+
m_n = m_n + self.dt * 0.5 * (s_pre + s_cor)
|
|
350
|
+
|
|
351
|
+
# Renormalize to unit sphere
|
|
352
|
+
norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
|
|
353
|
+
m_n = m_n / norm
|
|
354
|
+
|
|
355
|
+
self._record(m_n, t, n)
|
|
356
|
+
|
|
357
|
+
total_time = time.perf_counter() - start_time
|
|
358
|
+
|
|
359
|
+
self._finalize()
|
|
360
|
+
|
|
361
|
+
return total_time
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Mathematical utility functions for solvers."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def cross_product(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
7
|
+
r"""
|
|
8
|
+
Compute cross product :math:`a \times b`.
|
|
9
|
+
|
|
10
|
+
This implementation is faster than np.cross for large arrays.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
a: First vector (shape (3, nx, ny, nz))
|
|
14
|
+
b: Second vector (shape (3, nx, ny, nz))
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Cross product :math:`a \times b` (shape (3, nx, ny, nz))
|
|
18
|
+
"""
|
|
19
|
+
return np.stack(
|
|
20
|
+
[
|
|
21
|
+
a[1] * b[2] - a[2] * b[1], # x-component
|
|
22
|
+
a[2] * b[0] - a[0] * b[2], # y-component
|
|
23
|
+
a[0] * b[1] - a[1] * b[0], # z-component
|
|
24
|
+
],
|
|
25
|
+
axis=0,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def normalize(m_n: np.ndarray):
|
|
30
|
+
r"""
|
|
31
|
+
Normalize the magnetization array (in place).
|
|
32
|
+
|
|
33
|
+
.. math::
|
|
34
|
+
|
|
35
|
+
\mathbf{m}_n = \frac{\mathbf{m}_n}{|\mathbf{m}_n|}
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
m_n: Magnetization array at time step n (shape (3, nx, ny, nz)).
|
|
39
|
+
"""
|
|
40
|
+
norm = np.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
|
|
41
|
+
m_n /= norm
|