llg3d 2.0.1__py3-none-any.whl → 3.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. llg3d/__init__.py +2 -4
  2. llg3d/benchmarks/__init__.py +1 -0
  3. llg3d/benchmarks/compare_commits.py +321 -0
  4. llg3d/benchmarks/efficiency.py +451 -0
  5. llg3d/benchmarks/utils.py +25 -0
  6. llg3d/element.py +98 -17
  7. llg3d/grid.py +48 -58
  8. llg3d/io.py +395 -0
  9. llg3d/main.py +32 -35
  10. llg3d/parameters.py +159 -49
  11. llg3d/post/__init__.py +1 -1
  12. llg3d/post/extract.py +112 -0
  13. llg3d/post/info.py +192 -0
  14. llg3d/post/m1_vs_T.py +107 -0
  15. llg3d/post/m1_vs_time.py +81 -0
  16. llg3d/post/process.py +87 -85
  17. llg3d/post/utils.py +38 -0
  18. llg3d/post/x_profiles.py +161 -0
  19. llg3d/py.typed +1 -0
  20. llg3d/solvers/__init__.py +153 -0
  21. llg3d/solvers/base.py +345 -0
  22. llg3d/solvers/experimental/__init__.py +9 -0
  23. llg3d/{solver → solvers/experimental}/jax.py +117 -143
  24. llg3d/solvers/math_utils.py +41 -0
  25. llg3d/solvers/mpi.py +370 -0
  26. llg3d/solvers/numpy.py +126 -0
  27. llg3d/solvers/opencl.py +439 -0
  28. llg3d/solvers/profiling.py +38 -0
  29. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/METADATA +5 -2
  30. llg3d-3.1.0.dist-info/RECORD +36 -0
  31. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/WHEEL +1 -1
  32. llg3d-3.1.0.dist-info/entry_points.txt +9 -0
  33. llg3d/output.py +0 -107
  34. llg3d/post/plot_results.py +0 -61
  35. llg3d/post/temperature.py +0 -76
  36. llg3d/simulation.py +0 -95
  37. llg3d/solver/__init__.py +0 -45
  38. llg3d/solver/mpi.py +0 -450
  39. llg3d/solver/numpy.py +0 -207
  40. llg3d/solver/opencl.py +0 -330
  41. llg3d/solver/solver.py +0 -89
  42. llg3d-2.0.1.dist-info/RECORD +0 -25
  43. llg3d-2.0.1.dist-info/entry_points.txt +0 -4
  44. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/AUTHORS +0 -0
  45. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/licenses/LICENSE +0 -0
  46. {llg3d-2.0.1.dist-info → llg3d-3.1.0.dist-info}/top_level.txt +0 -0
llg3d/grid.py CHANGED
@@ -1,9 +1,10 @@
1
- """Module to define the grid for the simulation."""
1
+ """Define the computational grid for the simulation."""
2
2
 
3
- from dataclasses import dataclass
3
+ from typing import ClassVar
4
+ from dataclasses import dataclass, field, asdict
4
5
  import numpy as np
5
6
 
6
- from .solver import rank, size
7
+ from . import solvers
7
8
 
8
9
 
9
10
  @dataclass
@@ -15,109 +16,98 @@ class Grid:
15
16
  Jy: int #: number of points in y direction
16
17
  Jz: int #: number of points in z direction
17
18
  dx: float #: grid spacing in x direction
19
+ dy: float = field(init=False) #: grid spacing in y direction
20
+ dz: float = field(init=False) #: grid spacing in z direction
21
+ dV: float = field(init=False) #: elemental volume
22
+ Lx: float = field(init=False) #: physical length in x direction
23
+ Ly: float = field(init=False) #: physical length in y direction
24
+ Lz: float = field(init=False) #: physical length in z direction
25
+ dims: tuple[int, int, int] = field(init=False) #: local grid dimensions
26
+ V: float = field(init=False) #: total volume
27
+ ntot: int = field(init=False) #: total number of points
28
+ ncell: int = field(init=False) #: total number of cells
29
+ inv_dx2: float = field(init=False) #: :math:`1/dx^2`
30
+ inv_dy2: float = field(init=False) #: :math:`1/dy^2`
31
+ inv_dz2: float = field(init=False) #: :math:`1/dz^2`
32
+ center_coeff: float = field(init=False) #: center coefficient for Laplacian
33
+ uniform: ClassVar[bool] = True #: whether the grid is uniform
18
34
 
