llg3d 2.0.1__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. llg3d/__init__.py +2 -4
  2. llg3d/benchmarks/__init__.py +1 -0
  3. llg3d/benchmarks/compare_commits.py +321 -0
  4. llg3d/benchmarks/efficiency.py +451 -0
  5. llg3d/benchmarks/utils.py +25 -0
  6. llg3d/element.py +98 -17
  7. llg3d/grid.py +48 -58
  8. llg3d/io.py +395 -0
  9. llg3d/main.py +32 -35
  10. llg3d/parameters.py +159 -49
  11. llg3d/post/__init__.py +1 -1
  12. llg3d/post/extract.py +112 -0
  13. llg3d/post/info.py +192 -0
  14. llg3d/post/m1_vs_T.py +107 -0
  15. llg3d/post/m1_vs_time.py +81 -0
  16. llg3d/post/process.py +87 -85
  17. llg3d/post/utils.py +38 -0
  18. llg3d/post/x_profiles.py +161 -0
  19. llg3d/py.typed +1 -0
  20. llg3d/solvers/__init__.py +153 -0
  21. llg3d/solvers/base.py +345 -0
  22. llg3d/solvers/experimental/__init__.py +9 -0
  23. llg3d/{solver → solvers/experimental}/jax.py +117 -143
  24. llg3d/solvers/math_utils.py +41 -0
  25. llg3d/solvers/mpi.py +370 -0
  26. llg3d/solvers/numpy.py +126 -0
  27. llg3d/solvers/opencl.py +439 -0
  28. llg3d/solvers/profiling.py +38 -0
  29. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/METADATA +5 -2
  30. llg3d-3.1.0.dist-info/RECORD +36 -0
  31. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/WHEEL +1 -1
  32. llg3d-3.1.0.dist-info/entry_points.txt +9 -0
  33. llg3d/output.py +0 -107
  34. llg3d/post/plot_results.py +0 -61
  35. llg3d/post/temperature.py +0 -76
  36. llg3d/simulation.py +0 -95
  37. llg3d/solver/__init__.py +0 -45
  38. llg3d/solver/mpi.py +0 -450
  39. llg3d/solver/numpy.py +0 -207
  40. llg3d/solver/opencl.py +0 -330
  41. llg3d/solver/solver.py +0 -89
  42. llg3d-2.0.1.dist-info/RECORD +0 -25
  43. llg3d-2.0.1.dist-info/entry_points.txt +0 -4
  44. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/AUTHORS +0 -0
  45. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/LICENSE +0 -0
  46. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/top_level.txt +0 -0
