evoxels 0.1.1__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evoxels-0.1.1 → evoxels-1.0.0}/LICENSE +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/PKG-INFO +5 -4
- {evoxels-0.1.1 → evoxels-1.0.0}/README.md +3 -3
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/__init__.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/boundary_conditions.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/fd_stencils.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/function_approximators.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/inversion.py +1 -1
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/precompiled_solvers/__init__.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/precompiled_solvers/allen_cahn.py +1 -1
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/precompiled_solvers/cahn_hilliard.py +1 -1
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/problem_definition.py +31 -29
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/profiler.py +27 -10
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/solvers.py +113 -46
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/timesteppers.py +6 -6
- evoxels-1.0.0/evoxels/utils.py +445 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/voxelfields.py +23 -13
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels/voxelgrid.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/PKG-INFO +5 -4
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/SOURCES.txt +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/dependency_links.txt +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/requires.txt +1 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/evoxels.egg-info/top_level.txt +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/pyproject.toml +3 -2
- {evoxels-0.1.1 → evoxels-1.0.0}/setup.cfg +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_fields.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_inversion.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_laplace.py +0 -0
- {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_rhs.py +1 -1
- {evoxels-0.1.1 → evoxels-1.0.0}/tests/test_solvers.py +1 -1
- evoxels-0.1.1/evoxels/utils.py +0 -124
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: evoxels
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Differentiable physics framework for voxel-based microstructure simulations
|
|
5
5
|
Author-email: Simon Daubner <s.daubner@imperial.ac.uk>
|
|
6
6
|
License: MIT
|
|
@@ -40,6 +40,7 @@ Provides-Extra: notebooks
|
|
|
40
40
|
Requires-Dist: ipywidgets; extra == "notebooks"
|
|
41
41
|
Requires-Dist: ipympl; extra == "notebooks"
|
|
42
42
|
Requires-Dist: notebook; extra == "notebooks"
|
|
43
|
+
Requires-Dist: taufactor; extra == "notebooks"
|
|
43
44
|
Dynamic: license-file
|
|
44
45
|
|
|
45
46
|
[](https://github.com/daubners/evoxels/actions/workflows/python-package.yml)
|
|
@@ -81,7 +82,7 @@ TL;DR
|
|
|
81
82
|
```bash
|
|
82
83
|
conda create --name voxenv python=3.12
|
|
83
84
|
conda activate voxenv
|
|
84
|
-
pip install evoxels[torch,jax,dev,notebooks]
|
|
85
|
+
pip install "evoxels[torch,jax,dev,notebooks]"
|
|
85
86
|
pip install --upgrade "jax[cuda12]"
|
|
86
87
|
```
|
|
87
88
|
|
|
@@ -103,7 +104,7 @@ Navigate to the evoxels folder, then
|
|
|
103
104
|
```
|
|
104
105
|
pip install -e .[torch] # install with torch backend
|
|
105
106
|
pip install -e .[jax] # install with jax backend
|
|
106
|
-
pip install -e .[dev,
|
|
107
|
+
pip install -e .[dev,notebooks] # install testing and notebooks
|
|
107
108
|
```
|
|
108
109
|
Note that the default `[jax]` installation is only CPU compatible. To install the corresponding CUDA libraries check your CUDA version with
|
|
109
110
|
```bash
|
|
@@ -115,7 +116,7 @@ pip install -U "jax[cuda12]"
|
|
|
115
116
|
```
|
|
116
117
|
To install both backends within one environment it is important to install torch first and then upgrade the `jax` installation e.g.
|
|
117
118
|
```bash
|
|
118
|
-
pip install evoxels[torch,
|
|
119
|
+
pip install "evoxels[torch,jax,dev,notebooks]"
|
|
119
120
|
pip install --upgrade "jax[cuda12]"
|
|
120
121
|
```
|
|
121
122
|
To work with the example notebooks install Jupyter and all notebook related dependencies via
|
|
@@ -37,7 +37,7 @@ TL;DR
|
|
|
37
37
|
```bash
|
|
38
38
|
conda create --name voxenv python=3.12
|
|
39
39
|
conda activate voxenv
|
|
40
|
-
pip install evoxels[torch,jax,dev,notebooks]
|
|
40
|
+
pip install "evoxels[torch,jax,dev,notebooks]"
|
|
41
41
|
pip install --upgrade "jax[cuda12]"
|
|
42
42
|
```
|
|
43
43
|
|
|
@@ -59,7 +59,7 @@ Navigate to the evoxels folder, then
|
|
|
59
59
|
```
|
|
60
60
|
pip install -e .[torch] # install with torch backend
|
|
61
61
|
pip install -e .[jax] # install with jax backend
|
|
62
|
-
pip install -e .[dev,
|
|
62
|
+
pip install -e .[dev,notebooks] # install testing and notebooks
|
|
63
63
|
```
|
|
64
64
|
Note that the default `[jax]` installation is only CPU compatible. To install the corresponding CUDA libraries check your CUDA version with
|
|
65
65
|
```bash
|
|
@@ -71,7 +71,7 @@ pip install -U "jax[cuda12]"
|
|
|
71
71
|
```
|
|
72
72
|
To install both backends within one environment it is important to install torch first and then upgrade the `jax` installation e.g.
|
|
73
73
|
```bash
|
|
74
|
-
pip install evoxels[torch,
|
|
74
|
+
pip install "evoxels[torch,jax,dev,notebooks]"
|
|
75
75
|
pip install --upgrade "jax[cuda12]"
|
|
76
76
|
```
|
|
77
77
|
To work with the example notebooks install Jupyter and all notebook related dependencies via
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -72,7 +72,7 @@ class InversionModel:
|
|
|
72
72
|
solver = PseudoSpectralIMEX_dfx(problem.fourier_symbol)
|
|
73
73
|
|
|
74
74
|
solution = dfx.diffeqsolve(
|
|
75
|
-
dfx.ODETerm(lambda t, y, args: problem.rhs(
|
|
75
|
+
dfx.ODETerm(lambda t, y, args: problem.rhs(t, y)),
|
|
76
76
|
solver,
|
|
77
77
|
t0=saveat.subs.ts[0],
|
|
78
78
|
t1=saveat.subs.ts[-1],
|
|
File without changes
|
|
@@ -23,7 +23,7 @@ def run_allen_cahn_solver(
|
|
|
23
23
|
plot_bounds = None,
|
|
24
24
|
):
|
|
25
25
|
"""
|
|
26
|
-
|
|
26
|
+
Solves time-dependent Allen-Cahn problem with ForwardEuler timestepper.
|
|
27
27
|
"""
|
|
28
28
|
solver = TimeDependentSolver(
|
|
29
29
|
voxelfields,
|
|
@@ -20,7 +20,7 @@ def run_cahn_hilliard_solver(
|
|
|
20
20
|
plot_bounds = None,
|
|
21
21
|
):
|
|
22
22
|
"""
|
|
23
|
-
|
|
23
|
+
Solves time-dependent Cahn-Hilliard problem with PseudoSpectralIMEX timestepper.
|
|
24
24
|
"""
|
|
25
25
|
solver = TimeDependentSolver(
|
|
26
26
|
voxelfields,
|
|
@@ -13,17 +13,17 @@ _i_ = slice(1, -1) # inner elements [1:-1]
|
|
|
13
13
|
class ODE(ABC):
|
|
14
14
|
@property
|
|
15
15
|
@abstractmethod
|
|
16
|
-
def order(self):
|
|
16
|
+
def order(self) -> int:
|
|
17
17
|
"""Spatial order of convergence for numerical right-hand side."""
|
|
18
18
|
pass
|
|
19
19
|
|
|
20
20
|
@abstractmethod
|
|
21
|
-
def rhs_analytic(self,
|
|
21
|
+
def rhs_analytic(self, t, u):
|
|
22
22
|
"""Sympy expression of the problem right-hand side.
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
-
u : Sympy function of current state.
|
|
26
25
|
t (float): Current time.
|
|
26
|
+
u : Sympy function of current state.
|
|
27
27
|
|
|
28
28
|
Returns:
|
|
29
29
|
Sympy function of problem right-hand side.
|
|
@@ -31,12 +31,12 @@ class ODE(ABC):
|
|
|
31
31
|
pass
|
|
32
32
|
|
|
33
33
|
@abstractmethod
|
|
34
|
-
def rhs(self,
|
|
34
|
+
def rhs(self, t, u):
|
|
35
35
|
"""Numerical right-hand side of the ODE system.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
u (array): Current state.
|
|
39
38
|
t (float): Current time.
|
|
39
|
+
u (array): Current state.
|
|
40
40
|
|
|
41
41
|
Returns:
|
|
42
42
|
Same type as ``u`` containing the time derivative.
|
|
@@ -120,9 +120,11 @@ class ReactionDiffusion(SemiLinearODE):
|
|
|
120
120
|
bc_fun = self.vg.bc.pad_zero_flux_periodic
|
|
121
121
|
self.pad_boundary = lambda field, bc0, bc1: bc_fun(field)
|
|
122
122
|
k_squared = self.vg.fft_k_squared_nonperiodic()
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(f"Unsupported BC type: {self.BC_type}")
|
|
123
125
|
|
|
124
126
|
self._fourier_symbol = -self.D * self.A * k_squared
|
|
125
|
-
|
|
127
|
+
|
|
126
128
|
@property
|
|
127
129
|
def order(self):
|
|
128
130
|
return 2
|
|
@@ -130,14 +132,14 @@ class ReactionDiffusion(SemiLinearODE):
|
|
|
130
132
|
@property
|
|
131
133
|
def fourier_symbol(self):
|
|
132
134
|
return self._fourier_symbol
|
|
133
|
-
|
|
134
|
-
def _eval_f(self,
|
|
135
|
+
|
|
136
|
+
def _eval_f(self, t, c, lib):
|
|
135
137
|
"""Evaluate source/forcing term using ``self.f``."""
|
|
136
138
|
try:
|
|
137
|
-
return self.f(
|
|
139
|
+
return self.f(t, c, lib)
|
|
138
140
|
except TypeError:
|
|
139
|
-
return self.f(
|
|
140
|
-
|
|
141
|
+
return self.f(t, c)
|
|
142
|
+
|
|
141
143
|
@property
|
|
142
144
|
def bc_type(self):
|
|
143
145
|
return self.BC_type
|
|
@@ -145,12 +147,12 @@ class ReactionDiffusion(SemiLinearODE):
|
|
|
145
147
|
def pad_bc(self, u):
|
|
146
148
|
return self.pad_boundary(u, self.bcs[0], self.bcs[1])
|
|
147
149
|
|
|
148
|
-
def rhs_analytic(self,
|
|
149
|
-
return self.D*spv.laplacian(u) + self._eval_f(
|
|
150
|
-
|
|
151
|
-
def rhs(self,
|
|
150
|
+
def rhs_analytic(self, t, u):
|
|
151
|
+
return self.D*spv.laplacian(u) + self._eval_f(t, u, sp)
|
|
152
|
+
|
|
153
|
+
def rhs(self, t, u):
|
|
152
154
|
laplace = self.vg.laplace(self.pad_bc(u))
|
|
153
|
-
update = self.D * laplace + self._eval_f(
|
|
155
|
+
update = self.D * laplace + self._eval_f(t, u, self.vg.lib)
|
|
154
156
|
return update
|
|
155
157
|
|
|
156
158
|
@dataclass
|
|
@@ -185,15 +187,15 @@ class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
|
|
|
185
187
|
def pad_bc(self, u):
|
|
186
188
|
return self.pad_boundary(u, self.bcs[0], self.bcs[1])
|
|
187
189
|
|
|
188
|
-
def rhs_analytic(self,
|
|
189
|
-
|
|
190
|
-
|
|
190
|
+
def rhs_analytic(self, t, u, mask):
|
|
191
|
+
grad_m = spv.gradient(mask)
|
|
192
|
+
norm_grad_m = sp.sqrt(grad_m.dot(grad_m))
|
|
191
193
|
|
|
192
|
-
divergence = spv.divergence(self.D*(
|
|
193
|
-
du = divergence +
|
|
194
|
+
divergence = spv.divergence(self.D*(spv.gradient(u) - u/mask*grad_m))
|
|
195
|
+
du = divergence + norm_grad_m*self.bc_flux + mask*self._eval_f(t, u/mask, sp)
|
|
194
196
|
return du
|
|
195
197
|
|
|
196
|
-
def rhs(self,
|
|
198
|
+
def rhs(self, t, u):
|
|
197
199
|
z = self.pad_bc(u)
|
|
198
200
|
divergence = self.vg.grad_x_face(self.vg.grad_x_face(z) -\
|
|
199
201
|
self.vg.to_x_face(z/self.mask) * self.vg.grad_x_face(self.mask)
|
|
@@ -207,7 +209,7 @@ class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
|
|
|
207
209
|
|
|
208
210
|
update = self.D * divergence + \
|
|
209
211
|
self.norm*self.bc_flux + \
|
|
210
|
-
self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(u/self.mask[:,1:-1,1:-1,1:-1],
|
|
212
|
+
self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(t, u/self.mask[:,1:-1,1:-1,1:-1], self.vg.lib)
|
|
211
213
|
return update
|
|
212
214
|
|
|
213
215
|
|
|
@@ -249,13 +251,13 @@ class PeriodicCahnHilliard(SemiLinearODE):
|
|
|
249
251
|
except TypeError:
|
|
250
252
|
return self.mu_hom(c)
|
|
251
253
|
|
|
252
|
-
def rhs_analytic(self,
|
|
254
|
+
def rhs_analytic(self, t, c):
|
|
253
255
|
mu = self._eval_mu(c, sp) - 2*self.eps*spv.laplacian(c)
|
|
254
256
|
fluxes = self.D*c*(1-c)*spv.gradient(mu)
|
|
255
257
|
rhs = spv.divergence(fluxes)
|
|
256
258
|
return rhs
|
|
257
259
|
|
|
258
|
-
def rhs(self,
|
|
260
|
+
def rhs(self, t, c):
|
|
259
261
|
r"""Evaluate :math:`\partial c / \partial t` for the CH equation.
|
|
260
262
|
|
|
261
263
|
Numerical computation of
|
|
@@ -341,7 +343,7 @@ class AllenCahnEquation(SemiLinearODE):
|
|
|
341
343
|
except TypeError:
|
|
342
344
|
return self.potential(phi)
|
|
343
345
|
|
|
344
|
-
def rhs_analytic(self,
|
|
346
|
+
def rhs_analytic(self, t, phi):
|
|
345
347
|
grad = spv.gradient(phi)
|
|
346
348
|
laplace = spv.laplacian(phi)
|
|
347
349
|
norm_grad = sp.sqrt(grad.dot(grad))
|
|
@@ -354,7 +356,7 @@ class AllenCahnEquation(SemiLinearODE):
|
|
|
354
356
|
+ 3/self.eps * phi * (1-phi) * self.force
|
|
355
357
|
return self.M * df_dphi
|
|
356
358
|
|
|
357
|
-
def rhs(self,
|
|
359
|
+
def rhs(self, t, phi):
|
|
358
360
|
r"""Two-phase Allen-Cahn equation
|
|
359
361
|
|
|
360
362
|
Microstructural evolution of the order parameter ``\phi``
|
|
@@ -422,13 +424,13 @@ class CoupledReactionDiffusion(SemiLinearODE):
|
|
|
422
424
|
except TypeError:
|
|
423
425
|
return self.interaction(u)
|
|
424
426
|
|
|
425
|
-
def rhs_analytic(self,
|
|
427
|
+
def rhs_analytic(self, t, u):
|
|
426
428
|
interaction = self._eval_interaction(u, sp)
|
|
427
429
|
dc_A = self.D_A*spv.laplacian(u[0]) - interaction + self.feed * (1-u[0])
|
|
428
430
|
dc_B = self.D_B*spv.laplacian(u[1]) + interaction - self.kill * u[1]
|
|
429
431
|
return (dc_A, dc_B)
|
|
430
432
|
|
|
431
|
-
def rhs(self,
|
|
433
|
+
def rhs(self, t, u):
|
|
432
434
|
r"""Two-component reaction-diffusion system
|
|
433
435
|
|
|
434
436
|
Use batch channels for multiple species:
|
|
@@ -3,12 +3,20 @@ import psutil
|
|
|
3
3
|
import os
|
|
4
4
|
import subprocess
|
|
5
5
|
import tracemalloc
|
|
6
|
+
import shutil
|
|
6
7
|
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
9
|
class MemoryProfiler(ABC):
|
|
9
10
|
"""Base interface for tracking host and device memory usage."""
|
|
11
|
+
def __init__(self):
|
|
12
|
+
self.max_used_cpu = 0.0
|
|
13
|
+
self.max_used_gpu = 0.0
|
|
14
|
+
self.track_gpu = False # subclasses set this
|
|
15
|
+
|
|
10
16
|
def get_cuda_memory_from_nvidia_smi(self):
|
|
11
17
|
"""Return currently used CUDA memory in megabytes."""
|
|
18
|
+
if shutil.which("nvidia-smi") is None:
|
|
19
|
+
return None
|
|
12
20
|
try:
|
|
13
21
|
output = subprocess.check_output(
|
|
14
22
|
['nvidia-smi', '--query-gpu=memory.used',
|
|
@@ -23,8 +31,10 @@ class MemoryProfiler(ABC):
|
|
|
23
31
|
process = psutil.Process(os.getpid())
|
|
24
32
|
used_cpu = process.memory_info().rss / 1024**2
|
|
25
33
|
self.max_used_cpu = np.max((self.max_used_cpu, used_cpu))
|
|
26
|
-
|
|
27
|
-
|
|
34
|
+
if self.track_gpu:
|
|
35
|
+
used = self.get_cuda_memory_from_nvidia_smi()
|
|
36
|
+
if used is not None:
|
|
37
|
+
self.max_used_gpu = np.max((self.max_used_gpu, used))
|
|
28
38
|
|
|
29
39
|
@abstractmethod
|
|
30
40
|
def print_memory_stats(self, start: float, end: float, iters: int):
|
|
@@ -35,13 +45,14 @@ class TorchMemoryProfiler(MemoryProfiler):
|
|
|
35
45
|
def __init__(self, device):
|
|
36
46
|
"""Initialize the profiler for a given torch device."""
|
|
37
47
|
import torch
|
|
48
|
+
super().__init__()
|
|
38
49
|
self.torch = torch
|
|
39
50
|
self.device = device
|
|
51
|
+
self.track_gpu = (device.type == 'cuda')
|
|
52
|
+
|
|
40
53
|
tracemalloc.start()
|
|
41
|
-
if
|
|
54
|
+
if self.track_gpu:
|
|
42
55
|
torch.cuda.reset_peak_memory_stats(device=device)
|
|
43
|
-
self.max_used_gpu = 0
|
|
44
|
-
self.max_used_cpu = 0
|
|
45
56
|
|
|
46
57
|
def print_memory_stats(self, start, end, iters):
|
|
47
58
|
"""Print usage statistics for the Torch backend."""
|
|
@@ -60,7 +71,10 @@ class TorchMemoryProfiler(MemoryProfiler):
|
|
|
60
71
|
elif self.device.type == 'cuda':
|
|
61
72
|
self.update_memory_stats()
|
|
62
73
|
used = self.get_cuda_memory_from_nvidia_smi()
|
|
63
|
-
|
|
74
|
+
if used is None:
|
|
75
|
+
print("GPU-RAM (nvidia-smi) unavailable.")
|
|
76
|
+
else:
|
|
77
|
+
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
|
|
64
78
|
print(f"GPU-RAM (torch) current: "
|
|
65
79
|
f"{self.torch.cuda.memory_allocated(self.device) / 1024**2:.2f} MB "
|
|
66
80
|
f"({self.torch.cuda.max_memory_allocated(self.device) / 1024**2:.2f} MB max, "
|
|
@@ -70,9 +84,9 @@ class JAXMemoryProfiler(MemoryProfiler):
|
|
|
70
84
|
def __init__(self):
|
|
71
85
|
"""Initialize the profiler for JAX."""
|
|
72
86
|
import jax
|
|
87
|
+
super().__init__()
|
|
73
88
|
self.jax = jax
|
|
74
|
-
self.
|
|
75
|
-
self.max_used_cpu = 0
|
|
89
|
+
self.track_gpu = any(d.platform == "gpu" for d in jax.devices())
|
|
76
90
|
tracemalloc.start()
|
|
77
91
|
|
|
78
92
|
def print_memory_stats(self, start, end, iters):
|
|
@@ -88,7 +102,10 @@ class JAXMemoryProfiler(MemoryProfiler):
|
|
|
88
102
|
current = process.memory_info().rss / 1024**2
|
|
89
103
|
print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")
|
|
90
104
|
|
|
91
|
-
if self.
|
|
105
|
+
if self.track_gpu:
|
|
92
106
|
self.update_memory_stats()
|
|
93
107
|
used = self.get_cuda_memory_from_nvidia_smi()
|
|
94
|
-
|
|
108
|
+
if used is None:
|
|
109
|
+
print("GPU-RAM (nvidia-smi) unavailable.")
|
|
110
|
+
else:
|
|
111
|
+
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
|
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
from IPython.display import clear_output
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from typing import Callable, Any, Type
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
4
5
|
from timeit import default_timer as timer
|
|
5
6
|
import sys
|
|
6
7
|
from .problem_definition import ODE
|
|
7
8
|
from .timesteppers import TimeStepper
|
|
8
9
|
|
|
9
10
|
@dataclass
|
|
10
|
-
class
|
|
11
|
+
class BaseSolver(ABC):
|
|
11
12
|
"""Generic wrapper for solving one or more fields with a time stepper."""
|
|
12
13
|
vf: Any # VoxelFields object
|
|
13
14
|
fieldnames: str | list[str]
|
|
@@ -33,46 +34,24 @@ class TimeDependentSolver:
|
|
|
33
34
|
self.profiler = JAXMemoryProfiler()
|
|
34
35
|
else:
|
|
35
36
|
raise ValueError(f"Unsupported backend: {self.backend}")
|
|
36
|
-
|
|
37
|
-
def solve(
|
|
38
|
-
self,
|
|
39
|
-
time_increment=0.1,
|
|
40
|
-
frames=10,
|
|
41
|
-
max_iters=100,
|
|
42
|
-
problem_kwargs=None,
|
|
43
|
-
jit=True,
|
|
44
|
-
verbose=True,
|
|
45
|
-
vtk_out=False,
|
|
46
|
-
plot_bounds=None,
|
|
47
|
-
colormap='viridis'
|
|
48
|
-
):
|
|
49
|
-
"""Run the time integration loop.
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
time_increment (float): Size of a single time step.
|
|
53
|
-
frames (int): Number of output frames (for plotting, vtk, checks).
|
|
54
|
-
max_iters (int): Number of time steps to compute.
|
|
55
|
-
problem_kwargs (dict | None): Problem-specific input arguments.
|
|
56
|
-
jit (bool): Create just-in-time compiled kernel if ``True``
|
|
57
|
-
verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
|
|
58
|
-
updates an interactive plot.
|
|
59
|
-
vtk_out (bool): Write VTK files for each frame if ``True``.
|
|
60
|
-
plot_bounds (tuple | None): Optional value range for plots.
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
problem_kwargs = problem_kwargs or {}
|
|
37
|
+
|
|
64
38
|
if isinstance(self.fieldnames, str):
|
|
65
39
|
self.fieldnames = [self.fieldnames]
|
|
66
40
|
else:
|
|
67
41
|
self.fieldnames = list(self.fieldnames)
|
|
68
42
|
|
|
43
|
+
def _init_fields(self):
|
|
44
|
+
"""Initialize fields in the voxel grid."""
|
|
69
45
|
u_list = [self.vg.init_scalar_field(self.vf.fields[name]) for name in self.fieldnames]
|
|
70
46
|
u = self.vg.concatenate(u_list, 0)
|
|
71
47
|
u = self.vg.bc.trim_boundary_nodes(u)
|
|
72
|
-
|
|
48
|
+
return u
|
|
49
|
+
|
|
50
|
+
def _init_stepper(self, time_increment, problem_kwargs, jit):
|
|
51
|
+
problem_kwargs = problem_kwargs or {}
|
|
73
52
|
if self.step_fn is not None:
|
|
74
|
-
step = self.step_fn
|
|
75
53
|
self.problem = None
|
|
54
|
+
step = self.step_fn
|
|
76
55
|
else:
|
|
77
56
|
if self.problem_cls is None or self.timestepper_cls is None:
|
|
78
57
|
raise ValueError("Either provide step_fn or both problem_cls and timestepper_cls")
|
|
@@ -88,23 +67,47 @@ class TimeDependentSolver:
|
|
|
88
67
|
import torch
|
|
89
68
|
step = torch.compile(step)
|
|
90
69
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
70
|
+
return step
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def _run_loop(self, u, step, time_increment, frames, max_iters,
|
|
74
|
+
vtk_out, verbose, plot_bounds, colormap):
|
|
75
|
+
"""Abstract method for running the time integration loop."""
|
|
76
|
+
raise NotImplementedError("Subclasses must implement _run_loop method.")
|
|
94
77
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
78
|
+
def solve(
|
|
79
|
+
self,
|
|
80
|
+
time_increment=0.1,
|
|
81
|
+
frames=10,
|
|
82
|
+
max_iters=100,
|
|
83
|
+
problem_kwargs=None,
|
|
84
|
+
jit=True,
|
|
85
|
+
verbose=True,
|
|
86
|
+
vtk_out=False,
|
|
87
|
+
plot_bounds=None,
|
|
88
|
+
colormap='viridis'
|
|
89
|
+
):
|
|
90
|
+
"""Run the time integration loop.
|
|
101
91
|
|
|
102
|
-
|
|
92
|
+
Args:
|
|
93
|
+
time_increment (float): Size of a single time step.
|
|
94
|
+
frames (int): Number of output frames (for plotting, vtk, checks).
|
|
95
|
+
max_iters (int): Number of time steps to compute.
|
|
96
|
+
problem_kwargs (dict | None): Problem-specific input arguments.
|
|
97
|
+
jit (bool): Create just-in-time compiled kernel if ``True``
|
|
98
|
+
verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
|
|
99
|
+
updates an interactive plot.
|
|
100
|
+
vtk_out (bool): Write VTK files for each frame if ``True``.
|
|
101
|
+
plot_bounds (tuple | None): Optional value range for plots.
|
|
102
|
+
"""
|
|
103
|
+
u = self._init_fields()
|
|
104
|
+
step = self._init_stepper(time_increment, problem_kwargs, jit)
|
|
103
105
|
|
|
106
|
+
start = timer()
|
|
107
|
+
u = self._run_loop(u, step, time_increment, frames, max_iters,
|
|
108
|
+
vtk_out, verbose, plot_bounds, colormap)
|
|
104
109
|
end = timer()
|
|
105
|
-
|
|
106
|
-
self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
|
|
107
|
-
|
|
110
|
+
self.computation_time = end - start
|
|
108
111
|
if verbose:
|
|
109
112
|
self.profiler.print_memory_stats(start, end, max_iters)
|
|
110
113
|
|
|
@@ -113,10 +116,10 @@ class TimeDependentSolver:
|
|
|
113
116
|
if getattr(self, 'problem', None) is not None:
|
|
114
117
|
u_out = self.vg.bc.trim_ghost_nodes(self.problem.pad_bc(u))
|
|
115
118
|
else:
|
|
116
|
-
u_out = u
|
|
119
|
+
u_out = self.vg.bc.trim_ghost_nodes(self.vg.pad_zeros(u))
|
|
117
120
|
|
|
118
121
|
for i, name in enumerate(self.fieldnames):
|
|
119
|
-
self.vf.
|
|
122
|
+
self.vf.set_field(name, self.vg.export_scalar_field_to_numpy(u_out[i:i+1]))
|
|
120
123
|
|
|
121
124
|
if verbose:
|
|
122
125
|
self.profiler.update_memory_stats()
|
|
@@ -129,6 +132,70 @@ class TimeDependentSolver:
|
|
|
129
132
|
filename = self.problem_cls.__name__ + "_" +\
|
|
130
133
|
self.fieldnames[0] + f"_{frame:03d}.vtk"
|
|
131
134
|
self.vf.export_to_vtk(filename=filename, field_names=self.fieldnames)
|
|
135
|
+
|
|
132
136
|
if verbose == 'plot':
|
|
133
137
|
clear_output(wait=True)
|
|
134
138
|
self.vf.plot_slice(self.fieldnames[0], slice_idx, time=time, colormap=colormap, value_bounds=plot_bounds)
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class TimeDependentSolver(BaseSolver):
|
|
142
|
+
"""Solver for time-dependent problems."""
|
|
143
|
+
def _run_loop(self, u, step, time_increment, frames, max_iters,
|
|
144
|
+
vtk_out, verbose, plot_bounds, colormap):
|
|
145
|
+
n_out = max_iters // frames
|
|
146
|
+
frame = 0
|
|
147
|
+
slice_idx = self.vf.Nz // 2
|
|
148
|
+
|
|
149
|
+
for i in range(max_iters):
|
|
150
|
+
time = i * time_increment
|
|
151
|
+
if i % n_out == 0:
|
|
152
|
+
self._handle_outputs(u, frame, time, slice_idx, vtk_out,
|
|
153
|
+
verbose, plot_bounds, colormap)
|
|
154
|
+
frame += 1
|
|
155
|
+
|
|
156
|
+
u = step(time, u)
|
|
157
|
+
time = max_iters * time_increment
|
|
158
|
+
self._handle_outputs(u, frame, time, slice_idx, vtk_out,
|
|
159
|
+
verbose, plot_bounds, colormap)
|
|
160
|
+
return u
|
|
161
|
+
|
|
162
|
+
@dataclass
|
|
163
|
+
class SteadyStatePseudoTimeSolver(BaseSolver):
|
|
164
|
+
"""Solver for steady-state problems."""
|
|
165
|
+
conv_crit: float = 1e-6
|
|
166
|
+
check_freq: int = 10
|
|
167
|
+
|
|
168
|
+
def _run_loop(self, u, step, time_increment, frames, max_iters,
|
|
169
|
+
vtk_out, verbose, plot_bounds, colormap):
|
|
170
|
+
slice_idx = self.vf.Nz // 2
|
|
171
|
+
self.converged = False
|
|
172
|
+
self.iter = 0
|
|
173
|
+
|
|
174
|
+
while not self.converged and self.iter < max_iters:
|
|
175
|
+
time = self.iter * time_increment
|
|
176
|
+
diff = u - step(time, u)
|
|
177
|
+
u = step(time, u)
|
|
178
|
+
|
|
179
|
+
if self.iter % self.check_freq == 0:
|
|
180
|
+
self.converged = self.check_convergence(diff, verbose)
|
|
181
|
+
|
|
182
|
+
self._handle_outputs(u, 0, time, slice_idx, vtk_out,
|
|
183
|
+
verbose, plot_bounds, colormap)
|
|
184
|
+
return u
|
|
185
|
+
|
|
186
|
+
def check_convergence(self, diff, verbose):
|
|
187
|
+
"""Check for convergence based on relative change in fields."""
|
|
188
|
+
converged = True
|
|
189
|
+
for i, name in enumerate(self.fieldnames):
|
|
190
|
+
# Check if Frobenius norm of change is below threshold
|
|
191
|
+
rel_change = self.vg.lib.linalg.norm(diff[i]) / \
|
|
192
|
+
self.vg.lib.sqrt(self.vf.Nx * self.vf.Ny * self.vf.Nz)
|
|
193
|
+
if rel_change > self.conv_crit:
|
|
194
|
+
converged = False
|
|
195
|
+
if verbose:
|
|
196
|
+
print(f"Iter {self.iter}: Field '{name}' relative change: {rel_change:.2e}")
|
|
197
|
+
|
|
198
|
+
if converged and verbose:
|
|
199
|
+
print(f"Converged after {self.iter} iterations.")
|
|
200
|
+
|
|
201
|
+
return converged
|
|
@@ -16,13 +16,13 @@ class TimeStepper(ABC):
|
|
|
16
16
|
pass
|
|
17
17
|
|
|
18
18
|
@abstractmethod
|
|
19
|
-
def step(self,
|
|
19
|
+
def step(self, t: float, u: State) -> State:
|
|
20
20
|
"""
|
|
21
21
|
Take one timestep from t to (t+dt).
|
|
22
22
|
|
|
23
23
|
Args:
|
|
24
|
-
u : Current state
|
|
25
24
|
t : Current time
|
|
25
|
+
u : Current state
|
|
26
26
|
Returns:
|
|
27
27
|
Updated state at t + dt.
|
|
28
28
|
"""
|
|
@@ -39,8 +39,8 @@ class ForwardEuler(TimeStepper):
|
|
|
39
39
|
def order(self) -> int:
|
|
40
40
|
return 1
|
|
41
41
|
|
|
42
|
-
def step(self,
|
|
43
|
-
return u + self.dt * self.problem.rhs(
|
|
42
|
+
def step(self, t: float, u: State) -> State:
|
|
43
|
+
return u + self.dt * self.problem.rhs(t, u)
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
@dataclass
|
|
@@ -68,8 +68,8 @@ class PseudoSpectralIMEX(TimeStepper):
|
|
|
68
68
|
def order(self) -> int:
|
|
69
69
|
return 1
|
|
70
70
|
|
|
71
|
-
def step(self,
|
|
72
|
-
dc = self.pad(self.problem.rhs(
|
|
71
|
+
def step(self, t: float, u: State) -> State:
|
|
72
|
+
dc = self.pad(self.problem.rhs(t, u))
|
|
73
73
|
dc_fft = self._fft_prefac * self.problem.vg.rfftn(dc, dc.shape)
|
|
74
74
|
update = self.problem.vg.irfftn(dc_fft, dc.shape)[:,:u.shape[1]]
|
|
75
75
|
return u + update
|