19
35
  def __post_init__(self) -> None:
20
36
  """Compute grid characteristics."""
21
- self.dy = self.dz = self.dx # Setting dx = dy = dz
37
+ self.dy = self.dz = self.dx # Enforce dx = dy = dz
22
38
  self.Lx = (self.Jx - 1) * self.dx
23
39
  self.Ly = (self.Jy - 1) * self.dy
24
40
  self.Lz = (self.Jz - 1) * self.dz
25
- # shape of the local array to the process
26
- self.dims = self.Jx // size, self.Jy, self.Jz
27
- # elemental volume of a cell
41
+ self.dims = self.Jx // solvers.size, self.Jy, self.Jz
28
42
  self.dV = self.dx * self.dy * self.dz
29
- # total volume
30
43
  self.V = self.Lx * self.Ly * self.Lz
31
- # total number of points
32
44
  self.ntot = self.Jx * self.Jy * self.Jz
33
45
  self.ncell = (self.Jx - 1) * (self.Jy - 1) * (self.Jz - 1)
46
+ # precompute the Laplacian coefficients (uniform grid spacing)
47
+ self.inv_dx2 = self.inv_dy2 = self.inv_dz2 = 1 / self.dx**2
48
+ self.center_coeff = -6.0 * self.inv_dx2
34
49
 
35
50
  def __str__(self) -> str:
36
- """Print grid information."""
51
+ """Return grid information."""
37
52
  header = "\t\t".join(("x", "y", "z"))
38
53
  s = f"""\
54
+ ---
39
55
  \t{header}
40
56
  J\t{self.Jx}\t\t{self.Jy}\t\t{self.Jz}
41
57
  L\t{self.Lx:.08e}\t{self.Ly:.08e}\t{self.Lz:.08e}
42
58
  d\t{self.dx:.08e}\t{self.dy:.08e}\t{self.dz:.08e}
43
-
59
+ ---
44
60
  dV = {self.dV:.08e}
45
61
  V = {self.V:.08e}
46
62
  ntot = {self.ntot:d}
47
63
  ncell = {self.ncell:d}
48
- """
64
+ ---"""
49
65
  return s
50
66
 
51
- def get_filename(
52
- self, T: float, name: str = "m1_mean", extension: str = "txt"
53
- ) -> str:
67
+ def get_x_coords(
68
+ self, local: bool = True, dtype: np.dtype = np.dtype(np.float64)
69
+ ) -> np.ndarray:
54
70
  """
55
- Returns the output file name for a given temperature.
71
+ Returns the x coordinates.
56
72
 
57
73
  Args:
58
- T: temperature
59
- name: file name
60
- extension: file extension
74
+ local: if True, returns the local coordinates,
75
+ otherwise the global coordinates
76
+ dtype: data type of the coordinates
61
77
 
62
78
  Returns:
63
- file name
64
-
65
- >>> g = Grid(Jx=300, Jy=21, Jz=21, dx=1.e-9)
66
- >>> g.get_filename(1100)
67
- 'm1_mean_T1100_300x21x21.txt'
79
+ 1D array with the x coordinates
68
80
  """
69
- suffix = f"T{int(T)}_{self.Jx}x{self.Jy}x{self.Jz}"
70
- return f"{name}_{suffix}.{extension}"
81
+ x_global = np.linspace(0, self.Lx, self.Jx, dtype=dtype) # global coordinates
82
+ # Split x into local parts if needed
83
+ return x_global if not local else np.split(x_global, solvers.size)[solvers.rank]
71
84
 
72
85
  def get_mesh(
73
- self, local: bool = True, dtype=np.float64
86
+ self, local: bool = True, dtype: np.dtype = np.dtype(np.float64)
74
87
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
75
88
  """
76
89
  Returns a meshgrid of the coordinates.
77
90
 
91
+ Use ij indexing.
92
+
78
93
  Args:
79
94
  local: if True, returns the local coordinates,
