evoxels 0.1.1__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {evoxels-0.1.1 → evoxels-1.0.0}/LICENSE +0 -0
  2. {evoxels-0.1.1 → evoxels-1.0.0}/PKG-INFO +5 -4
  3. {evoxels-0.1.1 → evoxels-1.0.0}/README.md +3 -3
  4. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/__init__.py +0 -0
  5. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/boundary_conditions.py +0 -0
  6. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/fd_stencils.py +0 -0
  7. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/function_approximators.py +0 -0
  8. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/inversion.py +1 -1
  9. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/precompiled_solvers/__init__.py +0 -0
  10. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/precompiled_solvers/allen_cahn.py +1 -1
  11. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/precompiled_solvers/cahn_hilliard.py +1 -1
  12. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/problem_definition.py +31 -29
  13. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/profiler.py +27 -10
  14. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/solvers.py +113 -46
  15. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/timesteppers.py +6 -6
  16. evoxels-1.0.0/evoxels/utils.py +445 -0
  17. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/voxelfields.py +23 -13
  18. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/voxelgrid.py +0 -0
  19. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/PKG-INFO +5 -4
  20. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/SOURCES.txt +0 -0
  21. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/dependency_links.txt +0 -0
  22. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/requires.txt +1 -0
  23. {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/top_level.txt +0 -0
  24. {evoxels-0.1.1 → evoxels-1.0.0}/pyproject.toml +3 -2
  25. {evoxels-0.1.1 → evoxels-1.0.0}/setup.cfg +0 -0
  26. {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_fields.py +0 -0
  27. {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_inversion.py +0 -0
  28. {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_laplace.py +0 -0
  29. {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_rhs.py +1 -1
  30. {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_solvers.py +1 -1
  31. evoxels-0.1.1/evoxels/utils.py +0 -124
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evoxels
3
- Version: 0.1.1
3
+ Version: 1.0.0
4
4
  Summary: Differentiable physics framework for voxel-based microstructure simulations
5
5
  Author-email: Simon Daubner <s.daubner@imperial.ac.uk>
6
6
  License: MIT
@@ -40,6 +40,7 @@ Provides-Extra: notebooks
40
40
  Requires-Dist: ipywidgets; extra == "notebooks"
41
41
  Requires-Dist: ipympl; extra == "notebooks"
42
42
  Requires-Dist: notebook; extra == "notebooks"
43
+ Requires-Dist: taufactor; extra == "notebooks"
43
44
  Dynamic: license-file
44
45
 
45
46
  [![Python package](https://github.com/daubners/evoxels/actions/workflows/python-package.yml/badge.svg?branch=main)](https://github.com/daubners/evoxels/actions/workflows/python-package.yml)
@@ -81,7 +82,7 @@ TL;DR
81
82
  ```bash
82
83
  conda create --name voxenv python=3.12
83
84
  conda activate voxenv
84
- pip install evoxels[torch,jax,dev,notebooks]
85
+ pip install "evoxels[torch,jax,dev,notebooks]"
85
86
  pip install --upgrade "jax[cuda12]"
86
87
  ```
87
88
 
@@ -103,7 +104,7 @@ Navigate to the evoxels folder, then
103
104
  ```
104
105
  pip install -e .[torch] # install with torch backend
105
106
  pip install -e .[jax] # install with jax backend
106
- pip install -e .[dev, notebooks] # install testing and notebooks
107
+ pip install -e .[dev,notebooks] # install testing and notebooks
107
108
  ```
108
109
  Note that the default `[jax]` installation is only CPU compatible. To install the corresponding CUDA libraries check your CUDA version with
109
110
  ```bash
@@ -115,7 +116,7 @@ pip install -U "jax[cuda12]"
115
116
  ```
116
117
  To install both backends within one environment it is important to install torch first and then upgrade the `jax` installation e.g.
117
118
  ```bash
118
- pip install evoxels[torch, jax, dev, notebooks]
119
+ pip install "evoxels[torch,jax,dev,notebooks]"
119
120
  pip install --upgrade "jax[cuda12]"
120
121
  ```
121
122
  To work with the example notebooks install Jupyter and all notebook related dependencies via
@@ -37,7 +37,7 @@ TL;DR
37
37
  ```bash
38
38
  conda create --name voxenv python=3.12
39
39
  conda activate voxenv
40
- pip install evoxels[torch,jax,dev,notebooks]
40
+ pip install "evoxels[torch,jax,dev,notebooks]"
41
41
  pip install --upgrade "jax[cuda12]"
42
42
  ```
43
43
 
@@ -59,7 +59,7 @@ Navigate to the evoxels folder, then
59
59
  ```
60
60
  pip install -e .[torch] # install with torch backend
61
61
  pip install -e .[jax] # install with jax backend
62
- pip install -e .[dev, notebooks] # install testing and notebooks
62
+ pip install -e .[dev,notebooks] # install testing and notebooks
63
63
  ```
64
64
  Note that the default `[jax]` installation is only CPU compatible. To install the corresponding CUDA libraries check your CUDA version with
65
65
  ```bash
@@ -71,7 +71,7 @@ pip install -U "jax[cuda12]"
71
71
  ```
72
72
  To install both backends within one environment it is important to install torch first and then upgrade the `jax` installation e.g.
73
73
  ```bash
74
- pip install evoxels[torch, jax, dev, notebooks]
74
+ pip install "evoxels[torch,jax,dev,notebooks]"
75
75
  pip install --upgrade "jax[cuda12]"
76
76
  ```
77
77
  To work with the example notebooks install Jupyter and all notebook related dependencies via
File without changes
File without changes
@@ -72,7 +72,7 @@ class InversionModel:
72
72
  solver = PseudoSpectralIMEX_dfx(problem.fourier_symbol)
73
73
 
74
74
  solution = dfx.diffeqsolve(
75
- dfx.ODETerm(lambda t, y, args: problem.rhs(y, t)),
75
+ dfx.ODETerm(lambda t, y, args: problem.rhs(t, y)),
76
76
  solver,
77
77
  t0=saveat.subs.ts[0],
78
78
  t1=saveat.subs.ts[-1],
@@ -23,7 +23,7 @@ def run_allen_cahn_solver(
23
23
  plot_bounds = None,
24
24
  ):
25
25
  """
26
- Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
26
+ Solves time-dependent Allen-Cahn problem with ForwardEuler timestepper.
27
27
  """
28
28
  solver = TimeDependentSolver(
29
29
  voxelfields,
@@ -20,7 +20,7 @@ def run_cahn_hilliard_solver(
20
20
  plot_bounds = None,
21
21
  ):
22
22
  """
23
- Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
23
+ Solves time-dependent Cahn-Hilliard problem with PseudoSpectralIMEX timestepper.
24
24
  """
25
25
  solver = TimeDependentSolver(
26
26
  voxelfields,
@@ -13,17 +13,17 @@ _i_ = slice(1, -1) # inner elements [1:-1]
13
13
  class ODE(ABC):
14
14
  @property
15
15
  @abstractmethod
16
- def order(self):
16
+ def order(self) -> int:
17
17
  """Spatial order of convergence for numerical right-hand side."""
18
18
  pass
19
19
 
20
20
  @abstractmethod
21
- def rhs_analytic(self, u, t):
21
+ def rhs_analytic(self, t, u):
22
22
  """Sympy expression of the problem right-hand side.
23
23
 
24
24
  Args:
25
- u : Sympy function of current state.
26
25
  t (float): Current time.
26
+ u : Sympy function of current state.
27
27
 
28
28
  Returns:
29
29
  Sympy function of problem right-hand side.
@@ -31,12 +31,12 @@ class ODE(ABC):
31
31
  pass
32
32
 
33
33
  @abstractmethod
34
- def rhs(self, u, t):
34
+ def rhs(self, t, u):
35
35
  """Numerical right-hand side of the ODE system.
36
36
 
37
37
  Args:
38
- u (array): Current state.
39
38
  t (float): Current time.
39
+ u (array): Current state.
40
40
 
41
41
  Returns:
42
42
  Same type as ``u`` containing the time derivative.
@@ -120,9 +120,11 @@ class ReactionDiffusion(SemiLinearODE):
120
120
  bc_fun = self.vg.bc.pad_zero_flux_periodic
121
121
  self.pad_boundary = lambda field, bc0, bc1: bc_fun(field)
122
122
  k_squared = self.vg.fft_k_squared_nonperiodic()
123
+ else:
124
+ raise ValueError(f"Unsupported BC type: {self.BC_type}")
123
125
 
124
126
  self._fourier_symbol = -self.D * self.A * k_squared
125
-
127
+
126
128
  @property
127
129
  def order(self):
128
130
  return 2
@@ -130,14 +132,14 @@ class ReactionDiffusion(SemiLinearODE):
130
132
  @property
131
133
  def fourier_symbol(self):
132
134
  return self._fourier_symbol
133
-
134
- def _eval_f(self, c, t, lib):
135
+
136
+ def _eval_f(self, t, c, lib):
135
137
  """Evaluate source/forcing term using ``self.f``."""
136
138
  try:
137
- return self.f(c, t, lib)
139
+ return self.f(t, c, lib)
138
140
  except TypeError:
139
- return self.f(c, t)
140
-
141
+ return self.f(t, c)
142
+
141
143
  @property
142
144
  def bc_type(self):
143
145
  return self.BC_type
@@ -145,12 +147,12 @@ class ReactionDiffusion(SemiLinearODE):
145
147
  def pad_bc(self, u):
146
148
  return self.pad_boundary(u, self.bcs[0], self.bcs[1])
147
149
 
148
- def rhs_analytic(self, u, t):
149
- return self.D*spv.laplacian(u) + self._eval_f(u, t, sp)
150
-
151
- def rhs(self, u, t):
150
+ def rhs_analytic(self, t, u):
151
+ return self.D*spv.laplacian(u) + self._eval_f(t, u, sp)
152
+
153
+ def rhs(self, t, u):
152
154
  laplace = self.vg.laplace(self.pad_bc(u))
153
- update = self.D * laplace + self._eval_f(u, t, self.vg.lib)
155
+ update = self.D * laplace + self._eval_f(t, u, self.vg.lib)
154
156
  return update
155
157
 
156
158
  @dataclass
@@ -185,15 +187,15 @@ class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
185
187
  def pad_bc(self, u):
186
188
  return self.pad_boundary(u, self.bcs[0], self.bcs[1])
187
189
 
188
- def rhs_analytic(self, mask, u, t):
189
- grad = spv.gradient(u)
190
- norm_grad = sp.sqrt(grad.dot(grad))
190
+ def rhs_analytic(self, t, u, mask):
191
+ grad_m = spv.gradient(mask)
192
+ norm_grad_m = sp.sqrt(grad_m.dot(grad_m))
191
193
 
192
- divergence = spv.divergence(self.D*(grad - u/mask*spv.gradient(mask)))
193
- du = divergence + norm_grad*self.bc_flux + mask*self._eval_f(u/mask, t, sp)
194
+ divergence = spv.divergence(self.D*(spv.gradient(u) - u/mask*grad_m))
195
+ du = divergence + norm_grad_m*self.bc_flux + mask*self._eval_f(t, u/mask, sp)
194
196
  return du
195
197
 
196
- def rhs(self, u, t):
198
+ def rhs(self, t, u):
197
199
  z = self.pad_bc(u)
198
200
  divergence = self.vg.grad_x_face(self.vg.grad_x_face(z) -\
199
201
  self.vg.to_x_face(z/self.mask) * self.vg.grad_x_face(self.mask)
@@ -207,7 +209,7 @@ class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
207
209
 
208
210
  update = self.D * divergence + \
209
211
  self.norm*self.bc_flux + \
210
- self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(u/self.mask[:,1:-1,1:-1,1:-1], t, self.vg.lib)
212
+ self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(t, u/self.mask[:,1:-1,1:-1,1:-1], self.vg.lib)
211
213
  return update
212
214
 
213
215
 
@@ -249,13 +251,13 @@ class PeriodicCahnHilliard(SemiLinearODE):
249
251
  except TypeError:
250
252
  return self.mu_hom(c)
251
253
 
252
- def rhs_analytic(self, c, t):
254
+ def rhs_analytic(self, t, c):
253
255
  mu = self._eval_mu(c, sp) - 2*self.eps*spv.laplacian(c)
254
256
  fluxes = self.D*c*(1-c)*spv.gradient(mu)
255
257
  rhs = spv.divergence(fluxes)
256
258
  return rhs
257
259
 
258
- def rhs(self, c, t):
260
+ def rhs(self, t, c):
259
261
  r"""Evaluate :math:`\partial c / \partial t` for the CH equation.
260
262
 
261
263
  Numerical computation of
@@ -341,7 +343,7 @@ class AllenCahnEquation(SemiLinearODE):
341
343
  except TypeError:
342
344
  return self.potential(phi)
343
345
 
344
- def rhs_analytic(self, phi, t):
346
+ def rhs_analytic(self, t, phi):
345
347
  grad = spv.gradient(phi)
346
348
  laplace = spv.laplacian(phi)
347
349
  norm_grad = sp.sqrt(grad.dot(grad))
@@ -354,7 +356,7 @@ class AllenCahnEquation(SemiLinearODE):
354
356
  + 3/self.eps * phi * (1-phi) * self.force
355
357
  return self.M * df_dphi
356
358
 
357
- def rhs(self, phi, t):
359
+ def rhs(self, t, phi):
358
360
  r"""Two-phase Allen-Cahn equation
359
361
 
360
362
  Microstructural evolution of the order parameter ``\phi``
@@ -422,13 +424,13 @@ class CoupledReactionDiffusion(SemiLinearODE):
422
424
  except TypeError:
423
425
  return self.interaction(u)
424
426
 
425
- def rhs_analytic(self, u, t):
427
+ def rhs_analytic(self, t, u):
426
428
  interaction = self._eval_interaction(u, sp)
427
429
  dc_A = self.D_A*spv.laplacian(u[0]) - interaction + self.feed * (1-u[0])
428
430
  dc_B = self.D_B*spv.laplacian(u[1]) + interaction - self.kill * u[1]
429
431
  return (dc_A, dc_B)
430
432
 
431
- def rhs(self, u, t):
433
+ def rhs(self, t, u):
432
434
  r"""Two-component reaction-diffusion system
433
435
 
434
436
  Use batch channels for multiple species:
@@ -3,12 +3,20 @@ import psutil
3
3
  import os
4
4
  import subprocess
5
5
  import tracemalloc
6
+ import shutil
6
7
  from abc import ABC, abstractmethod
7
8
 
8
9
  class MemoryProfiler(ABC):
9
10
  """Base interface for tracking host and device memory usage."""
11
+ def __init__(self):
12
+ self.max_used_cpu = 0.0
13
+ self.max_used_gpu = 0.0
14
+ self.track_gpu = False # subclasses set this
15
+
10
16
  def get_cuda_memory_from_nvidia_smi(self):
11
17
  """Return currently used CUDA memory in megabytes."""
18
+ if shutil.which("nvidia-smi") is None:
19
+ return None
12
20
  try:
13
21
  output = subprocess.check_output(
14
22
  ['nvidia-smi', '--query-gpu=memory.used',
@@ -23,8 +31,10 @@ class MemoryProfiler(ABC):
23
31
  process = psutil.Process(os.getpid())
24
32
  used_cpu = process.memory_info().rss / 1024**2
25
33
  self.max_used_cpu = np.max((self.max_used_cpu, used_cpu))
26
- used = self.get_cuda_memory_from_nvidia_smi()
27
- self.max_used_gpu = np.max((self.max_used_gpu, used))
34
+ if self.track_gpu:
35
+ used = self.get_cuda_memory_from_nvidia_smi()
36
+ if used is not None:
37
+ self.max_used_gpu = np.max((self.max_used_gpu, used))
28
38
 
29
39
  @abstractmethod
30
40
  def print_memory_stats(self, start: float, end: float, iters: int):
@@ -35,13 +45,14 @@ class TorchMemoryProfiler(MemoryProfiler):
35
45
  def __init__(self, device):
36
46
  """Initialize the profiler for a given torch device."""
37
47
  import torch
48
+ super().__init__()
38
49
  self.torch = torch
39
50
  self.device = device
51
+ self.track_gpu = (device.type == 'cuda')
52
+
40
53
  tracemalloc.start()
41
- if device.type == 'cuda':
54
+ if self.track_gpu:
42
55
  torch.cuda.reset_peak_memory_stats(device=device)
43
- self.max_used_gpu = 0
44
- self.max_used_cpu = 0
45
56
 
46
57
  def print_memory_stats(self, start, end, iters):
47
58
  """Print usage statistics for the Torch backend."""
@@ -60,7 +71,10 @@ class TorchMemoryProfiler(MemoryProfiler):
60
71
  elif self.device.type == 'cuda':
61
72
  self.update_memory_stats()
62
73
  used = self.get_cuda_memory_from_nvidia_smi()
63
- print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
74
+ if used is None:
75
+ print("GPU-RAM (nvidia-smi) unavailable.")
76
+ else:
77
+ print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
64
78
  print(f"GPU-RAM (torch) current: "
65
79
  f"{self.torch.cuda.memory_allocated(self.device) / 1024**2:.2f} MB "
66
80
  f"({self.torch.cuda.max_memory_allocated(self.device) / 1024**2:.2f} MB max, "
@@ -70,9 +84,9 @@ class JAXMemoryProfiler(MemoryProfiler):
70
84
  def __init__(self):
71
85
  """Initialize the profiler for JAX."""
72
86
  import jax
87
+ super().__init__()
73
88
  self.jax = jax
74
- self.max_used_gpu = 0
75
- self.max_used_cpu = 0
89
+ self.track_gpu = any(d.platform == "gpu" for d in jax.devices())
76
90
  tracemalloc.start()
77
91
 
78
92
  def print_memory_stats(self, start, end, iters):
@@ -88,7 +102,10 @@ class JAXMemoryProfiler(MemoryProfiler):
88
102
  current = process.memory_info().rss / 1024**2
89
103
  print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")
90
104
 
91
- if self.jax.default_backend() == 'gpu':
105
+ if self.track_gpu:
92
106
  self.update_memory_stats()
93
107
  used = self.get_cuda_memory_from_nvidia_smi()
94
- print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
108
+ if used is None:
109
+ print("GPU-RAM (nvidia-smi) unavailable.")
110
+ else:
111
+ print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
@@ -1,13 +1,14 @@
1
1
  from IPython.display import clear_output
2
2
  from dataclasses import dataclass
3
3
  from typing import Callable, Any, Type
4
+ from abc import ABC, abstractmethod
4
5
  from timeit import default_timer as timer
5
6
  import sys
6
7
  from .problem_definition import ODE
7
8
  from .timesteppers import TimeStepper
8
9
 
9
10
  @dataclass
10
- class TimeDependentSolver:
11
+ class BaseSolver(ABC):
11
12
  """Generic wrapper for solving one or more fields with a time stepper."""
12
13
  vf: Any # VoxelFields object
13
14
  fieldnames: str | list[str]
@@ -33,46 +34,24 @@ class TimeDependentSolver:
33
34
  self.profiler = JAXMemoryProfiler()
34
35
  else:
35
36
  raise ValueError(f"Unsupported backend: {self.backend}")
36
-
37
- def solve(
38
- self,
39
- time_increment=0.1,
40
- frames=10,
41
- max_iters=100,
42
- problem_kwargs=None,
43
- jit=True,
44
- verbose=True,
45
- vtk_out=False,
46
- plot_bounds=None,
47
- colormap='viridis'
48
- ):
49
- """Run the time integration loop.
50
-
51
- Args:
52
- time_increment (float): Size of a single time step.
53
- frames (int): Number of output frames (for plotting, vtk, checks).
54
- max_iters (int): Number of time steps to compute.
55
- problem_kwargs (dict | None): Problem-specific input arguments.
56
- jit (bool): Create just-in-time compiled kernel if ``True``
57
- verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
58
- updates an interactive plot.
59
- vtk_out (bool): Write VTK files for each frame if ``True``.
60
- plot_bounds (tuple | None): Optional value range for plots.
61
- """
62
-
63
- problem_kwargs = problem_kwargs or {}
37
+
64
38
  if isinstance(self.fieldnames, str):
65
39
  self.fieldnames = [self.fieldnames]
66
40
  else:
67
41
  self.fieldnames = list(self.fieldnames)
68
42
 
43
+ def _init_fields(self):
44
+ """Initialize fields in the voxel grid."""
69
45
  u_list = [self.vg.init_scalar_field(self.vf.fields[name]) for name in self.fieldnames]
70
46
  u = self.vg.concatenate(u_list, 0)
71
47
  u = self.vg.bc.trim_boundary_nodes(u)
72
-
48
+ return u
49
+
50
+ def _init_stepper(self, time_increment, problem_kwargs, jit):
51
+ problem_kwargs = problem_kwargs or {}
73
52
  if self.step_fn is not None:
74
- step = self.step_fn
75
53
  self.problem = None
54
+ step = self.step_fn
76
55
  else:
77
56
  if self.problem_cls is None or self.timestepper_cls is None:
78
57
  raise ValueError("Either provide step_fn or both problem_cls and timestepper_cls")
@@ -88,23 +67,47 @@ class TimeDependentSolver:
88
67
  import torch
89
68
  step = torch.compile(step)
90
69
 
91
- n_out = max_iters // frames
92
- frame = 0
93
- slice_idx = self.vf.Nz // 2
70
+ return step
71
+
72
+ @abstractmethod
73
+ def _run_loop(self, u, step, time_increment, frames, max_iters,
74
+ vtk_out, verbose, plot_bounds, colormap):
75
+ """Abstract method for running the time integration loop."""
76
+ raise NotImplementedError("Subclasses must implement _run_loop method.")
94
77
 
95
- start = timer()
96
- for i in range(max_iters):
97
- time = i * time_increment
98
- if i % n_out == 0:
99
- self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
100
- frame += 1
78
+ def solve(
79
+ self,
80
+ time_increment=0.1,
81
+ frames=10,
82
+ max_iters=100,
83
+ problem_kwargs=None,
84
+ jit=True,
85
+ verbose=True,
86
+ vtk_out=False,
87
+ plot_bounds=None,
88
+ colormap='viridis'
89
+ ):
90
+ """Run the time integration loop.
101
91
 
102
- u = step(u, time)
92
+ Args:
93
+ time_increment (float): Size of a single time step.
94
+ frames (int): Number of output frames (for plotting, vtk, checks).
95
+ max_iters (int): Number of time steps to compute.
96
+ problem_kwargs (dict | None): Problem-specific input arguments.
97
+ jit (bool): Create just-in-time compiled kernel if ``True``
98
+ verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
99
+ updates an interactive plot.
100
+ vtk_out (bool): Write VTK files for each frame if ``True``.
101
+ plot_bounds (tuple | None): Optional value range for plots.
102
+ """
103
+ u = self._init_fields()
104
+ step = self._init_stepper(time_increment, problem_kwargs, jit)
103
105
 
106
+ start = timer()
107
+ u = self._run_loop(u, step, time_increment, frames, max_iters,
108
+ vtk_out, verbose, plot_bounds, colormap)
104
109
  end = timer()
105
- time = max_iters * time_increment
106
- self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
107
-
110
+ self.computation_time = end - start
108
111
  if verbose:
109
112
  self.profiler.print_memory_stats(start, end, max_iters)
110
113
 
@@ -113,10 +116,10 @@ class TimeDependentSolver:
113
116
  if getattr(self, 'problem', None) is not None:
114
117
  u_out = self.vg.bc.trim_ghost_nodes(self.problem.pad_bc(u))
115
118
  else:
116
- u_out = u
119
+ u_out = self.vg.bc.trim_ghost_nodes(self.vg.pad_zeros(u))
117
120
 
118
121
  for i, name in enumerate(self.fieldnames):
119
- self.vf.fields[name] = self.vg.export_scalar_field_to_numpy(u_out[i:i+1])
122
+ self.vf.set_field(name, self.vg.export_scalar_field_to_numpy(u_out[i:i+1]))
120
123
 
121
124
  if verbose:
122
125
  self.profiler.update_memory_stats()
@@ -129,6 +132,70 @@ class TimeDependentSolver:
129
132
  filename = self.problem_cls.__name__ + "_" +\
130
133
  self.fieldnames[0] + f"_{frame:03d}.vtk"
131
134
  self.vf.export_to_vtk(filename=filename, field_names=self.fieldnames)
135
+
132
136
  if verbose == 'plot':
133
137
  clear_output(wait=True)
134
138
  self.vf.plot_slice(self.fieldnames[0], slice_idx, time=time, colormap=colormap, value_bounds=plot_bounds)
139
+
140
+ @dataclass
141
+ class TimeDependentSolver(BaseSolver):
142
+ """Solver for time-dependent problems."""
143
+ def _run_loop(self, u, step, time_increment, frames, max_iters,
144
+ vtk_out, verbose, plot_bounds, colormap):
145
+ n_out = max_iters // frames
146
+ frame = 0
147
+ slice_idx = self.vf.Nz // 2
148
+
149
+ for i in range(max_iters):
150
+ time = i * time_increment
151
+ if i % n_out == 0:
152
+ self._handle_outputs(u, frame, time, slice_idx, vtk_out,
153
+ verbose, plot_bounds, colormap)
154
+ frame += 1
155
+
156
+ u = step(time, u)
157
+ time = max_iters * time_increment
158
+ self._handle_outputs(u, frame, time, slice_idx, vtk_out,
159
+ verbose, plot_bounds, colormap)
160
+ return u
161
+
162
+ @dataclass
163
+ class SteadyStatePseudoTimeSolver(BaseSolver):
164
+ """Solver for steady-state problems."""
165
+ conv_crit: float = 1e-6
166
+ check_freq: int = 10
167
+
168
+ def _run_loop(self, u, step, time_increment, frames, max_iters,
169
+ vtk_out, verbose, plot_bounds, colormap):
170
+ slice_idx = self.vf.Nz // 2
171
+ self.converged = False
172
+ self.iter = 0
173
+
174
+ while not self.converged and self.iter < max_iters:
175
+ time = self.iter * time_increment
176
+ diff = u - step(time, u)
177
+ u = step(time, u)
178
+
179
+ if self.iter % self.check_freq == 0:
180
+ self.converged = self.check_convergence(diff, verbose)
181
+
182
+ self._handle_outputs(u, 0, time, slice_idx, vtk_out,
183
+ verbose, plot_bounds, colormap)
184
+ return u
185
+
186
+ def check_convergence(self, diff, verbose):
187
+ """Check for convergence based on relative change in fields."""
188
+ converged = True
189
+ for i, name in enumerate(self.fieldnames):
190
+ # Check if Frobenius norm of change is below threshold
191
+ rel_change = self.vg.lib.linalg.norm(diff[i]) / \
192
+ self.vg.lib.sqrt(self.vf.Nx * self.vf.Ny * self.vf.Nz)
193
+ if rel_change > self.conv_crit:
194
+ converged = False
195
+ if verbose:
196
+ print(f"Iter {self.iter}: Field '{name}' relative change: {rel_change:.2e}")
197
+
198
+ if converged and verbose:
199
+ print(f"Converged after {self.iter} iterations.")
200
+
201
+ return converged
@@ -16,13 +16,13 @@ class TimeStepper(ABC):
16
16
  pass
17
17
 
18
18
  @abstractmethod
19
- def step(self, u: State, t: float) -> State:
19
+ def step(self, t: float, u: State) -> State:
20
20
  """
21
21
  Take one timestep from t to (t+dt).
22
22
 
23
23
  Args:
24
- u : Current state
25
24
  t : Current time
25
+ u : Current state
26
26
  Returns:
27
27
  Updated state at t + dt.
28
28
  """
@@ -39,8 +39,8 @@ class ForwardEuler(TimeStepper):
39
39
  def order(self) -> int:
40
40
  return 1
41
41
 
42
- def step(self, u: State, t: float) -> State:
43
- return u + self.dt * self.problem.rhs(u, t)
42
+ def step(self, t: float, u: State) -> State:
43
+ return u + self.dt * self.problem.rhs(t, u)
44
44
 
45
45
 
46
46
  @dataclass
@@ -68,8 +68,8 @@ class PseudoSpectralIMEX(TimeStepper):
68
68
  def order(self) -> int:
69
69
  return 1
70
70
 
71
- def step(self, u: State, t: float) -> State:
72
- dc = self.pad(self.problem.rhs(u, t))
71
+ def step(self, t: float, u: State) -> State:
72
+ dc = self.pad(self.problem.rhs(t, u))
73
73
  dc_fft = self._fft_prefac * self.problem.vg.rfftn(dc, dc.shape)
74
74
  update = self.problem.vg.irfftn(dc_fft, dc.shape)[:,:u.shape[1]]
75
75
  return u + update