llg3d 2.0.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. llg3d/__init__.py +3 -3
  2. llg3d/__main__.py +2 -2
  3. llg3d/benchmarks/__init__.py +1 -0
  4. llg3d/benchmarks/compare_commits.py +321 -0
  5. llg3d/benchmarks/efficiency.py +451 -0
  6. llg3d/benchmarks/utils.py +25 -0
  7. llg3d/element.py +118 -31
  8. llg3d/grid.py +51 -64
  9. llg3d/io.py +395 -0
  10. llg3d/main.py +36 -38
  11. llg3d/parameters.py +159 -49
  12. llg3d/post/__init__.py +1 -1
  13. llg3d/post/extract.py +105 -0
  14. llg3d/post/info.py +178 -0
  15. llg3d/post/m1_vs_T.py +90 -0
  16. llg3d/post/m1_vs_time.py +56 -0
  17. llg3d/post/process.py +82 -75
  18. llg3d/post/utils.py +38 -0
  19. llg3d/post/x_profiles.py +141 -0
  20. llg3d/py.typed +1 -0
  21. llg3d/solvers/__init__.py +153 -0
  22. llg3d/solvers/base.py +345 -0
  23. llg3d/solvers/experimental/__init__.py +9 -0
  24. llg3d/solvers/experimental/jax.py +361 -0
  25. llg3d/solvers/math_utils.py +41 -0
  26. llg3d/solvers/mpi.py +370 -0
  27. llg3d/solvers/numpy.py +126 -0
  28. llg3d/solvers/opencl.py +439 -0
  29. llg3d/solvers/profiling.py +38 -0
  30. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/METADATA +6 -3
  31. llg3d-3.0.0.dist-info/RECORD +36 -0
  32. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/WHEEL +1 -1
  33. llg3d-3.0.0.dist-info/entry_points.txt +9 -0
  34. llg3d/output.py +0 -108
  35. llg3d/post/plot_results.py +0 -65
  36. llg3d/post/temperature.py +0 -83
  37. llg3d/simulation.py +0 -104
  38. llg3d/solver/__init__.py +0 -45
  39. llg3d/solver/jax.py +0 -383
  40. llg3d/solver/mpi.py +0 -449
  41. llg3d/solver/numpy.py +0 -210
  42. llg3d/solver/opencl.py +0 -329
  43. llg3d/solver/solver.py +0 -93
  44. llg3d-2.0.0.dist-info/RECORD +0 -25
  45. llg3d-2.0.0.dist-info/entry_points.txt +0 -4
  46. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/AUTHORS +0 -0
  47. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/licenses/LICENSE +0 -0
  48. {llg3d-2.0.0.dist-info → llg3d-3.0.0.dist-info}/top_level.txt +0 -0