80
95
  otherwise the global coordinates
81
- dtype: data type of the coordinates)
96
+ dtype: data type of the coordinates
82
97
 
83
98
  Returns:
84
- tuple of 3D arrays with the coordinates
99
+ Tuple of 3D arrays with the coordinates
85
100
  """
86
- x_global = np.linspace(0, self.Lx, self.Jx, dtype=dtype) # global coordinates
101
+ x = self.get_x_coords(local=local, dtype=dtype)
87
102
  y = np.linspace(0, self.Ly, self.Jy, dtype=dtype)
88
103
  z = np.linspace(0, self.Lz, self.Jz, dtype=dtype)
89
- if local:
90
- x_local = np.split(x_global, size)[rank] # local coordinates
91
- return np.meshgrid(x_local, y, z, indexing="ij")
92
- else:
93
- return np.meshgrid(x_global, y, z, indexing="ij")
104
+ return np.meshgrid(x, y, z, indexing="ij")
94
105
 
95
- def to_dict(self) -> dict:
106
+ def as_dict(self) -> dict:
96
107
  """
97
108
  Export grid parameters to a dictionary for JAX JIT compatibility.
98
109
 
99
110
  Returns:
100
111
  Dictionary containing grid parameters needed for computations
101
112
  """
102
- return {
103
- "dx": self.dx,
104
- "dy": self.dy,
105
- "dz": self.dz,
106
- "Jx": self.Jx,
107
- "Jy": self.Jy,
108
- "Jz": self.Jz,
109
- "dV": self.dV,
110
- }
111
-
112
- def get_laplacian_coeff(self) -> tuple[float, float, float, float]:
113
- """
114
- Returns the coefficients for the laplacian computation.
115
-
116
- Returns:
117
- Tuple of coefficients (dx2_inv, dy2_inv, dz2_inv, center_coeff)
118
- """
119
- dx2_inv = 1 / self.dx**2
120
- dy2_inv = 1 / self.dy**2
121
- dz2_inv = 1 / self.dz**2
122
- center_coeff = -2 * (dx2_inv + dy2_inv + dz2_inv)
123
- return dx2_inv, dy2_inv, dz2_inv, center_coeff
113
+ return asdict(self)
llg3d/io.py ADDED
@@ -0,0 +1,395 @@
1
+ """
2
+ Input/Output functions.
3
+
4
+ Extensibility Guide
5
+ ===================
6
+
7
+ The I/O system is designed to be fully extensible. You can add arbitrary fields
8
+ to `metrics` and `records` without modifying this module.
9
+
10
+ Adding custom metrics
11
+ ---------------------
12
+
13
+ >>> solver.metrics['custom_value'] = 42.0
14
+ >>> solver.metrics['convergence_rate'] = 0.95
15
+
16
+ Adding custom records
17
+ ---------------------
18
+
19
+ >>> # Simple array
20
+ >>> solver.records['energy_evolution'] = np.array([...])
21
+
22
+ >>> # Nested structure (automatically flattened in .npz)
23
+ >>> solver.records['field_snapshots'] = {
24
+ ... 'times': np.array([0, 1, 2]),
25
+ ... 'data': np.array([...])
26
+ ... }
27
+
28
+ >>> # Deeply nested (unlimited depth)
29
+ >>> solver.records['analysis'] = {
30
+ ... 'spectra': {
31
+ ... 'fourier': np.array([...]),
32
+ ... 'wavelets': np.array([...])
33
+ ... }
34
+ ... }
35
+
36
+ TypedDicts (`Metrics`, `Records`, etc.) are documentation only. Runtime accepts
37
+ any dict[str, Any] structure.
38
+ """
39
+
40
+ from __future__ import annotations
41
+
42
+ import sys
43
+ from dataclasses import dataclass
44
+ from pathlib import Path
45
+ from typing import (
46
+ Any,
47
+ NotRequired,
48
+ TextIO,
49
+ TypedDict,
50
+ cast,
51
+ )
52
+
53
+ import numpy as np
54
+
55
+ from .parameters import RunParameters
56
+ from .solvers import rank, size
57
+ from .solvers.profiling import ProfilingStats
58
+
59
+
60
+ def get_tqdm_file() -> TextIO | None:
61
+ """Get a TQDM-compatible file for progress bar output in MPI."""
62
+ if size == 1:
63
+ return sys.stdout # single process uses default
64
+ if rank != 0:
65
+ return None # other ranks disable progress bar
66
+ try:
67
+ # For MPI, try to open /dev/tty for direct terminal output
68
+ return open("/dev/tty", "w")
69
+ except OSError:
70
+ return sys.stdout # fallback
71
+
72
+
73
+ class MetricsRequired(TypedDict):
74
+ """Required fields for simulation metrics."""
75
+
76
+ total_time: float #: Total wall-clock time of the simulation
77
+ time_per_ite: float #: Average time per iteration
78
+ efficiency: float #: Time per iteration per cell
79
+ CFL: float #: CFL condition value
80
+
81
+
82
+ class Metrics(MetricsRequired, total=False):
83
+ """
84
+ Structure for simulation metrics (extensible).
85
+
86
+ Required fields: total_time, time_per_ite, efficiency, CFL
87
+ All other fields are optional and can be added dynamically.
88
+ """
89
+
90
+ #: Profiling statistics
91
+ profiling_stats: NotRequired[ProfilingStats]
92
+
93
+
94
+ class Observables(TypedDict, total=False):
95
+ """
96
+ Physical observables from the simulation (extensible).
97
+
98
+ All fields are optional and can be added dynamically.
99
+ """
100
+
101
+ m1_mean: float #: Time-averaged magnetization in x direction
102
+
103
+
104
+ class XProfiles(TypedDict):
105
+ """Structure for cross-sectional profiles (arrays in final records)."""
106
+
107
+ t: np.ndarray #: Time points for profiles
108
+ m1: np.ndarray #: m1 component profiles
109
+ m2: np.ndarray #: m2 component profiles
110
+ m3: np.ndarray #: m3 component profiles
111
+
112
+
113
+ class XProfilesBuffer(TypedDict):
114
+ """Structure for cross-sectional profiles during accumulation (lists)."""
115
+
116
+ t: list[float] #: Time points for profiles
117
+ m1: list[np.ndarray] #: m1 component profiles
118
+ m2: list[np.ndarray] #: m2 component profiles
119
+ m3: list[np.ndarray] #: m3 component profiles
120
+
121
+
122
+ class RecordsBuffer(TypedDict, total=False):
123
+ """
124
+ Records during simulation (accumulation phase with lists).
125
+
126
+ BaseSolver.records uses this structure during simulation, accumulating
127
+ data as lists. Before saving, BaseSolver.save() converts lists to arrays.
128
+ """
129
+
130
+ xyz_average: list[tuple[float, float]] #: Accumulated (t, value) pairs
131
+ x_profiles: XProfilesBuffer #: Profiles during accumulation
132
+
133
+
134
+ class Records(TypedDict, total=False):
135
+ """
136
+ Records after saving (finalized with numpy arrays).
137
+
138
+ Returned by load_results() with all data as read-only numpy arrays.
139
+ """
140
+
141
+ xyz_average: np.ndarray #: Space-averaged magnetization over time (shape: (2, N))
142
+ x_profiles: XProfiles #: Cross-sectional profiles in yz plane
143
+
144
+
145
+ class SimulationResults(TypedDict):
146
+ """Structure for simulation results."""
147
+
148
+ metrics: Metrics #: Simulation metrics (performance, numerical quality)
149
+ observables: NotRequired[Observables] #: Physical observables (results)
150
+ records: NotRequired[Records] #: Optional time-series records
151
+
152
+
153
+ @dataclass
154
+ class RunResults:
155
+ """Structure for loaded simulation results."""
156
+
157
+ params: RunParameters
158
+ results: SimulationResults
159
+ file: str | Path
160
+
161
+ def get_record(self, record_name: str) -> np.ndarray | dict[str, np.ndarray]:
162
+ """Return a record by name from results or raise a descriptive error."""
163
+ try:
164
+ if "records" not in self.results:
165
+ raise KeyError(f"RunResults from '{self.file}' has no records.")
166
+
167
+ records = self.results["records"]
168
+ # Cast to dict to allow dynamic access with a variable key for mypy
169
+ record = cast(dict, records)[record_name]
170
+ if isinstance(record, (np.ndarray, dict)):
171
+ return record
172
+ else:
173
+ raise TypeError(f"'{record_name}' is not in the expected format.")
174
+ except KeyError as e:
175
+ msg = (
176
+ f"RunResults from '{self.file}' does not contain the required record "
177
+ f"'{record_name}'."
178
+ )
179
+ raise KeyError(msg) from e
180
+
181
+
182
+ def save_results(
183
+ output_file: str | Path,
184
+ params: RunParameters,
185
+ metrics: Metrics,
186
+ observables: Observables | None = None,
187
+ records_buffer: RecordsBuffer | None = None,
188
+ ) -> None:
189
+ """
190
+ Saves simulation results to a .npz file with hierarchical structure.
191
+
192
+ Args:
193
+ output_file: Path to the output .npz file.
194
+ params: Dataclass of simulation parameters.
195
+ metrics: Dictionary of simulation metrics. Must contain required fields:
196
+ total_time, time_per_ite, CFL. Additional fields are allowed.
197
+ observables: Dictionary of physical observables (results).
198
+ records_buffer: RecordsBuffer or custom dict of records
199
+ (accumulation phase with lists/dicts).
200
+
201
+ Note:
202
+ Both metrics and records accept any key-value pairs, allowing easy
203
+ extension without modifying this function. See Metrics and Records
204
+ TypedDict for recommended structure.
205
+
206
+ Example:
207
+ >>> records_buffer = {
208
+ ... 'xyz_average': arr1,
209
+ ... 'x_profiles': {'t': t, 'm1': m1, ...},
210
+ ... 'custom_data': arr2 # Any new field works automatically
211
+ ... }
212
+ >>> save_results('run.npz', params, metrics, records_buffer=records_buffer)
213
+ """
214
+ data_to_save: dict[str, Any] = {}
215
+
216
+ if records_buffer is None:
217
+ records_buffer = {}
218
+ if observables is None:
219
+ observables = {}
220
+
221
+ # Helper to convert lists to numpy arrays recursively
222
+ def _convert_lists_to_arrays(d: Any) -> Any:
223
+ """Recursively convert lists to numpy arrays."""
224
+ if isinstance(d, list):
225
+ return np.array(d)
226
+ elif isinstance(d, dict):
227
+ return {k: _convert_lists_to_arrays(v) for k, v in d.items()}
228
+ else:
229
+ return d
230
+
231
+ # Convert lists in records_buffer to arrays
232
+ records_buffer = _convert_lists_to_arrays(records_buffer)
233
+
234
+ # Helper to flatten nested dicts into NPZ keys
235
+ def _flatten_dict(d: dict | Any, prefix: str):
236
+ """Recursively flatten nested dicts into NPZ keys."""
237
+ if isinstance(d, dict):
238
+ for key, value in d.items():
239
+ full_key = f"{prefix}/{key}"
240
+ _flatten_dict(value, full_key)
241
+ else:
242
+ # Leaf value - save as array
243
+ data_to_save[prefix] = d
244
+
245
+ # Flatten params
246
+ _flatten_dict(params.as_dict(), "params")
247
+
248
+ # Flatten metrics
249
+ _flatten_dict(metrics, "results/metrics")
250
+
251
+ # Flatten observables
252
+ if observables:
253
+ _flatten_dict(observables, "results/observables")
254
+
255
+ # Flatten records
256
+ _flatten_dict(records_buffer, "results/records")
257
+
258
+ np.savez(output_file, **data_to_save, allow_pickle=False)
259
+
260
+
261
+ def load_results(result_file: str | Path) -> RunResults:
262
+ """
263
+ Loads simulation results from a .npz file with hierarchical structure.
264
+
265
+ Numpy arrays are set to read-only to prevent accidental modification.
266
+
267
+ Example:
268
+ >>> run_results = load_results("run.npz")
269
+ >>> params = run_results["params"]
270
+ >>> metrics = run_results["results"]["metrics"]
271
+ >>> xyz_average = run_results["results"]["records"]["xyz_average"]
272
+ >>> m1_prof = run_results["results"]["records"]["x_profiles"]["m1"]
273
+
274
+ Args:
275
+ result_file: path to the .npz file.
276
+
277
+ Returns:
278
+ Dictionary containing the loaded data with hierarchical structure.
279
+ """
280
+ npz_data: np.lib.npyio.NpzFile = np.load(result_file, allow_pickle=False)
281
+
282
+ def _unflatten_dict(prefix: str) -> dict[str, Any]:
283
+ """Reconstruct nested dict from flat NPZ keys with given prefix."""
284
+ result: dict[str, Any] = {}
285
+
286
+ for key in npz_data.files:
287
+ if not key.startswith(prefix + "/"):
288
+ continue
289
+
290
+ # Extract relative path after prefix
291
+ rel_path = key[len(prefix) + 1 :]
292
+ parts = rel_path.split("/")
293
+
294
+ # Navigate/create nested structure
295
+ current = result
296
+ for part in parts[:-1]:
297
+ if part not in current:
298
+ current[part] = {}
299
+ current = current[part]
300
+
301
+ # Store the value
302
+ arr = npz_data[key]
303
+ arr.flags.writeable = False
304
+ current[parts[-1]] = arr
305
+
306
+ return result
307
+
308
+ # Load params
309
+ params_dict = _unflatten_dict("params")
310
+ params = RunParameters(**params_dict)
311
+
312
+ # Load metrics
313
+ metrics: Metrics = cast(Metrics, _unflatten_dict("results/metrics"))
314
+
315
+ # Load observables
316
+ observables: Observables | None = None
317
+ observables_dict = _unflatten_dict("results/observables")
318
+ if observables_dict:
319
+ observables = cast(Observables, observables_dict)
320
+
321
+ # Load records
322
+ records: Records = cast(Records, _unflatten_dict("results/records"))
323
+
324
+ # Build the results dict
325
+ simulation_results: SimulationResults = {"metrics": metrics}
326
+ if observables:
327
+ simulation_results["observables"] = observables
328
+ if records:
329
+ simulation_results["records"] = records
330
+
331
+ run_results = RunResults(
332
+ params=params, results=simulation_results, file=result_file
333
+ )
334
+
335
+ return run_results
336
+
337
+
338
+ def format_profiling_table(
339
+ profiling_dict: dict[str, dict[str, float | int]], total_time: float | None = None
340
+ ) -> str:
341
+ """
342
+ Format profiling statistics as a table.
343
+
344
+ Args:
345
+ profiling_dict: Dictionary with profiling stats, where each entry has
346
+ 'time' (float) and 'calls' (int) keys
347
+ total_time: Total simulation time for percentage calculation.
348
+ If None, the percentage column is omitted.
349
+
350
+ Returns:
351
+ A formatted table string
352
+ """
353
+ if not profiling_dict:
354
+ return "(empty)"
355
+
356
+ # Determine the maximum width for alignment
357
+ col1_width = max(len(name) for name in profiling_dict.keys())
358
+
359
+ # Header
360
+ if total_time is not None:
361
+ s = (
362
+ f"{'Function':<{col1_width}} | {'Calls':>5} | {'total_time (s)':>14} "
363
+ f"| {'%':>6} | {'Avg Time (s)':>13}\n"
364
+ )
365
+ else:
366
+ s = (
367
+ f"{'Function':<{col1_width}} | {'Calls':>5} | {'total_time (s)':>14} "
368
+ f"| {'Avg Time (s)':>13}\n"
369
+ )
370
+ s += "-" * (len(s) - 1) + "\n"
371
+
372
+ # Sort by total time descending
373
+ sorted_profiling = dict(
374
+ sorted(profiling_dict.items(), key=lambda item: item[1]["time"], reverse=True)
375
+ )
376
+
377
+ # Data rows
378
+ for name, stats in sorted_profiling.items():
379
+ func_time = stats["time"]
380
+ calls = stats["calls"]
381
+ avg_time = func_time / calls if calls > 0 else 0.0
382
+
383
+ if total_time is not None:
384
+ percent = (func_time / total_time * 100) if total_time > 0 else 0.0
385
+ s += (
386
+ f"{name:<{col1_width}} | {calls:>5} | {func_time:>14.6f} "
387
+ f"| {percent:>5.1f}% | {avg_time:>13.6f}\n"
388
+ )
389
+ else:
390
+ s += (
391
+ f"{name:<{col1_width}} | {calls:>5} | {func_time:>14.6f} "
392
+ f"| {avg_time:>13.6f}\n"
393
+ )
394
+
395
+ return s
llg3d/main.py CHANGED
@@ -1,67 +1,64 @@
1
- """Define a CLI for the llg3d package."""
1
+ """Define a CLI for running LLG3D simulations."""
2
2
 
3
3
  import argparse
4
4
 
5
-
6
- from . import rank, size, LIB_AVAILABLE
7
- from .parameters import parameters, get_parameter_list
8
- from .simulation import Simulation
9
-
10
- if LIB_AVAILABLE["mpi4py"]:
11
- # Use the MPI version of the ArgumentParser
12
- from .solver.mpi import ArgumentParser
13
- else:
14
- # Use the original version of the ArgumentParser
15
- from argparse import ArgumentParser
5
+ from .parameters import arg_parameters
6
+ from . import solvers
16
7
 
17
8
 
18
9
  def parse_args(args: list[str] | None) -> argparse.Namespace:
19
10
  """
