llg3d 2.0.1__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. llg3d/__init__.py +2 -4
  2. llg3d/benchmarks/__init__.py +1 -0
  3. llg3d/benchmarks/compare_commits.py +321 -0
  4. llg3d/benchmarks/efficiency.py +451 -0
  5. llg3d/benchmarks/utils.py +25 -0
  6. llg3d/element.py +98 -17
  7. llg3d/grid.py +48 -58
  8. llg3d/io.py +395 -0
  9. llg3d/main.py +32 -35
  10. llg3d/parameters.py +159 -49
  11. llg3d/post/__init__.py +1 -1
  12. llg3d/post/extract.py +105 -0
  13. llg3d/post/info.py +178 -0
  14. llg3d/post/m1_vs_T.py +90 -0
  15. llg3d/post/m1_vs_time.py +56 -0
  16. llg3d/post/process.py +87 -85
  17. llg3d/post/utils.py +38 -0
  18. llg3d/post/x_profiles.py +141 -0
  19. llg3d/py.typed +1 -0
  20. llg3d/solvers/__init__.py +153 -0
  21. llg3d/solvers/base.py +345 -0
  22. llg3d/solvers/experimental/__init__.py +9 -0
  23. llg3d/{solver → solvers/experimental}/jax.py +117 -143
  24. llg3d/solvers/math_utils.py +41 -0
  25. llg3d/solvers/mpi.py +370 -0
  26. llg3d/solvers/numpy.py +126 -0
  27. llg3d/solvers/opencl.py +439 -0
  28. llg3d/solvers/profiling.py +38 -0
  29. {llg3d-2.0.1.dist-info → llg3d-3.0.0.dist-info}/METADATA +5 -2
  30. llg3d-3.0.0.dist-info/RECORD +36 -0
  31. {llg3d-2.0.1.dist-info → llg3d-3.0.0.dist-info}/WHEEL +1 -1
  32. llg3d-3.0.0.dist-info/entry_points.txt +9 -0
  33. llg3d/output.py +0 -107
  34. llg3d/post/plot_results.py +0 -61
  35. llg3d/post/temperature.py +0 -76
  36. llg3d/simulation.py +0 -95
  37. llg3d/solver/__init__.py +0 -45
  38. llg3d/solver/mpi.py +0 -450
  39. llg3d/solver/numpy.py +0 -207
  40. llg3d/solver/opencl.py +0 -330
  41. llg3d/solver/solver.py +0 -89
  42. llg3d-2.0.1.dist-info/RECORD +0 -25
  43. llg3d-2.0.1.dist-info/entry_points.txt +0 -4
  44. {llg3d-2.0.1.dist-info → llg3d-3.0.0.dist-info}/licenses/AUTHORS +0 -0
  45. {llg3d-2.0.1.dist-info → llg3d-3.0.0.dist-info}/licenses/LICENSE +0 -0
  46. {llg3d-2.0.1.dist-info → llg3d-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,451 @@
1
+ """Compare the efficiency of different solver for different domain sizes."""
2
+
3
+ import argparse
4
+ import platform
5
+ import subprocess
6
+ from itertools import cycle
7
+ from pathlib import Path
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from tabulate import tabulate
12
+
13
+ from .utils import ChdirTemporaryDirectory
14
+ from ..post import extract
15
+
16
+ JyJz = 24
17
+ DOMAIN_SIZES: tuple[int, ...] = (2**7, 2**8, 2**9, 2**10, 2**11, 2**12)
18
+ CSV_FILENAME: str = "bench_efficiency.csv"
19
+
20
+
21
+ def get_num_iterations(Jx: int) -> int:
22
+ """Get the number of iterations for the benchmark based on domain size."""
23
+ return max(2**17 // Jx, 1)
24
+
25
+
26
+ def get_gpu_name() -> str:
27
+ """Get the name of the OpenCL device to use for the benchmark."""
28
+ from llg3d.solvers.opencl import get_context_and_device
29
+
30
+ _, device = get_context_and_device("gpu")
31
+ try:
32
+ # AMD GPUs have a more user-friendly board name
33
+ return str(device.board_name_amd)
34
+ except AttributeError:
35
+ # Fallback to the device name
36
+ return str(device.name)
37
+
38
+
39
+ def get_cpu_name() -> str | None:
40
+ """Get the CPU name of the current machine."""
41
+ try:
42
+ os = platform.system()
43
+ if os == "Linux":
44
+ with open("/proc/cpuinfo") as f:
45
+ for line in f:
46
+ if "model name" in line:
47
+ return line.split(":", 1)[1].strip()
48
+ if os == "Darwin":
49
+ return subprocess.check_output(
50
+ ["sysctl", "-n", "machdep.cpu.brand_string"], text=True
51
+ ).strip()
52
+ if os == "Windows":
53
+ return (
54
+ subprocess.check_output(["wmic", "cpu", "get", "Name"], text=True)
55
+ .splitlines()[1]
56
+ .strip()
57
+ )
58
+ else:
59
+ return platform.processor()
60
+ except Exception:
61
+ return None
62
+
63
+
64
+ def get_cpu_gpu_legend() -> str:
65
+ """Get a legend string with CPU and GPU names."""
66
+ cpu_name = get_cpu_name()
67
+ gpu_name = get_gpu_name()
68
+ return f"CPU: {cpu_name} | GPU: {gpu_name}"
69
+
70
+
71
+ Results = dict[str, list[float]]
72
+
73
+
74
+ def run_benchmark(mpi_nprocs: int, repeats: int = 1) -> tuple[Results, Results | None]:
75
+ """
76
+ Run the benchmark for different solvers and domain sizes.
77
+
78
+ Args:
79
+ mpi_nprocs: Number of MPI processes to use for the MPI solver.
80
+ repeats: Number of times to repeat each measurement. If >1, function
81
+ returns both means and std-devs per solver/domain.
82
+
83
+ Returns:
84
+ (means, stds): two mappings from display-name (e.g. 'NumPy (1 CPU core)')
85
+ to lists of mean and std values per domain size. If `repeats == 1`,
86
+ `stds` is None.
87
+ """
88
+ solvers = {
89
+ "numpy": "NumPy (1 CPU core)",
90
+ "mpi": f"MPI ({mpi_nprocs} CPU cores)",
91
+ "opencl": "OpenCL (1 GPU)",
92
+ }
93
+ # Use display names as keys in results for simpler downstream handling
94
+ means: Results = {display: [] for display in solvers.values()}
95
+ stds: Results | None = (
96
+ {display: [] for display in solvers.values()} if repeats > 1 else None
97
+ )
98
+
99
+ for Jx in DOMAIN_SIZES:
100
+ N = get_num_iterations(Jx)
101
+ for solver in solvers:
102
+ # collect `repeats` measurements
103
+ values: list[float] = []
104
+ for r in range(repeats):
105
+ # Move to a temporary directory to avoid overwriting files
106
+ with ChdirTemporaryDirectory():
107
+ cmd = [
108
+ "llg3d",
109
+ "--N",
110
+ str(N),
111
+ "--Jx",
112
+ str(Jx),
113
+ "--Jy",
114
+ str(JyJz),
115
+ "--Jz",
116
+ str(JyJz),
117
+ "--n_mean",
118
+ "0",
119
+ "--precision",
120
+ "single",
121
+ "--result_file",
122
+ "run.npz",
123
+ "--solver",
124
+ solver,
125
+ ]
126
+ if solver == "mpi":
127
+ cmd = ["mpirun", "-np", str(mpi_nprocs)] + cmd
128
+ print(f"Running: {' '.join(cmd)} (run {r + 1}/{repeats})")
129
+ result = subprocess.run(cmd, capture_output=True, text=True)
130
+ if result.stderr:
131
+ # show output for debugging if there's an error
132
+ print("STDOUT:", result.stdout)
133
+ print("STDERR:", result.stderr)
134
+ (eff,) = extract.extract_values(
135
+ Path("run.npz"), "results/metrics/efficiency"
136
+ )
137
+ values.append(float(eff))
138
+
139
+ display_key = solvers[solver]
140
+ mean_val = float(np.mean(values))
141
+ means[display_key].append(mean_val)
142
+ if stds is not None:
143
+ stds[display_key].append(float(np.std(values, ddof=0)))
144
+
145
+ return means, stds
146
+
147
+
148
+ def save_as_csv(
149
+ results: Results,
150
+ filename: str,
151
+ stds: Results | None = None,
152
+ legend: str | None = None,
153
+ ) -> None:
154
+ """
155
+ Save the benchmark results as a CSV file.
156
+
157
+ Args:
158
+ results: The benchmark results.
159
+ filename: The output CSV filename.
160
+ stds: Optional standard deviations for the results.
161
+ legend: An optional legend string to include as metadata.
162
+ """
163
+ import json
164
+
165
+ solver_keys = list(results.keys())
166
+ with open(filename, "w") as f:
167
+ # write metadata as JSON comment on the first line if provided
168
+ if legend is not None:
169
+ metadata = {"legend": legend}
170
+ f.write("#META " + json.dumps(metadata) + "\n")
171
+ # If stds provided, write two columns per solver: <solver>, <solver> std
172
+ header = ["Domain size"]
173
+ for solver in solver_keys:
174
+ header.append(solver)
175
+ if stds is not None:
176
+ header.append(f"{solver} std")
177
+ f.write(",".join(header) + "\n")
178
+ for i, J in enumerate(DOMAIN_SIZES):
179
+ row = [str(J)]
180
+ for solver in solver_keys:
181
+ val = results[solver][i]
182
+ row.append(f"{val:.6e}")
183
+ if stds is not None:
184
+ std_val = stds[solver][i]
185
+ row.append(f"{std_val:.6e}")
186
+ f.write(",".join(row) + "\n")
187
+
188
+
189
+ def results_as_table(
190
+ results: Results, stds: Results | None = None, legend: str | None = ""
191
+ ) -> str:
192
+ """
193
+ Format the benchmark results as a table.
194
+
195
+ Add an acceleration column showing speedup relative to the NumPy solver.
196
+
197
+ Args:
198
+ results: The benchmark results.
199
+ stds: Optional standard deviations for the results.
200
+ legend: An optional legend string to include above the table.
201
+
202
+ Returns:
203
+ A string representing the results in table format.
204
+ """
205
+ solver_keys = list(results.keys())
206
+ # legend should be provided by the caller (run or metadata); do not auto-fetch
207
+ if legend is None:
208
+ legend = ""
209
+ # Ensure NumPy is the reference for acceleration
210
+ numpy_key = next((k for k in solver_keys if k.startswith("NumPy")), solver_keys[0])
211
+ # Reorder keys so numpy_key is first
212
+ ordered_keys = [numpy_key] + [k for k in solver_keys if k != numpy_key]
213
+ headers = ["Domain size", numpy_key]
214
+ for solver in ordered_keys[1:]:
215
+ headers.append(solver + " (Accel)")
216
+
217
+ table = []
218
+ numpy_times = results[numpy_key]
219
+ for i, J in enumerate(DOMAIN_SIZES):
220
+ row = [str(J)]
221
+ val_numpy = results[numpy_key][i]
222
+ if stds is not None:
223
+ std_numpy = stds[numpy_key][i]
224
+ row.append(f"{val_numpy:.1e} ± {std_numpy:.1e}")
225
+ else:
226
+ row.append(f"{val_numpy:.1e}")
227
+ for solver in ordered_keys[1:]:
228
+ val = results[solver][i]
229
+ accel = numpy_times[i] / val if val != 0 else float("inf")
230
+ if stds is not None:
231
+ std_val = stds[solver][i]
232
+ row.append(f"{val:.1e} ± {std_val:.1e} ({accel:5.1f}x)")
233
+ else:
234
+ row.append(f"{val:.1e} ({accel:5.1f}x)")
235
+ table.append(row)
236
+ tab = tabulate(
237
+ table, headers=headers, tablefmt="simple", numalign="right", stralign="right"
238
+ )
239
+ return f"{legend}\n{tab}"
240
+
241
+
242
+ def plot(
243
+ results: Results,
244
+ show: bool = False,
245
+ legend: str | None = None,
246
+ errorbars: bool = False,
247
+ stds: Results | None = None,
248
+ ) -> None:
249
+ """
250
+ Plot the benchmark results.
251
+
252
+ Args:
253
+ results: The benchmark results.
254
+ show: Whether to display the plot interactively .
255
+ legend: An optional legend string to include in the plot title.
256
+ errorbars: Whether to plot error bars using stds if provided.
257
+ stds: Optional standard deviations for the results.
258
+ """
259
+ fig, ax = plt.subplots(figsize=(7, 5))
260
+ fig.suptitle("LLG3D efficiency vs domain size")
261
+ markers = cycle(["o", "s", "^", "D", "v", "x", "*", "+"])
262
+ for solver in results.keys():
263
+ if errorbars and stds is not None and solver in stds:
264
+ ax.errorbar(
265
+ DOMAIN_SIZES,
266
+ results[solver],
267
+ yerr=stds[solver],
268
+ marker=next(markers),
269
+ label=solver,
270
+ capsize=3,
271
+ linestyle="-",
272
+ )
273
+ else:
274
+ ax.plot(DOMAIN_SIZES, results[solver], marker=next(markers), label=solver)
275
+ ax.set_xlabel(f"Domain size $(J_x \\times {JyJz} \\times {JyJz})$")
276
+ ax.set_ylabel("Efficiency [s/iteration/point]")
277
+ ax.set_title(
278
+ legend if legend is not None else get_cpu_gpu_legend(),
279
+ fontsize=8,
280
+ )
281
+ ax.set_xscale("log", base=2)
282
+ ax.set_yscale("log")
283
+ ax.legend()
284
+ ax.grid(True, which="both", ls=":")
285
+ fig.tight_layout()
286
+ fig.savefig("bench_efficiency.png", dpi=300)
287
+ if show:
288
+ plt.show()
289
+
290
+
291
+ def load_csv(filename: str) -> tuple[Results, Results | None, dict]:
292
+ """
293
+ Load benchmark results from a CSV file.
294
+
295
+ Returns results keyed by display-name (same format as run_benchmark output).
296
+
297
+ Args:
298
+ filename: The CSV filename to load.
299
+
300
+ Returns:
301
+ A tuple (results, stds, metadata) where `results` is the benchmark results,
302
+ `stds` are the standard deviations if present (else None), and `metadata` is
303
+ any metadata extracted from the CSV file.
304
+ """
305
+ import csv
306
+ import json
307
+
308
+ results: Results = {}
309
+ metadata: dict = {}
310
+ stds: Results | None = None
311
+ with open(filename, "r") as f:
312
+ # read possible metadata comment lines
313
+ pos = f.tell()
314
+ first = f.readline()
315
+ if first.startswith("#META "):
316
+ try:
317
+ metadata = json.loads(first[len("#META ") :])
318
+ except Exception:
319
+ metadata = {"raw": first.strip()}
320
+ header_line = f.readline()
321
+ else:
322
+ # no metadata, rewind
323
+ f.seek(pos)
324
+ header_line = f.readline()
325
+ reader = csv.reader([header_line] + f.readlines())
326
+ header = next(reader)
327
+ domain_sizes = []
328
+ cols = header[1:]
329
+ # detect pairs: solver, solver std
330
+ solver_names: list[str] = []
331
+ solver_has_std: dict[str, bool] = {}
332
+ i = 0
333
+ while i < len(cols):
334
+ name = cols[i]
335
+ if i + 1 < len(cols) and cols[i + 1].strip().endswith(" std"):
336
+ solver_names.append(name)
337
+ solver_has_std[name] = True
338
+ if stds is None:
339
+ stds = {}
340
+ results[name] = []
341
+ if name not in stds:
342
+ stds[name] = []
343
+ i += 2
344
+ else:
345
+ solver_names.append(name)
346
+ solver_has_std[name] = False
347
+ results[name] = []
348
+ i += 1
349
+
350
+ for row in reader:
351
+ domain_sizes.append(int(row[0]))
352
+ j = 0
353
+ for name in solver_names:
354
+ val = row[1 + j]
355
+ val = val.split(" ")[0]
356
+ results[name].append(float(val))
357
+ j += 1
358
+ if solver_has_std.get(name, False):
359
+ if 1 + j < len(row):
360
+ std_val = row[1 + j]
361
+ std_val = std_val.split(" ")[0]
362
+ if stds is not None:
363
+ stds[name].append(float(std_val))
364
+ elif stds is not None:
365
+ # missing std value, append 0
366
+ stds[name].append(0.0)
367
+ j += 1
368
+ # Remap keys if needed to match original order
369
+ # Here, we assume the order is the same
370
+ return results, stds, metadata
371
+
372
+
373
+ def main():
374
+ """Main function to run the benchmark, plot, or report results."""
375
+ parser = argparse.ArgumentParser(description=__doc__)
376
+ subparsers = parser.add_subparsers(dest="mode", required=True)
377
+
378
+ parser_run = subparsers.add_parser(
379
+ "run",
380
+ help="Run the benchmark and save the CSV file",
381
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
382
+ )
383
+ parser_run.add_argument(
384
+ "--mpi-nprocs",
385
+ type=int,
386
+ default=32,
387
+ help="Number of MPI processes",
388
+ )
389
+ parser_run.add_argument(
390
+ "--repeats",
391
+ type=int,
392
+ default=1,
393
+ help="Number of repetitions per measurement",
394
+ )
395
+ parser_run.add_argument(
396
+ "--csv", type=str, default=CSV_FILENAME, help="CSV file to save"
397
+ )
398
+
399
+ parser_plot = subparsers.add_parser(
400
+ "plot",
401
+ help="Plot the graph from the CSV file",
402
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
403
+ )
404
+ parser_plot.add_argument(
405
+ "--csv", type=str, default=CSV_FILENAME, help="CSV file to load"
406
+ )
407
+ parser_plot.add_argument(
408
+ "--show", action="store_true", help="Show the plot interactively", default=False
409
+ )
410
+ parser_plot.add_argument(
411
+ "--errorbars",
412
+ action="store_true",
413
+ default=False,
414
+ help="Plot error bars using std columns if present in the CSV",
415
+ )
416
+
417
+ parser_report = subparsers.add_parser(
418
+ "report",
419
+ help="Print a results table from a CSV file",
420
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
421
+ )
422
+ parser_report.add_argument(
423
+ "--csv", type=str, default=CSV_FILENAME, help="CSV file to load"
424
+ )
425
+ parser_report.add_argument(
426
+ "--show-std",
427
+ action="store_true",
428
+ default=False,
429
+ help="Include std deviations in the report if present in the CSV",
430
+ )
431
+
432
+ args = parser.parse_args()
433
+
434
+ if args.mode == "run":
435
+ results, stds = run_benchmark(args.mpi_nprocs, repeats=args.repeats)
436
+ legend = get_cpu_gpu_legend()
437
+ save_as_csv(results, filename=args.csv, stds=stds, legend=legend)
438
+ print(results_as_table(results, stds=stds, legend=legend))
439
+ elif args.mode == "plot":
440
+ results, stds, metadata = load_csv(args.csv)
441
+ legend = metadata.get("legend") if metadata else None
442
+ plot(
443
+ results, show=args.show, legend=legend, errorbars=args.errorbars, stds=stds
444
+ )
445
+ elif args.mode == "report":
446
+ results, stds, metadata = load_csv(args.csv)
447
+ legend = metadata.get("legend") if metadata else None
448
+ if args.show_std:
449
+ print(results_as_table(results, stds=stds, legend=legend))
450
+ else:
451
+ print(results_as_table(results, stds=None, legend=legend))
@@ -0,0 +1,25 @@
1
+ """Utilities for benchmarks."""
2
+
3
+ from pathlib import Path
4
+ import os
5
+ import tempfile
6
+
7
+
8
+ class ChdirTemporaryDirectory(tempfile.TemporaryDirectory):
9
+ """
10
+ Context manager to create and change to a temporary directory.
11
+
12
+ On exit, returns to the previous working directory.
13
+ """
14
+
15
+ def __enter__(self) -> Path:
16
+ """Create and change to a temporary directory."""
17
+ self._prev_cwd = os.getcwd()
18
+ path = super().__enter__()
19
+ os.chdir(path)
20
+ return path
21
+
22
+ def __exit__(self, exc_type, exc_value, traceback):
23
+ """Return to the previous directory and clean up."""
24
+ os.chdir(self._prev_cwd)
25
+ return super().__exit__(exc_type, exc_value, traceback)
llg3d/element.py CHANGED
@@ -1,6 +1,7 @@
1
- """Module containing the definition of the chemical elements."""
1
+ """Define the chemical elements."""
2
2
 
3
3
  from abc import ABC
4
+ from typing import Literal
4
5
 
5
6
  import numpy as np
6
7
 
@@ -20,20 +21,29 @@ class Element(ABC):
20
21
  H_ext: External magnetic field strength
21
22
  g: Grid object representing the simulation grid
22
23
  dt: Time step for the simulation
24
+ dtype: Optional numpy dtype to cast scalar coefficients to
23
25
  """
24
26
 
25
- A = 0.0
26
- K = 0.0
27
- lambda_G = 0.0
28
- M_s = 0.0
29
- a_eff = 0.0
30
- anisotropy: str = ""
31
-
32
- def __init__(self, T: float, H_ext: float, g: Grid, dt: float) -> None:
27
+ A: float = 0.0 #: Exchange constant :math:`[J.m^{-1}]`
28
+ K: float = 0.0 #: Anisotropy constant :math:`[J.m^{-3}]`
29
+ lambda_G: float = 0.0 #: Damping parameter :math:`[1]`
30
+ M_s: float = 0.0 #: Saturation magnetization :math:`[A.m^{-1}]`
31
+ a_eff: float = 0.0 #: Effective lattice constant :math:`[m]`
32
+ anisotropy: Literal["uniaxial", "cubic"] #: Type of anisotropy
33
+
34
+ def __init__(
35
+ self,
36
+ T: float,
37
+ H_ext: float,
38
+ g: Grid,
39
+ dt: float,
40
+ dtype: np.dtype | None = None,
41
+ ) -> None:
33
42
  self.H_ext = H_ext
34
43
  self.g = g
35
44
  self.dt = dt
36
45
  self.gamma_0 = gamma * mu_0 #: Rescaled gyromagnetic ratio [mA^-1.s^-1]
46
+ self.d_0 = np.sqrt(self.A / np.abs(self.K)) #: Domain wall width [m]
37
47
 
38
48
  # --- Characteristic Scales ---
39
49
  self.coeff_1 = self.gamma_0 * 2.0 * self.A / (mu_0 * self.M_s)
@@ -44,16 +54,40 @@ class Element(ABC):
44
54
  T_simu = T * self.g.dx / self.a_eff
45
55
  # calculation of the random field related to temperature
46
56
  # (we only take the volume over one mesh)
47
- h_alea = np.sqrt(
57
+ h_random = np.sqrt(
48
58
  2 * self.lambda_G * k_B / (self.gamma_0 * mu_0 * self.M_s * self.g.dV)
49
59
  )
50
- H_alea = h_alea * np.sqrt(T_simu) * np.sqrt(1.0 / self.dt)
51
- self.coeff_4 = H_alea * self.gamma_0
60
+ H_random = h_random * np.sqrt(T_simu) * np.sqrt(1.0 / self.dt)
61
+ self.coeff_4 = H_random * self.gamma_0
52
62
 
53
- def get_CFL(self) -> float:
63
+ # If a dtype is provided, cast scalar coefficients to that dtype
64
+ if dtype is not None:
65
+ self._cast_to_dtype(dtype)
66
+
67
+ def _cast_to_dtype(self, dtype: np.dtype):
54
68
  """
69
+ Cast all float attributes of the instance to the specified numpy dtype.
70
+
71
+ Args:
72
+ dtype: The numpy dtype to cast to (e.g., np.float32, np.float64)
73
+ """
74
+ # Collect instance attributes and class attributes (excluding private)
75
+ names = set(vars(self)) | {
76
+ k for k in vars(self.__class__) if not k.startswith("__")
77
+ }
78
+ for nm in names:
79
+ val = getattr(self, nm)
80
+ if isinstance(val, (float, np.floating)):
81
+ setattr(self, nm, np.asarray(val, dtype=dtype))
82
+
83
+ def get_CFL(self) -> float:
84
+ r"""
55
85
  Returns the value of the CFL.
56
86
 
87
+ .. math::
88
+
89
+ CFL = \frac{dt \cdot 2 \gamma_0 A}{\mu_0 M_s dx^2}
90
+
57
91
  Returns:
58
92
  The CFL value
59
93
  """
@@ -79,6 +113,50 @@ class Element(ABC):
79
113
  "gamma_0": self.gamma_0,
80
114
  }
81
115
 
116
+ def compute_H_anisotropy(self, m: np.ndarray, H_aniso: np.ndarray):
117
+ r"""
118
+ Compute the anisotropy field.
119
+
120
+ For uniaxial anisotropy:
121
+
122
+ .. math::
123
+
124
+ \boldsymbol{H}_{\text{ani, uniaxial}}=\frac{2K}{\mu_0M_s}(\boldsymbol{e}_x
125
+ \cdot\boldsymbol{m})\boldsymbol{e}_x,\label{uniaxial}
126
+
127
+
128
+ For cubic anisotropy:
129
+
130
+ .. math::
131
+
132
+ \boldsymbol{H}_{\text{ani, cubic}}=-\frac{2K}{\mu_0M_s}\sum_{(i,j,k)\in I}
133
+ \left((\boldsymbol{e}_j\cdot\boldsymbol{m})^2+(\boldsymbol{e}_k\cdot
134
+ \boldsymbol{m})^2+(\boldsymbol{e}_j\cdot\boldsymbol{m})^2(\boldsymbol{e}_k
135
+ \cdot\boldsymbol{m})^2\right)(\boldsymbol{e}_i\cdot\boldsymbol{m})\boldsymbol{e}_i,\label{cubic}
136
+
137
+
138
+ Args:
139
+ m: Magnetization array (shape (3, nx, ny, nz)).
140
+ H_aniso: Pre-allocated output array (shape (3, nx, ny, nz)).
141
+
142
+ Raises:
143
+ ValueError: If the anisotropy type is unknown
144
+ """
145
+ if self.anisotropy == "uniaxial":
146
+ H_aniso[0] = self.coeff_2 * m[0]
147
+ H_aniso[1] = 0.0
148
+ H_aniso[2] = 0.0
149
+ elif self.anisotropy == "cubic":
150
+ m1, m2, m3 = m
151
+ m1m1 = m1 * m1
152
+ m2m2 = m2 * m2
153
+ m3m3 = m3 * m3
154
+ H_aniso[0] = -self.coeff_2 * (1 - m1m1 + m2m2 * m3m3) * m1
155
+ H_aniso[1] = -self.coeff_2 * (1 - m2m2 + m1m1 * m3m3) * m2
156
+ H_aniso[2] = -self.coeff_2 * (1 - m3m3 + m1m1 * m2m2) * m3
157
+ else:
158
+ raise ValueError(f"Unknown anisotropy type: {self.anisotropy}")
159
+
82
160
 
83
161
  class Cobalt(Element):
84
162
  """Cobalt element."""
@@ -113,12 +191,17 @@ class Nickel(Element):
113
191
  anisotropy = "cubic" #: Type of anisotropy (e.g., "uniaxial", "cubic")
114
192
 
115
193
 
116
- def get_element_class(element_name: str | type[Element]) -> type[Element]:
194
+ def get_element_class(element_name: str) -> type[Element]:
117
195
  """
118
196
  Get the class of the chemical element by its name.
119
197
 
198
+ Example:
199
+ >>> cls = get_element_class("Cobalt")
200
+ >>> cls.__name__
201
+ 'Cobalt'
202
+
120
203
  Args:
121
- element_name: The name of the element or its class
204
+ element_name: The name of the element
122
205
 
123
206
  Returns:
124
207
  The class of the element
@@ -126,8 +209,6 @@ def get_element_class(element_name: str | type[Element]) -> type[Element]:
126
209
  Raises:
127
210
  ValueError: If the element is not found
128
211
  """
129
- if isinstance(element_name, type):
130
- return element_name
131
212
  for cls in Element.__subclasses__():
132
213
  if cls.__name__ == element_name:
133
214
  return cls