llg3d/solvers/base.py ADDED
@@ -0,0 +1,345 @@
1
+ """Define the base solver class."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, ClassVar
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from ..element import Element, get_element_class
13
+ from ..grid import Grid
14
+ from ..io import (
15
+ Metrics,
16
+ Observables,
17
+ RecordsBuffer,
18
+ format_profiling_table,
19
+ get_tqdm_file,
20
+ save_results,
21
+ )
22
+ from ..parameters import InitType, RunParameters
23
+ from . import rank, size
24
+ from .math_utils import normalize
25
+ from .profiling import ProfilingStats, timeit
26
+
27
+
28
+ @dataclass
29
+ class BaseSolver(ABC, RunParameters):
30
+ """Abstract data base class for LLG3D solvers."""
31
+
32
+ solver_type: ClassVar[str] = "base" #: Solver type name
33
+
34
+ def __post_init__(self) -> None:
35
+ """Initialize the solver after dataclass creation."""
36
+ # Ensure solver name matches class solver_type
37
+ if self.solver != self.solver_type:
38
+ raise ValueError(
39
+ f"Trying to initialize a {self.__class__.__name__}, but solver name "
40
+ f'mismatch: expected "{self.solver_type}", got "{self.solver}"'
41
+ )
42
+
43
+ # Initialize results structure aligned with RunResults
44
+ self.metrics: Metrics = {
45
+ "total_time": 0.0,
46
+ "time_per_ite": 0.0,
47
+ "efficiency": 0.0,
48
+ "CFL": 0.0,
49
+ }
50
+ # Physical observables
51
+ self.observables: Observables = {}
52
+ # Records with optional x_profiles and xyz_average
53
+ self.records: RecordsBuffer = {}
54
+
55
+ self.np = size # Add a parameter for the number of processes
56
+
57
+ self.np_float: np.dtype = np.dtype(
58
+ np.float64 if self.precision == "double" else np.float32
59
+ )
60
+ self.grid: Grid = Grid(self.Jx, self.Jy, self.Jz, self.dx)
61
+ if rank == 0:
62
+ print(self.grid)
63
+ # Reference the element class from the element string
64
+ ElementClass: type[Element] = get_element_class(self.element)
65
+ # Pass dtype to Element so its scalar coefficients have the correct
66
+ # numpy dtype (`self.np_float`) and avoid implicit promotion.
67
+ self.elem: Element = ElementClass(
68
+ self.T, self.H_ext, self.grid, self.dt, dtype=self.np_float
69
+ )
70
+
71
+ self.rng: np.random.Generator = self._init_rng()
72
+ self.profiling_stats: ProfilingStats = defaultdict(
73
+ lambda: {"time": 0.0, "calls": 0}
74
+ )
75
+ self._tqdm_file = get_tqdm_file() #: File object for tqdm output
76
+
77
+ def theta_init_0(self, t: float) -> np.ndarray:
78
+ """Initialization of theta with 0."""
79
+ return np.zeros(self.grid.dims, dtype=self.np_float)
80
+
81
+ def theta_init_dw(self, t: float) -> np.ndarray:
82
+ """Initialization of theta with a domain wall profile."""
83
+ x, _, _ = self.grid.get_mesh(local=size > 1, dtype=self.np_float)
84
+ return 2.0 * np.arctan(
85
+ np.exp(
86
+ -(
87
+ x
88
+ - self.grid.Lx / 2
89
+ + self.elem.d_0 * self.elem.coeff_3 * self.elem.lambda_G * t
90
+ )
91
+ / self.elem.d_0
92
+ )
93
+ )
94
+
95
+ def phi_init(self, t: float) -> np.ndarray:
96
+ """Initialization of phi."""
97
+ return (
98
+ np.zeros(self.grid.dims, dtype=self.np_float)
99
+ + self.elem.gamma_0 * self.elem.H_ext * t
100
+ )
101
+
102
+ def _init_m_n(self) -> np.ndarray:
103
+ """Initialize the magnetization array at time step n."""
104
+ m_n = np.zeros((3,) + self.grid.dims, dtype=self.np_float)
105
+
106
+ if self.init_type == "0":
107
+ theta = self.theta_init_0(0)
108
+ elif self.init_type == "dw":
109
+ theta = self.theta_init_dw(0)
110
+ else:
111
+ raise ValueError(
112
+ f"Unknown initialization type: {self.init_type}, "
113
+ f"should be in {InitType}"
114
+ )
115
+
116
+ phi = self.phi_init(0)
117
+
118
+ m_n[0] = np.cos(theta)
119
+ m_n[1] = np.sin(theta) * np.cos(phi)
120
+ m_n[2] = np.sin(theta) * np.sin(phi)
121
+ # renormalize to verify the constraint of being on the sphere
122
+ normalize(m_n)
123
+ return m_n
124
+
125
+ def _init_rng(self) -> np.random.Generator:
126
+ """
127
+ Initialize a random number generator for temperature fluctuations.
128
+
129
+ Returns:
130
+ A numpy random number generator
131
+ """
132
+ # Initialize a sequence of random seeds
133
+ # See: https://numpy.org/doc/stable/reference/random/parallel.html
134
+ ss = np.random.SeedSequence(self.seed)
135
+
136
+ # Deploy size x SeedSequence to pass to child processes
137
+ child_seeds = ss.spawn(size)
138
+ streams = [np.random.default_rng(s) for s in child_seeds]
139
+ rng = streams[rank]
140
+ return rng
141
+
142
+ @timeit
143
+ def _get_R_random(self) -> np.ndarray:
144
+ """
145
+ Generate the random field for temperature fluctuations.
146
+
147
+ Returns:
148
+ Random field array (shape (3, nx, ny, nz))
149
+ """
150
+ R_random = self.elem.coeff_4 * self.rng.standard_normal(
151
+ (3, *self.grid.dims), dtype=self.np_float
152
+ )
153
+ return R_random
154
+
155
+ @timeit
156
+ def _normalize(self, m_n: np.ndarray) -> None:
157
+ r"""
158
+ Normalize the magnetization array (in place).
159
+
160
+ .. math::
161
+
162
+ \mathbf{m}_n = \frac{\mathbf{m}_n}{|\mathbf{m}_n|}
163
+
164
+ Args:
165
+ m_n: Magnetization array at time step n (shape (3, nx, ny, nz)).
166
+ """
167
+ normalize(m_n)
168
+
169
+ @timeit
170
+ def _xyz_average(self, m: np.ndarray) -> float:
171
+ """
172
+ Returns the spatial average of m with shape (g.dims) using the midpoint method.
173
+
174
+ Args:
175
+ m: Array to be integrated
176
+
177
+ Returns:
178
+ Spatial average of m
179
+ """
180
+ mm = m.copy() # copy m to avoid modifying its value
181
+
182
+ # on the edges, we divide the contribution by 2
183
+ # x
184
+ mm[0, :, :] /= 2
185
+ mm[-1, :, :] /= 2
186
+ # y
187
+ mm[:, 0, :] /= 2
188
+ mm[:, -1, :] /= 2
189
+ # z
190
+ mm[:, :, 0] /= 2
191
+ mm[:, :, -1] /= 2
192
+
193
+ average = mm.sum() / self.grid.ncell
194
+ return float(average)
195
+
196
+ def _record_xyz_average(self, m_n: np.ndarray, t: float, n: int) -> None:
197
+ """Update the time average of m1."""
198
+ xyz_average = self._xyz_average(m_n[0])
199
+ if rank == 0:
200
+ # Ensure xyz_average list exists
201
+ if "xyz_average" not in self.records:
202
+ self.records["xyz_average"] = []
203
+ # Record the mean value at time t
204
+ self.records["xyz_average"].append((t, xyz_average))
205
+ # Update time average of m1
206
+ if n >= self.start_averaging:
207
+ # Initialize m1_mean on first use
208
+ if "m1_mean" not in self.observables:
209
+ self.observables["m1_mean"] = 0.0
210
+ # Accumulate time average (each sample contributes equally)
211
+ self.observables["m1_mean"] += xyz_average
212
+
213
+ def _finalize(self) -> None:
214
+ """Normalize m1_mean by the actual number of samples accumulated."""
215
+ if rank == 0:
216
+ if "m1_mean" in self.observables:
217
+ # Divide by actual number of samples
218
+ # (accounting for n_mean sampling interval)
219
+ num_samples = (self.N - self.start_averaging) // self.n_mean
220
+ self.observables["m1_mean"] /= num_samples
221
+
222
+ @timeit
223
+ def _yz_average(self, m_i: np.ndarray) -> np.ndarray:
224
+ """
225
+ Returns the spatial average of m using the midpoint method along y and z.
226
+
227
+ Args:
228
+ m_i: Array to be integrated
229
+
230
+ Returns:
231
+ Spatial average of m in y and z of shape (g.dims[0],)
232
+ """
233
+ # Make a copy of m to avoid modifying its value
234
+ mm = m_i.copy()
235
+
236
+ # On y and z edges, divide the contribution by 2
237
+ mm[:, 0, :] /= 2
238
+ mm[:, -1, :] /= 2
239
+ mm[:, :, 0] /= 2
240
+ mm[:, :, -1] /= 2
241
+
242
+ n_cell_yz = (mm.shape[1] - 1) * (mm.shape[2] - 1)
243
+ return mm.sum(axis=(1, 2)) / n_cell_yz
244
+
245
+ @timeit
246
+ def _update_x_profiles(self, m_n: np.ndarray, t: float) -> None:
247
+ """Update x profiles of the averaged m_i in y and z."""
248
+ # Initialize x_profiles on first use
249
+ if "x_profiles" not in self.records:
250
+ self.records["x_profiles"] = {
251
+ "t": [],
252
+ "m1": [],
253
+ "m2": [],
254
+ "m3": [],
255
+ }
256
+ x_prof = self.records["x_profiles"]
257
+ x_prof["t"].append(t)
258
+ x_prof["m1"].append(self._yz_average(m_n[0]))
259
+ x_prof["m2"].append(self._yz_average(m_n[1]))
260
+ x_prof["m3"].append(self._yz_average(m_n[2]))
261
+
262
+ def _record(self, m_n: Any, t: float, n: int) -> None:
263
+ """Record simulation data."""
264
+ # Record the average of m1
265
+ if self.n_mean != 0 and n % self.n_mean == 0:
266
+ self._record_xyz_average(m_n, t, n)
267
+ # Record the x profiles of the averaged m_i in y and z
268
+ if self.n_profile != 0 and n % self.n_profile == 0:
269
+ self._update_x_profiles(m_n, t)
270
+
271
+ def _progress_bar(self):
272
+ """Return a progress bar for the given range using tqdm."""
273
+ if self._tqdm_file is None:
274
+ # corresponds to rank != 0 in MPI
275
+ return range(self.N)
276
+
277
+ return tqdm(
278
+ range(self.N),
279
+ file=self._tqdm_file,
280
+ dynamic_ncols=True,
281
+ leave=True,
282
+ )
283
+
284
+ @abstractmethod
285
+ def _simulate(self) -> float:
286
+ """
287
+ Simulates the system for N iterations.
288
+
289
+ Returns:
290
+ total_time: Total simulation time
291
+ """
292
+
293
+ def run(self) -> None:
294
+ """Runs the simulation and store the results."""
295
+ total_time = self._simulate()
296
+ self.metrics["total_time"] = total_time
297
+ time_per_ite = total_time / self.N if self.N > 0 else 0.0
298
+ self.metrics["time_per_ite"] = time_per_ite
299
+ self.metrics["efficiency"] = time_per_ite / self.grid.ntot
300
+ self.metrics["CFL"] = float(self.elem.get_CFL())
301
+ if rank == 0:
302
+ # Store only profiling stats for functions/kernels that were actually called
303
+ profiling_filtered = {
304
+ k: v for k, v in self.profiling_stats.items() if v.get("calls", 0) > 0
305
+ }
306
+ if profiling_filtered:
307
+ self.metrics["profiling_stats"] = profiling_filtered
308
+
309
+ def _format_profiling(self) -> str:
310
+ """Format the profiling information for display."""
311
+ return format_profiling_table(self.profiling_stats, self.metrics["total_time"])
312
+
313
+ def save(self, dir_path: str | Path = ".") -> None:
314
+ """
315
+ Saves the results of the simulation to a .npz file.
316
+
317
+ Args:
318
+ dir_path: Directory path to save the results
319
+ """
320
+ if rank == 0:
321
+ if self.metrics["total_time"] > 0:
322
+ s = f"""\
323
+ N iterations = {self.N}
324
+ total_time [s] = {self.metrics["total_time"]:.03f}
325
+ time/ite [s/ite] = {self.metrics["time_per_ite"]:.03e}
326
+ efficiency [s/ite/pt] = {self.metrics["efficiency"]:.03e}
327
+ CFL = {self.metrics["CFL"]:.03e}"""
328
+ print(s)
329
+
330
+ # Print profiling info only if enabled
331
+ if self.profiling:
332
+ print(f"Profiling info:\n{self._format_profiling()}")
333
+
334
+ # Export the mean of m1 over space and time
335
+ if "m1_mean" in self.observables:
336
+ print(f"m1_mean = {self.observables['m1_mean']:e}")
337
+
338
+ print(f"Saving {self.result_file}")
339
+ save_results(
340
+ Path(dir_path) / self.result_file,
341
+ self,
342
+ self.metrics,
343
+ observables=self.observables if self.observables else None,
344
+ records_buffer=self.records if self.records else None,
345
+ )
@@ -0,0 +1,9 @@
1
+ """
2
+ Experimental solvers live here.
3
+
4
+ Currently contains JAX-based solver implementations.
5
+ """
6
+
7
+ from .jax import JaxSolver
8
+
9
+ __all__ = ["JaxSolver"]
@@ -1,4 +1,12 @@
1
- """LLG3D solver using XLA compilation."""
1
+ """
2
+ LLG3D solver using XLA compilation.
3
+
4
+ .. warning::
5
+
6
+ It is experimental and not maintained.
7
+ """
8
+
9
+ from typing import ClassVar
2
10
 
3
11
  import os
4
12
  import time
@@ -7,9 +15,7 @@ import jax
7
15
  import jax.numpy as jnp
8
16
  from jax import random
9
17
 
10
- from ..output import progress_bar, get_output_files, close_output_files
11
- from ..grid import Grid
12
- from ..element import Element
18
+ from ..base import BaseSolver
13
19
 
14
20
 
15
21
  # JIT compile individual components for better performance and modularity
@@ -157,7 +163,7 @@ def compute_space_average_jax(m1: jnp.ndarray) -> float:
157
163
  # Compute ncell from the weights (this is the effective cell count)
158
164
  ncell = jnp.sum(weights)
159
165
 
160
- return weighted_sum / ncell
166
+ return weighted_sum / ncell # type: ignore
161
167
 
162
168
 
163
169
  @jax.jit
@@ -217,171 +223,139 @@ def compute_slope(
217
223
  return -(m_cross_R_eff + lambda_G * m_cross_m_cross_R_eff)
218
224
 
219
225
 
220
- def simulate(
221
- N: int,
222
- Jx: int,
223
- Jy: int,
224
- Jz: int,
225
- dx: float,
226
- T: float,
227
- H_ext: float,
228
- dt: float,
229
- start_averaging: int,
230
- n_mean: int,
231
- n_profile: int,
232
- element_class: Element,
233
- precision: str,
234
- seed: int,
235
- device: str = "auto",
236
- **_,
237
- ) -> tuple[float, str, float]:
238
- """
239
- Simulates the system for N iterations using JAX.
226
+ class JaxSolver(BaseSolver):
227
+ """JAX-based LLG3D solver."""
240
228
 