20
- Argument parser for llg3d.
11
+ Argument parser for LLG3D.
21
12
 
22
- Automatically adds arguments from the parameter dictionary.
13
+ Automatically adds arguments from :class:`~llg3d.parameters.arg_parameters`.
14
+
15
+ Args:
16
+ args: List of command line arguments
23
17
 
24
18
  Returns:
25
19
  argparse.Namespace: Parsed arguments
26
20
  """
21
+ ArgumentParser: type[argparse.ArgumentParser]
22
+ if solvers.size > 1:
23
+ # Use the MPI version of the ArgumentParser
24
+ from .solvers.mpi import ArgumentParser
25
+
26
+ else:
27
+ # Use the original version of the ArgumentParser
28
+ from argparse import ArgumentParser
29
+
27
30
  parser = ArgumentParser(
28
31
  description=__doc__,
29
32
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
30
33
  )
31
34
 
32
- if size > 1:
33
- parameters["solver"]["default"] = "mpi"
34
-
35
35
  # Automatically add arguments from the parameter dictionary
36
- for name, parameter in parameters.items():
36
+ for name, parameter in arg_parameters.items():
37
37
  if "action" not in parameter:
38
38
  parameter["type"] = type(parameter["default"])
39
- parser.add_argument(f"--{name}", **parameter)
39
+ parser.add_argument(f"--{name}", **parameter) # type: ignore[arg-type]
40
40
 
41
41
  return parser.parse_args(args)
42
42
 
43
43
 
44
- def main(arg_list: list[str] = None):
44
+ def main(arg_list: list[str] | None = None):
45
45
  """
46
46
  Evaluates the command line and runs the simulation.
47
47
 
48
48
  Args:
49
49
  arg_list: List of command line arguments
50
50
  """
51
+ if solvers.size > 1: # Ensure MPI global variables are initialized
52
+ from .solvers.mpi import initialize_mpi
53
+
54
+ initialize_mpi()
55
+
51
56
  args = parse_args(arg_list)
52
57
 
53
- if size > 1 and args.solver != "mpi":
54
- raise ValueError(f"Solver method {args.solver} is not compatible with MPI.")
55
- if args.solver == "mpi" and not LIB_AVAILABLE["mpi4py"]:
56
- raise ValueError(
57
- "The MPI solver method requires to install the mpi4py package, "
58
- "for example using pip: pip install mpi4py"
59
- )
60
-
61
- if rank == 0:
62
- # Display parameters as a list
63
- print(get_parameter_list(vars(args)))
64
-
65
- simulation = Simulation(vars(args))
66
- simulation.run()
67
- simulation.save()
58
+ Solver = solvers.get_solver_class(args.solver)
59
+ s = Solver(**vars(args))
60
+ if solvers.rank == 0:
61
+ # Display the solver parameters
62
+ print(s)
63
+ s.run()
64
+ s.save()