llg3d/solvers/base.py ADDED
@@ -0,0 +1,345 @@
1
+ """Define the base solver class."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, ClassVar
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from ..element import Element, get_element_class
13
+ from ..grid import Grid
14
+ from ..io import (
15
+ Metrics,
16
+ Observables,
17
+ RecordsBuffer,
18
+ format_profiling_table,
19
+ get_tqdm_file,
20
+ save_results,
21
+ )
22
+ from ..parameters import InitType, RunParameters
23
+ from . import rank, size
24
+ from .math_utils import normalize
25
+ from .profiling import ProfilingStats, timeit
26
+
27
+
28
+ @dataclass
29
+ class BaseSolver(ABC, RunParameters):
30
+ """Abstract data base class for LLG3D solvers."""
31
+
32
+ solver_type: ClassVar[str] = "base" #: Solver type name
33
+
34
+ def __post_init__(self) -> None:
35
+ """Initialize the solver after dataclass creation."""
36
+ # Ensure solver name matches class solver_type
37
+ if self.solver != self.solver_type:
38
+ raise ValueError(
39
+ f"Trying to initialize a {self.__class__.__name__}, but solver name "
40
+ f'mismatch: expected "{self.solver_type}", got "{self.solver}"'
41
+ )
42
+
43
+ # Initialize results structure aligned with RunResults
44
+ self.metrics: Metrics = {
45
+ "total_time": 0.0,
46
+ "time_per_ite": 0.0,
47
+ "efficiency": 0.0,
48
+ "CFL": 0.0,
49
+ }
50
+ # Physical observables
51
+ self.observables: Observables = {}
52
+ # Records with optional x_profiles and xyz_average
53
+ self.records: RecordsBuffer = {}
54
+
55
+ self.np = size # Add a parameter for the number of processes
56
+
57
+ self.np_float: np.dtype = np.dtype(
58
+ np.float64 if self.precision == "double" else np.float32
59
+ )
60
+ self.grid: Grid = Grid(self.Jx, self.Jy, self.Jz, self.dx)
61
+ if rank == 0:
62
+ print(self.grid)
63
+ # Reference the element class from the element string
64
+ ElementClass: type[Element] = get_element_class(self.element)
65
+ # Pass dtype to Element so its scalar coefficients have the correct
66
+ # numpy dtype (`self.np_float`) and avoid implicit promotion.
67
+ self.elem: Element = ElementClass(
68
+ self.T, self.H_ext, self.grid, self.dt, dtype=self.np_float
69
+ )
70
+
71
+ self.rng: np.random.Generator = self._init_rng()
72
+ self.profiling_stats: ProfilingStats = defaultdict(
73
+ lambda: {"time": 0.0, "calls": 0}
74
+ )
75
+ self._tqdm_file = get_tqdm_file() #: File object for tqdm output
76
+
77
+ def theta_init_0(self, t: float) -> np.ndarray:
78
+ """Initialization of theta with 0."""
79
+ return np.zeros(self.grid.dims, dtype=self.np_float)
80
+
81
+ def theta_init_dw(self, t: float) -> np.ndarray:
82
+ """Initialization of theta with a domain wall profile."""
83
+ x, _, _ = self.grid.get_mesh(local=size > 1, dtype=self.np_float)
84
+ return 2.0 * np.arctan(
85
+ np.exp(
86
+ -(
87
+ x
88
+ - self.grid.Lx / 2
89
+ + self.elem.d_0 * self.elem.coeff_3 * self.elem.lambda_G * t
90
+ )
91
+ / self.elem.d_0
92
+ )
93
+ )
94
+
95
+ def phi_init(self, t: float) -> np.ndarray:
96
+ """Initialization of phi."""
97
+ return (
98
+ np.zeros(self.grid.dims, dtype=self.np_float)
99
+ + self.elem.gamma_0 * self.elem.H_ext * t
100
+ )
101
+
102
+ def _init_m_n(self) -> np.ndarray:
103
+ """Initialize the magnetization array at time step n."""
104
+ m_n = np.zeros((3,) + self.grid.dims, dtype=self.np_float)
105
+
106
+ if self.init_type == "0":
107
+ theta = self.theta_init_0(0)
108
+ elif self.init_type == "dw":
109
+ theta = self.theta_init_dw(0)
110
+ else:
111
+ raise ValueError(
112
+ f"Unknown initialization type: {self.init_type}, "
113
+ f"should be in {InitType}"
114
+ )
115
+
116
+ phi = self.phi_init(0)
117
+
118
+ m_n[0] = np.cos(theta)
119
+ m_n[1] = np.sin(theta) * np.cos(phi)
120
+ m_n[2] = np.sin(theta) * np.sin(phi)
121
+ # renormalize to verify the constraint of being on the sphere
122
+ normalize(m_n)
123
+ return m_n
124
+
125
+ def _init_rng(self) -> np.random.Generator:
126
+ """
127
+ Initialize a random number generator for temperature fluctuations.
128
+
129
+ Returns:
130
+ A numpy random number generator
131
+ """
132
+ # Initialize a sequence of random seeds
133
+ # See: https://numpy.org/doc/stable/reference/random/parallel.html
134
+ ss = np.random.SeedSequence(self.seed)
135
+
136
+ # Deploy size x SeedSequence to pass to child processes
137
+ child_seeds = ss.spawn(size)
138
+ streams = [np.random.default_rng(s) for s in child_seeds]
139
+ rng = streams[rank]
140
+ return rng
141
+
142
+ @timeit
143
+ def _get_R_random(self) -> np.ndarray:
144
+ """
145
+ Generate the random field for temperature fluctuations.
146
+
147
+ Returns:
148
+ Random field array (shape (3, nx, ny, nz))
149
+ """
150
+ R_random = self.elem.coeff_4 * self.rng.standard_normal(
151
+ (3, *self.grid.dims), dtype=self.np_float
152
+ )
153
+ return R_random
154
+
155
+ @timeit
156
+ def _normalize(self, m_n: np.ndarray) -> None:
157
+ r"""
158
+ Normalize the magnetization array (in place).
159
+
160
+ .. math::
161
+
162
+ \mathbf{m}_n = \frac{\mathbf{m}_n}{|\mathbf{m}_n|}
163
+
164
+ Args:
165
+ m_n: Magnetization array at time step n (shape (3, nx, ny, nz)).
166
+ """
167
+ normalize(m_n)
168
+
169
+ @timeit
170
+ def _xyz_average(self, m: np.ndarray) -> float:
171
+ """
172
+ Returns the spatial average of m with shape (g.dims) using the midpoint method.
173
+
174
+ Args:
175
+ m: Array to be integrated
176
+
177
+ Returns:
178
+ Spatial average of m
179
+ """
180
+ mm = m.copy() # copy m to avoid modifying its value
181
+
182
+ # on the edges, we divide the contribution by 2
183
+ # x
184
+ mm[0, :, :] /= 2
185
+ mm[-1, :, :] /= 2
186
+ # y
187
+ mm[:, 0, :] /= 2
188
+ mm[:, -1, :] /= 2
189
+ # z
190
+ mm[:, :, 0] /= 2
191
+ mm[:, :, -1] /= 2
192
+
193
+ average = mm.sum() / self.grid.ncell
194
+ return float(average)
195
+
196
+ def _record_xyz_average(self, m_n: np.ndarray, t: float, n: int) -> None:
197
+ """Update the time average of m1."""
198
+ xyz_average = self._xyz_average(m_n[0])
199
+ if rank == 0:
200
+ # Ensure xyz_average list exists
201
+ if "xyz_average" not in self.records:
202
+ self.records["xyz_average"] = []
203
+ # Record the mean value at time t
204
+ self.records["xyz_average"].append((t, xyz_average))
205
+ # Update time average of m1
206
+ if n >= self.start_averaging:
207
+ # Initialize m1_mean on first use
208
+ if "m1_mean" not in self.observables:
209
+ self.observables["m1_mean"] = 0.0
210
+ # Accumulate time average (each sample contributes equally)
211
+ self.observables["m1_mean"] += xyz_average
212
+
213
+ def _finalize(self) -> None:
214
+ """Normalize m1_mean by the actual number of samples accumulated."""
215
+ if rank == 0:
216
+ if "m1_mean" in self.observables:
217
+ # Divide by actual number of samples
218
+ # (accounting for n_mean sampling interval)
219
+ num_samples = (self.N - self.start_averaging) // self.n_mean
220
+ self.observables["m1_mean"] /= num_samples
221
+
222
+ @timeit
223
+ def _yz_average(self, m_i: np.ndarray) -> np.ndarray:
224
+ """
225
+ Returns the spatial average of m using the midpoint method along y and z.
226
+
227
+ Args:
228
+ m_i: Array to be integrated
229
+
230
+ Returns:
231
+ Spatial average of m in y and z of shape (g.dims[0],)
232
+ """
233
+ # Make a copy of m to avoid modifying its value
234
+ mm = m_i.copy()
235
+
236
+ # On y and z edges, divide the contribution by 2
237
+ mm[:, 0, :] /= 2
238
+ mm[:, -1, :] /= 2
239
+ mm[:, :, 0] /= 2
240
+ mm[:, :, -1] /= 2
241
+
242
+ n_cell_yz = (mm.shape[1] - 1) * (mm.shape[2] - 1)
243
+ return mm.sum(axis=(1, 2)) / n_cell_yz
244
+
245
+ @timeit
246
+ def _update_x_profiles(self, m_n: np.ndarray, t: float) -> None:
247
+ """Update x profiles of the averaged m_i in y and z."""
248
+ # Initialize x_profiles on first use
249
+ if "x_profiles" not in self.records:
250
+ self.records["x_profiles"] = {
251
+ "t": [],
252
+ "m1": [],
253
+ "m2": [],
254
+ "m3": [],
255
+ }
256
+ x_prof = self.records["x_profiles"]
257
+ x_prof["t"].append(t)
258
+ x_prof["m1"].append(self._yz_average(m_n[0]))
259
+ x_prof["m2"].append(self._yz_average(m_n[1]))
260
+ x_prof["m3"].append(self._yz_average(m_n[2]))
261
+
262
+ def _record(self, m_n: Any, t: float, n: int) -> None:
263
+ """Record simulation data."""
264
+ # Record the average of m1
265
+ if self.n_mean != 0 and n % self.n_mean == 0:
266
+ self._record_xyz_average(m_n, t, n)
267
+ # Record the x profiles of the averaged m_i in y and z
268
+ if self.n_profile != 0 and n % self.n_profile == 0:
269
+ self._update_x_profiles(m_n, t)
270
+
271
+ def _progress_bar(self):
272
+ """Return a progress bar for the given range using tqdm."""
273
+ if self._tqdm_file is None:
274
+ # corresponds to rank != 0 in MPI
275
+ return range(self.N)
276
+
277
+ return tqdm(
278
+ range(self.N),
279
+ file=self._tqdm_file,
280
+ dynamic_ncols=True,
281
+ leave=True,
282
+ )
283
+
284
+ @abstractmethod
285
+ def _simulate(self) -> float:
286
+ """
287
+ Simulates the system for N iterations.
288
+
289
+ Returns:
290
+ total_time: Total simulation time
291
+ """
292
+
293
+ def run(self) -> None:
294
+ """Runs the simulation and store the results."""
295
+ total_time = self._simulate()
296
+ self.metrics["total_time"] = total_time
297
+ time_per_ite = total_time / self.N if self.N > 0 else 0.0
298
+ self.metrics["time_per_ite"] = time_per_ite
299
+ self.metrics["efficiency"] = time_per_ite / self.grid.ntot
300
+ self.metrics["CFL"] = float(self.elem.get_CFL())
301
+ if rank == 0:
302
+ # Store only profiling stats for functions/kernels that were actually called
303
+ profiling_filtered = {
304
+ k: v for k, v in self.profiling_stats.items() if v.get("calls", 0) > 0
305
+ }
306
+ if profiling_filtered:
307
+ self.metrics["profiling_stats"] = profiling_filtered
308
+
309
+ def _format_profiling(self) -> str:
310
+ """Format the profiling information for display."""
311
+ return format_profiling_table(self.profiling_stats, self.metrics["total_time"])
312
+
313
+ def save(self, dir_path: str | Path = ".") -> None:
314
+ """
315
+ Saves the results of the simulation to a .npz file.
316
+
317
+ Args:
318
+ dir_path: Directory path to save the results
319
+ """
320
+ if rank == 0:
321
+ if self.metrics["total_time"] > 0:
322
+ s = f"""\
323
+ N iterations = {self.N}
324
+ total_time [s] = {self.metrics["total_time"]:.03f}
325
+ time/ite [s/ite] = {self.metrics["time_per_ite"]:.03e}
326
+ efficiency [s/ite/pt] = {self.metrics["efficiency"]:.03e}
327
+ CFL = {self.metrics["CFL"]:.03e}"""
328
+ print(s)
329
+
330
+ # Print profiling info only if enabled
331
+ if self.profiling:
332
+ print(f"Profiling info:\n{self._format_profiling()}")
333
+
334
+ # Export the mean of m1 over space and time
335
+ if "m1_mean" in self.observables:
336
+ print(f"m1_mean = {self.observables['m1_mean']:e}")
337
+
338
+ print(f"Saving {self.result_file}")
339
+ save_results(
340
+ Path(dir_path) / self.result_file,
341
+ self,
342
+ self.metrics,
343
+ observables=self.observables if self.observables else None,
344
+ records_buffer=self.records if self.records else None,
345
+ )
@@ -0,0 +1,9 @@
1
+ """
2
+ Experimental solvers live here.
3
+
4
+ Currently contains JAX-based solver implementations.
5
+ """
6
+
7
+ from .jax import JaxSolver
8
+
9
+ __all__ = ["JaxSolver"]