llg3d 1.4.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llg3d/__init__.py +4 -1
- llg3d/__main__.py +6 -0
- llg3d/element.py +128 -0
- llg3d/grid.py +126 -0
- llg3d/main.py +66 -0
- llg3d/output.py +108 -0
- llg3d/parameters.py +75 -0
- llg3d/post/plot_results.py +65 -0
- llg3d/post/temperature.py +0 -1
- llg3d/simulation.py +104 -0
- llg3d/solver/__init__.py +45 -0
- llg3d/solver/jax.py +383 -0
- llg3d/solver/mpi.py +449 -0
- llg3d/solver/numpy.py +210 -0
- llg3d/solver/opencl.py +329 -0
- llg3d/solver/solver.py +93 -0
- {llg3d-1.4.0.dist-info → llg3d-2.0.0.dist-info}/METADATA +13 -20
- llg3d-2.0.0.dist-info/RECORD +25 -0
- {llg3d-1.4.0.dist-info → llg3d-2.0.0.dist-info}/WHEEL +1 -1
- llg3d-2.0.0.dist-info/entry_points.txt +4 -0
- llg3d/llg3d.py +0 -742
- llg3d/llg3d_seq.py +0 -447
- llg3d-1.4.0.dist-info/RECORD +0 -13
- llg3d-1.4.0.dist-info/entry_points.txt +0 -3
- {llg3d-1.4.0.dist-info → llg3d-2.0.0.dist-info/licenses}/AUTHORS +0 -0
- {llg3d-1.4.0.dist-info → llg3d-2.0.0.dist-info/licenses}/LICENSE +0 -0
- {llg3d-1.4.0.dist-info → llg3d-2.0.0.dist-info}/top_level.txt +0 -0
llg3d/solver/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Solver module for LLG3D
|
|
3
|
+
|
|
4
|
+
This module contains different solver implementations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = ["numpy", "solver", "rank", "size", "comm", "status"]
|
|
11
|
+
|
|
12
|
+
LIB_AVAILABLE: dict[str, bool] = {}
|
|
13
|
+
|
|
14
|
+
# Check for other solver availability
|
|
15
|
+
for lib in "opencl", "jax", "mpi4py":
|
|
16
|
+
if importlib.util.find_spec(lib, package=__package__) is not None:
|
|
17
|
+
LIB_AVAILABLE[lib] = True
|
|
18
|
+
__all__.append(lib)
|
|
19
|
+
else:
|
|
20
|
+
LIB_AVAILABLE[lib] = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# MPI utilities
|
|
24
|
+
if LIB_AVAILABLE["mpi4py"]:
|
|
25
|
+
from mpi4py import MPI
|
|
26
|
+
|
|
27
|
+
comm = MPI.COMM_WORLD
|
|
28
|
+
rank = comm.Get_rank()
|
|
29
|
+
size = comm.Get_size()
|
|
30
|
+
status = MPI.Status()
|
|
31
|
+
else:
|
|
32
|
+
# MPI library is not available: use dummy values
|
|
33
|
+
class DummyComm:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
comm = DummyComm()
|
|
37
|
+
rank = 0
|
|
38
|
+
size = 1
|
|
39
|
+
|
|
40
|
+
class DummyStatus:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
status = DummyStatus()
|
|
44
|
+
|
|
45
|
+
from . import numpy, solver
|
llg3d/solver/jax.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLG3D solver using XLA compilation
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jax import random
|
|
11
|
+
|
|
12
|
+
from ..output import progress_bar, get_output_files, close_output_files
|
|
13
|
+
from ..grid import Grid
|
|
14
|
+
from ..element import Element
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# JIT compile individual components for better performance and modularity
|
|
18
|
+
@jax.jit
|
|
19
|
+
def compute_H_anisotropy(
|
|
20
|
+
m: jnp.ndarray, coeff_2: float, anisotropy: int
|
|
21
|
+
) -> jnp.ndarray:
|
|
22
|
+
"""
|
|
23
|
+
Compute anisotropy field (JIT compiled)
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
27
|
+
coeff_2: Coefficient for anisotropy
|
|
28
|
+
anisotropy: Anisotropy type (0: uniaxial, 1: cubic)
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Anisotropy field array (shape (3, nx, ny, nz))
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
m1, m2, m3 = m
|
|
35
|
+
|
|
36
|
+
m1m1 = m1 * m1
|
|
37
|
+
m2m2 = m2 * m2
|
|
38
|
+
m3m3 = m3 * m3
|
|
39
|
+
|
|
40
|
+
# Uniaxial anisotropy
|
|
41
|
+
aniso_1_uniaxial = m1
|
|
42
|
+
aniso_2_uniaxial = jnp.zeros_like(m1)
|
|
43
|
+
aniso_3_uniaxial = jnp.zeros_like(m1)
|
|
44
|
+
|
|
45
|
+
# Cubic anisotropy
|
|
46
|
+
aniso_1_cubic = -(1 - m1m1 + m2m2 * m3m3) * m1
|
|
47
|
+
aniso_2_cubic = -(1 - m2m2 + m1m1 * m3m3) * m2
|
|
48
|
+
aniso_3_cubic = -(1 - m3m3 + m1m1 * m2m2) * m3
|
|
49
|
+
|
|
50
|
+
# Select based on anisotropy type
|
|
51
|
+
aniso_1 = jnp.where(
|
|
52
|
+
anisotropy == 0,
|
|
53
|
+
aniso_1_uniaxial,
|
|
54
|
+
jnp.where(anisotropy == 1, aniso_1_cubic, jnp.zeros_like(m1)),
|
|
55
|
+
)
|
|
56
|
+
aniso_2 = jnp.where(
|
|
57
|
+
anisotropy == 0,
|
|
58
|
+
aniso_2_uniaxial,
|
|
59
|
+
jnp.where(anisotropy == 1, aniso_2_cubic, jnp.zeros_like(m1)),
|
|
60
|
+
)
|
|
61
|
+
aniso_3 = jnp.where(
|
|
62
|
+
anisotropy == 0,
|
|
63
|
+
aniso_3_uniaxial,
|
|
64
|
+
jnp.where(anisotropy == 1, aniso_3_cubic, jnp.zeros_like(m1)),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return coeff_2 * jnp.stack([aniso_1, aniso_2, aniso_3], axis=0)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@jax.jit
|
|
71
|
+
def laplacian3D(
|
|
72
|
+
m_i: jnp.ndarray, dx2_inv: float, dy2_inv: float, dz2_inv: float, center_coeff: float
|
|
73
|
+
) -> jnp.ndarray:
|
|
74
|
+
"""
|
|
75
|
+
Compute Laplacian for a single component with Neumann boundary conditions (JIT compiled)
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
m_i: Single component of magnetization (shape (nx, ny, nz))
|
|
79
|
+
dx2_inv: Inverse of squared grid spacing in x direction
|
|
80
|
+
dy2_inv: Inverse of squared grid spacing in y direction
|
|
81
|
+
dz2_inv: Inverse of squared grid spacing in z direction
|
|
82
|
+
center_coeff: Coefficient for the center point
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Laplacian of m_i (shape (nx, ny, nz))
|
|
86
|
+
"""
|
|
87
|
+
m_i_padded = jnp.pad(m_i, ((1, 1), (1, 1), (1, 1)), mode="reflect")
|
|
88
|
+
return (
|
|
89
|
+
dx2_inv * (m_i_padded[2:, 1:-1, 1:-1] + m_i_padded[:-2, 1:-1, 1:-1])
|
|
90
|
+
+ dy2_inv * (m_i_padded[1:-1, 2:, 1:-1] + m_i_padded[1:-1, :-2, 1:-1])
|
|
91
|
+
+ dz2_inv * (m_i_padded[1:-1, 1:-1, 2:] + m_i_padded[1:-1, 1:-1, :-2])
|
|
92
|
+
+ center_coeff * m_i
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@jax.jit
|
|
97
|
+
def compute_laplacian(
|
|
98
|
+
m: jnp.ndarray, dx: float, dy: float, dz: float
|
|
99
|
+
) -> jnp.ndarray:
|
|
100
|
+
"""
|
|
101
|
+
Compute 3D Laplacian with Neumann boundary conditions (JIT compiled)
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
105
|
+
dx: Grid spacing in x direction
|
|
106
|
+
dy: Grid spacing in y direction
|
|
107
|
+
dz: Grid spacing in z direction
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Laplacian of m (shape (3, nx, ny, nz))
|
|
111
|
+
"""
|
|
112
|
+
dx2_inv, dy2_inv, dz2_inv = 1 / dx**2, 1 / dy**2, 1 / dz**2
|
|
113
|
+
center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
|
|
114
|
+
|
|
115
|
+
return jnp.stack(
|
|
116
|
+
[
|
|
117
|
+
laplacian3D(m[0], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
118
|
+
laplacian3D(m[1], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
119
|
+
laplacian3D(m[2], dx2_inv, dy2_inv, dz2_inv, center_coeff),
|
|
120
|
+
],
|
|
121
|
+
axis=0,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@jax.jit
|
|
126
|
+
def compute_space_average_jax(m1: jnp.ndarray) -> float:
|
|
127
|
+
"""
|
|
128
|
+
Compute space average using midpoint method on GPU (JIT compiled)
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
m1: First component of magnetization (shape (nx, ny, nz))
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Space average of m1
|
|
135
|
+
"""
|
|
136
|
+
# Get dimensions directly from the array shape
|
|
137
|
+
Jx, Jy, Jz = m1.shape
|
|
138
|
+
|
|
139
|
+
# Create 3D coordinate grids using the shape
|
|
140
|
+
i_coords = jnp.arange(Jx)
|
|
141
|
+
j_coords = jnp.arange(Jy)
|
|
142
|
+
k_coords = jnp.arange(Jz)
|
|
143
|
+
|
|
144
|
+
# Create 3D coordinate grids
|
|
145
|
+
ii, jj, kk = jnp.meshgrid(i_coords, j_coords, k_coords, indexing='ij')
|
|
146
|
+
|
|
147
|
+
# Apply midpoint weights (0.5 on edges, 1.0 elsewhere)
|
|
148
|
+
weights = jnp.ones_like(m1)
|
|
149
|
+
weights = jnp.where((ii == 0) | (ii == Jx-1), weights * 0.5, weights)
|
|
150
|
+
weights = jnp.where((jj == 0) | (jj == Jy-1), weights * 0.5, weights)
|
|
151
|
+
weights = jnp.where((kk == 0) | (kk == Jz-1), weights * 0.5, weights)
|
|
152
|
+
|
|
153
|
+
# Compute weighted sum and normalize
|
|
154
|
+
weighted_sum = jnp.sum(weights * m1)
|
|
155
|
+
|
|
156
|
+
# Compute ncell from the weights (this is the effective cell count)
|
|
157
|
+
ncell = jnp.sum(weights)
|
|
158
|
+
|
|
159
|
+
return weighted_sum / ncell
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@jax.jit
|
|
163
|
+
def cross_product(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
|
|
164
|
+
r"""
|
|
165
|
+
Compute cross product :math:`a \times b` (JIT compiled)
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
a: First vector (shape (3, nx, ny, nz))
|
|
169
|
+
b: Second vector (shape (3, nx, ny, nz))
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Cross product :math:`a \times b` (shape (3, nx, ny, nz))
|
|
173
|
+
"""
|
|
174
|
+
# Use JAX's optimized cross product function directly on axis 0
|
|
175
|
+
return jnp.cross(a, b, axis=0)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# JIT compile the slope computation for performance
|
|
179
|
+
@jax.jit
|
|
180
|
+
def compute_slope(
|
|
181
|
+
g_params: dict, e_params: dict, m: jnp.ndarray, R_random: jnp.ndarray
|
|
182
|
+
) -> jnp.ndarray:
|
|
183
|
+
"""
|
|
184
|
+
JIT-compiled version of compute_slope_jax using modular sub-functions
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
g_params: Grid parameters dict (dx, dy, dz)
|
|
188
|
+
e_params: Element parameters dict (coeff_1, coeff_2, coeff_3, lambda_G, anisotropy)
|
|
189
|
+
m: Magnetization array (shape (3, nx, ny, nz))
|
|
190
|
+
R_random: Random field array (shape (3, nx, ny, nz))
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Slope array (shape (3, nx, ny, nz))
|
|
194
|
+
"""
|
|
195
|
+
# Extract parameters
|
|
196
|
+
dx, dy, dz = g_params["dx"], g_params["dy"], g_params["dz"]
|
|
197
|
+
coeff_1 = e_params["coeff_1"]
|
|
198
|
+
coeff_2 = e_params["coeff_2"]
|
|
199
|
+
coeff_3 = e_params["coeff_3"]
|
|
200
|
+
lambda_G = e_params["lambda_G"]
|
|
201
|
+
anisotropy = e_params["anisotropy"]
|
|
202
|
+
|
|
203
|
+
# Compute components using modular sub-functions
|
|
204
|
+
H_aniso = compute_H_anisotropy(m, coeff_2, anisotropy)
|
|
205
|
+
laplacian_m = compute_laplacian(m, dx, dy, dz)
|
|
206
|
+
|
|
207
|
+
# Effective field
|
|
208
|
+
R_eff = coeff_1 * laplacian_m + R_random + H_aniso
|
|
209
|
+
R_eff = R_eff.at[0].add(coeff_3)
|
|
210
|
+
|
|
211
|
+
# Cross products using modular functions
|
|
212
|
+
m_cross_R_eff = cross_product(m, R_eff)
|
|
213
|
+
m_cross_m_cross_R_eff = cross_product(m, m_cross_R_eff)
|
|
214
|
+
|
|
215
|
+
return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def simulate(
|
|
219
|
+
N: int,
|
|
220
|
+
Jx: int,
|
|
221
|
+
Jy: int,
|
|
222
|
+
Jz: int,
|
|
223
|
+
dx: float,
|
|
224
|
+
T: float,
|
|
225
|
+
H_ext: float,
|
|
226
|
+
dt: float,
|
|
227
|
+
start_averaging: int,
|
|
228
|
+
n_mean: int,
|
|
229
|
+
n_profile: int,
|
|
230
|
+
element_class: Element,
|
|
231
|
+
precision: str,
|
|
232
|
+
seed: int,
|
|
233
|
+
device: str = "auto",
|
|
234
|
+
**_,
|
|
235
|
+
) -> tuple[float, str, float]:
|
|
236
|
+
"""
|
|
237
|
+
Simulates the system for N iterations using JAX
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
N: Number of iterations
|
|
241
|
+
Jx: Number of grid points in x direction
|
|
242
|
+
Jy: Number of grid points in y direction
|
|
243
|
+
Jz: Number of grid points in z direction
|
|
244
|
+
dx: Grid spacing
|
|
245
|
+
T: Temperature in Kelvin
|
|
246
|
+
H_ext: External magnetic field strength
|
|
247
|
+
dt: Time step for the simulation
|
|
248
|
+
start_averaging: Number of iterations for averaging
|
|
249
|
+
n_mean: Number of iterations for integral output
|
|
250
|
+
n_profile: Number of iterations for profile output
|
|
251
|
+
element_class: Element of the sample (default: Cobalt)
|
|
252
|
+
precision: Precision of the simulation (single or double)
|
|
253
|
+
seed: Random seed for temperature fluctuations
|
|
254
|
+
device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
- The time taken for the simulation
|
|
258
|
+
- The output filenames
|
|
259
|
+
- The average magnetization
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
# Configure JAX
|
|
263
|
+
if device == "auto":
|
|
264
|
+
# Let JAX choose the best available device
|
|
265
|
+
pass
|
|
266
|
+
elif device == "cpu":
|
|
267
|
+
jax.config.update("jax_platform_name", "cpu")
|
|
268
|
+
elif device == "gpu":
|
|
269
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
270
|
+
elif device.startswith("gpu:"):
|
|
271
|
+
# Select specific GPU using environment variable
|
|
272
|
+
jax.config.update("jax_platform_name", "gpu")
|
|
273
|
+
gpu_id = device.split(":")[1]
|
|
274
|
+
# Check if CUDA_VISIBLE_DEVICES is already set externally
|
|
275
|
+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
|
276
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
277
|
+
print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
|
|
278
|
+
else:
|
|
279
|
+
print(f"Using external CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
|
|
280
|
+
|
|
281
|
+
# Set precision
|
|
282
|
+
if precision == "double":
|
|
283
|
+
jax.config.update("jax_enable_x64", True)
|
|
284
|
+
jnp_float = jnp.float64
|
|
285
|
+
else:
|
|
286
|
+
jax.config.update("jax_enable_x64", False)
|
|
287
|
+
jnp_float = jnp.float32
|
|
288
|
+
|
|
289
|
+
print(f"Available JAX devices: {jax.devices()}")
|
|
290
|
+
print(f"Using JAX on device: {jax.devices()[0]}")
|
|
291
|
+
print(f"Precision: {precision} ({jnp_float})")
|
|
292
|
+
|
|
293
|
+
# Initialize random key for JAX
|
|
294
|
+
key = random.PRNGKey(seed)
|
|
295
|
+
|
|
296
|
+
g = Grid(Jx, Jy, Jz, dx)
|
|
297
|
+
dims = g.dims
|
|
298
|
+
|
|
299
|
+
e = element_class(T, H_ext, g, dt)
|
|
300
|
+
print(f"CFL = {e.get_CFL()}")
|
|
301
|
+
|
|
302
|
+
# Prepare parameters for JIT compilation using to_dict methods
|
|
303
|
+
g_params = g.to_dict()
|
|
304
|
+
e_params = e.to_dict()
|
|
305
|
+
|
|
306
|
+
# --- Initialization ---
|
|
307
|
+
def theta_init(shape):
|
|
308
|
+
"""Initialization of theta"""
|
|
309
|
+
return jnp.zeros(shape, dtype=jnp_float)
|
|
310
|
+
|
|
311
|
+
def phi_init(t, shape):
|
|
312
|
+
"""Initialization of phi"""
|
|
313
|
+
return jnp.zeros(shape, dtype=jnp_float) + e.gamma_0 * H_ext * t
|
|
314
|
+
|
|
315
|
+
m_n = jnp.zeros((3,) + dims, dtype=jnp_float)
|
|
316
|
+
|
|
317
|
+
theta = theta_init(dims)
|
|
318
|
+
phi = phi_init(0, dims)
|
|
319
|
+
|
|
320
|
+
m_n = m_n.at[0].set(jnp.cos(theta))
|
|
321
|
+
m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
|
|
322
|
+
m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
|
|
323
|
+
|
|
324
|
+
f_mean, f_profiles, output_filenames = get_output_files(g, T, n_mean, n_profile)
|
|
325
|
+
|
|
326
|
+
t = 0.0
|
|
327
|
+
m1_average = 0.0
|
|
328
|
+
|
|
329
|
+
# === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
|
|
330
|
+
print("Warming up JIT compilation...")
|
|
331
|
+
|
|
332
|
+
# Generate dummy random field for warmup
|
|
333
|
+
warmup_key = random.PRNGKey(42)
|
|
334
|
+
R_warmup = e.coeff_4 * random.normal(warmup_key, (3,) + dims, dtype=jnp_float)
|
|
335
|
+
|
|
336
|
+
# Warmup all JIT functions with actual data shapes
|
|
337
|
+
_ = compute_slope(g_params, e_params, m_n, R_warmup)
|
|
338
|
+
if n_mean != 0:
|
|
339
|
+
_ = compute_space_average_jax(m_n[0])
|
|
340
|
+
|
|
341
|
+
# Force compilation and execution to complete
|
|
342
|
+
jax.block_until_ready(m_n)
|
|
343
|
+
print("JIT warmup completed.")
|
|
344
|
+
|
|
345
|
+
start_time = time.perf_counter()
|
|
346
|
+
|
|
347
|
+
for n in progress_bar(range(1, N + 1), "Iteration : ", 40):
|
|
348
|
+
t += dt
|
|
349
|
+
|
|
350
|
+
# Generate random field for temperature effect
|
|
351
|
+
key, subkey = random.split(key)
|
|
352
|
+
R_random = e.coeff_4 * random.normal(subkey, (3,) + dims, dtype=jnp_float)
|
|
353
|
+
|
|
354
|
+
# Use JIT-compiled version for better performance
|
|
355
|
+
s_pre = compute_slope(g_params, e_params, m_n, R_random)
|
|
356
|
+
m_pre = m_n + dt * s_pre
|
|
357
|
+
s_cor = compute_slope(g_params, e_params, m_pre, R_random)
|
|
358
|
+
|
|
359
|
+
# Update magnetization
|
|
360
|
+
m_n = m_n + dt * 0.5 * (s_pre + s_cor)
|
|
361
|
+
|
|
362
|
+
# Renormalize to unit sphere
|
|
363
|
+
norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
|
|
364
|
+
m_n = m_n / norm
|
|
365
|
+
|
|
366
|
+
# Export the average of m1 to a file (optimized for GPU)
|
|
367
|
+
if n_mean != 0 and n % n_mean == 0:
|
|
368
|
+
# Compute space average directly on GPU without CPU transfer
|
|
369
|
+
m1_mean = compute_space_average_jax(m_n[0])
|
|
370
|
+
# Convert to Python float for file writing
|
|
371
|
+
m1_mean = float(m1_mean)
|
|
372
|
+
if n >= start_averaging:
|
|
373
|
+
m1_average += m1_mean * n_mean
|
|
374
|
+
f_mean.write(f"{t:10.8e} {m1_mean:10.8e}\n")
|
|
375
|
+
|
|
376
|
+
total_time = time.perf_counter() - start_time
|
|
377
|
+
|
|
378
|
+
close_output_files(f_mean, f_profiles)
|
|
379
|
+
|
|
380
|
+
if n > start_averaging:
|
|
381
|
+
m1_average /= N - start_averaging
|
|
382
|
+
|
|
383
|
+
return total_time, output_filenames, m1_average
|