241
- Args:
242
- N: Number of iterations
243
- Jx: Number of grid points in x direction
244
- Jy: Number of grid points in y direction
245
- Jz: Number of grid points in z direction
246
- dx: Grid spacing
247
- T: Temperature in Kelvin
248
- H_ext: External magnetic field strength
249
- dt: Time step for the simulation
250
- start_averaging: Number of iterations for averaging
251
- n_mean: Number of iterations for integral output
252
- n_profile: Number of iterations for profile output
253
- element_class: Element of the sample (default: Cobalt)
254
- precision: Precision of the simulation (single or double)
255
- seed: Random seed for temperature fluctuations
256
- device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
229
+ solver_type: ClassVar[str] = "jax" #: Solver type name
257
230
 
258
- Returns:
259
- - The time taken for the simulation
260
- - The output filenames
261
- - The average magnetization
262
- """
263
- # Configure JAX
264
- if device == "auto":
265
- # Let JAX choose the best available device
266
- pass
267
- elif device == "cpu":
268
- jax.config.update("jax_platform_name", "cpu")
269
- elif device == "gpu":
270
- jax.config.update("jax_platform_name", "gpu")
271
- elif device.startswith("gpu:"):
272
- # Select specific GPU using environment variable
273
- jax.config.update("jax_platform_name", "gpu")
274
- gpu_id = device.split(":")[1]
275
- # Check if CUDA_VISIBLE_DEVICES is already set externally
276
- if "CUDA_VISIBLE_DEVICES" not in os.environ:
277
- os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
278
- print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
279
- else:
280
- cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
281
- print(
282
- f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
283
- )
231
+ def _xyz_average(self, m1: jnp.ndarray) -> float:
232
+ """Compute the space average of m1 using JAX."""
233
+ return compute_space_average_jax(m1)
284
234
 
285
- # Set precision
286
- if precision == "double":
287
- jax.config.update("jax_enable_x64", True)
288
- jnp_float = jnp.float64
289
- else:
290
- jax.config.update("jax_enable_x64", False)
291
- jnp_float = jnp.float32
235
+ def _simulate(self) -> float:
236
+ """
237
+ Simulates the system for N iterations using JAX.
292
238
 
293
- print(f"Available JAX devices: {jax.devices()}")
294
- print(f"Using JAX on device: {jax.devices()[0]}")
295
- print(f"Precision: {precision} ({jnp_float})")
239
+ Attributes:
240
+ device: Device to use ('cpu', 'gpu', 'gpu:0', 'gpu:1', etc., or 'auto')
296
241
 
297
- # Initialize random key for JAX
298
- key = random.PRNGKey(seed)
242
+ Returns:
243
+ The time taken for the simulation
299
244
 
300
- g = Grid(Jx, Jy, Jz, dx)
301
- dims = g.dims
245
+ Raises:
246
+ NotImplementedError: If n_profile is not zero
247
+ """
248
+ if self.n_profile != 0:
249
+ raise NotImplementedError(
250
+ "Saving x-profiles is not implemented for the JAX solver."
251
+ )
302
252
 
303
- e = element_class(T, H_ext, g, dt)
304
- print(f"CFL = {e.get_CFL()}")
253
+ # Configure JAX
254
+ if self.device == "auto":
255
+ # Let JAX choose the best available device
256
+ pass
257
+ elif self.device == "cpu":
258
+ jax.config.update("jax_platform_name", "cpu")
259
+ elif self.device == "gpu":
260
+ jax.config.update("jax_platform_name", "gpu")
261
+ elif self.device.startswith("gpu:"):
262
+ # Select specific GPU using environment variable
263
+ jax.config.update("jax_platform_name", "gpu")
264
+ gpu_id = self.device.split(":")[1]
265
+ # Check if CUDA_VISIBLE_DEVICES is already set externally
266
+ if "CUDA_VISIBLE_DEVICES" not in os.environ:
267
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
268
+ print(f"Set CUDA_VISIBLE_DEVICES={gpu_id}")
269
+ else:
270
+ cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
271
+ print(f"Using external CUDA_VISIBLE_DEVICES={cuda_visible_devices}")
272
+
273
+ # Set precision
274
+ if self.precision == "double":
275
+ jax.config.update("jax_enable_x64", True)
276
+ jnp_float = jnp.float64
277
+ else:
278
+ jax.config.update("jax_enable_x64", False)
279
+ jnp_float = jnp.float32
305
280
 
306
- # Prepare parameters for JIT compilation using to_dict methods
307
- g_params = g.to_dict()
308
- e_params = e.to_dict()
281
+ print(f"Available JAX devices: {jax.devices()}")
282
+ print(f"Using JAX on device: {jax.devices()[0]}")
283
+ print(f"Precision: {self.precision} ({jnp_float})")
309
284
 
310
- # --- Initialization ---
311
- def theta_init(shape):
312
- """Initialization of theta."""
313
- return jnp.zeros(shape, dtype=jnp_float)
285
+ # Initialize random key for JAX
286
+ key = random.PRNGKey(self.seed)
314
287
 
315
- def phi_init(t, shape):
316
- """Initialization of phi."""
317
- return jnp.zeros(shape, dtype=jnp_float) + e.gamma_0 * H_ext * t
288
+ # Prepare parameters for JIT compilation using to_dict methods
289
+ g_params = self.grid.as_dict()
290
+ e_params = self.elem.to_dict()
318
291
 
319
- m_n = jnp.zeros((3,) + dims, dtype=jnp_float)
292
+ # --- Initialization ---
293
+ def theta_init(shape):
294
+ """Initialization of theta."""
295
+ return jnp.zeros(shape, dtype=jnp_float)
320
296
 
321
- theta = theta_init(dims)
322
- phi = phi_init(0, dims)
297
+ def phi_init(t, shape):
298
+ """Initialization of phi."""
299
+ return (
300
+ jnp.zeros(shape, dtype=jnp_float) + self.elem.gamma_0 * self.H_ext * t
301
+ )
323
302
 
324
- m_n = m_n.at[0].set(jnp.cos(theta))
325
- m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
326
- m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
303
+ m_n = jnp.zeros((3, *self.grid.dims), dtype=jnp_float)
327
304
 
328
- f_mean, f_profiles, output_filenames = get_output_files(g, T, n_mean, n_profile)
305
+ theta = theta_init(self.grid.dims)
306
+ phi = phi_init(0, self.grid.dims)
329
307
 
330
- t = 0.0
331
- m1_average = 0.0
308
+ m_n = m_n.at[0].set(jnp.cos(theta))
309
+ m_n = m_n.at[1].set(jnp.sin(theta) * jnp.cos(phi))
310
+ m_n = m_n.at[2].set(jnp.sin(theta) * jnp.sin(phi))
332
311
 
333
- # === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
334
- print("Warming up JIT compilation...")
312
+ t = 0.0
335
313
 
336
- # Generate dummy random field for warmup
337
- warmup_key = random.PRNGKey(42)
338
- R_warmup = e.coeff_4 * random.normal(warmup_key, (3,) + dims, dtype=jnp_float)
314
+ # === JIT WARMUP: Pre-compile all functions to exclude compilation time ===
315
+ print("Warming up JIT compilation...")
339
316
 
340
- # Warmup all JIT functions with actual data shapes
341
- _ = compute_slope(g_params, e_params, m_n, R_warmup)
342
- if n_mean != 0:
343
- _ = compute_space_average_jax(m_n[0])
317
+ # Generate dummy random field for warmup
318
+ warmup_key = random.PRNGKey(42)
319
+ R_warmup = self.elem.coeff_4 * random.normal(
320
+ warmup_key, (3, *self.grid.dims), dtype=jnp_float
321
+ )
344
322
 
345
- # Force compilation and execution to complete
346
- jax.block_until_ready(m_n)
347
- print("JIT warmup completed.")
323
+ # Warmup all JIT functions with actual data shapes
324
+ _ = compute_slope(g_params, e_params, m_n, R_warmup)
325
+ if self.n_mean != 0:
326
+ _ = compute_space_average_jax(m_n[0])
348
327
 
349
- start_time = time.perf_counter()
328
+ # Force compilation and execution to complete
329
+ jax.block_until_ready(m_n)
330
+ print("JIT warmup completed.")
350
331
 
351
- for n in progress_bar(range(1, N + 1), "Iteration : ", 40):
352
- t += dt
332
+ start_time = time.perf_counter()
353
333
 
354
- # Generate random field for temperature effect
355
- key, subkey = random.split(key)
356
- R_random = e.coeff_4 * random.normal(subkey, (3,) + dims, dtype=jnp_float)
334
+ for n in self._progress_bar():
335
+ t += self.dt
357
336
 
358
- # Use JIT-compiled version for better performance
359
- s_pre = compute_slope(g_params, e_params, m_n, R_random)
360
- m_pre = m_n + dt * s_pre
361
- s_cor = compute_slope(g_params, e_params, m_pre, R_random)
337
+ # Generate random field for temperature effect
338
+ key, subkey = random.split(key)
339
+ R_random = self.elem.coeff_4 * random.normal(
340
+ subkey, (3, *self.grid.dims), dtype=jnp_float
341
+ )
362
342
 
363
- # Update magnetization
364
- m_n = m_n + dt * 0.5 * (s_pre + s_cor)
343
+ # Use JIT-compiled version for better performance
344
+ s_pre = compute_slope(g_params, e_params, m_n, R_random)
345
+ m_pre = m_n + self.dt * s_pre
346
+ s_cor = compute_slope(g_params, e_params, m_pre, R_random)
365
347
 
366
- # Renormalize to unit sphere
367
- norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
368
- m_n = m_n / norm
348
+ # Update magnetization
349
+ m_n = m_n + self.dt * 0.5 * (s_pre + s_cor)
369
350
 
370
- # Export the average of m1 to a file (optimized for GPU)
371
- if n_mean != 0 and n % n_mean == 0:
372
- # Compute space average directly on GPU without CPU transfer
373
- m1_mean = compute_space_average_jax(m_n[0])
374
- # Convert to Python float for file writing
375
- m1_mean = float(m1_mean)
376
- if n >= start_averaging:
377
- m1_average += m1_mean * n_mean
378
- f_mean.write(f"{t:10.8e} {m1_mean:10.8e}\n")
351
+ # Renormalize to unit sphere
352
+ norm = jnp.sqrt(m_n[0] ** 2 + m_n[1] ** 2 + m_n[2] ** 2)
353
+ m_n = m_n / norm
379
354
 
380
- total_time = time.perf_counter() - start_time
355
+ self._record(m_n, t, n)
381
356
 
382
- close_output_files(f_mean, f_profiles)
357
+ total_time = time.perf_counter() - start_time
383
358
 
384
- if n > start_averaging:
385
- m1_average /= N - start_averaging
359
+ self._finalize()
386
360
 
387
- return total_time, output_filenames, m1_average
361
+ return total_time