evoxels 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evoxels/__init__.py +13 -0
- evoxels/boundary_conditions.py +138 -0
- evoxels/fd_stencils.py +103 -0
- evoxels/function_approximators.py +97 -0
- evoxels/inversion.py +233 -0
- evoxels/precompiled_solvers/__init__.py +1 -0
- evoxels/precompiled_solvers/allen_cahn.py +50 -0
- evoxels/precompiled_solvers/cahn_hilliard.py +42 -0
- evoxels/problem_definition.py +450 -0
- evoxels/profiler.py +94 -0
- evoxels/solvers.py +134 -0
- evoxels/timesteppers.py +119 -0
- evoxels/utils.py +124 -0
- evoxels/voxelfields.py +318 -0
- evoxels/voxelgrid.py +278 -0
- evoxels-0.1.0.dist-info/METADATA +171 -0
- evoxels-0.1.0.dist-info/RECORD +20 -0
- evoxels-0.1.0.dist-info/WHEEL +5 -0
- evoxels-0.1.0.dist-info/licenses/LICENSE +21 -0
- evoxels-0.1.0.dist-info/top_level.txt +1 -0
evoxels/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Public API for the evoxels package."""
|
|
2
|
+
|
|
3
|
+
from .voxelfields import VoxelFields
|
|
4
|
+
from .precompiled_solvers.cahn_hilliard import (run_cahn_hilliard_solver)
|
|
5
|
+
from .precompiled_solvers.allen_cahn import (run_allen_cahn_solver)
|
|
6
|
+
from .inversion import InversionModel
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"VoxelFields",
|
|
10
|
+
"run_cahn_hilliard_solver",
|
|
11
|
+
"run_allen_cahn_solver",
|
|
12
|
+
"InversionModel"
|
|
13
|
+
]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# Shorthands in slicing logic
|
|
2
|
+
__ = slice(None) # all elements [:]
|
|
3
|
+
_i_ = slice(1, -1) # inner elements [1:-1]
|
|
4
|
+
|
|
5
|
+
class CellCenteredBCs:
|
|
6
|
+
def __init__(self, vg):
|
|
7
|
+
self.vg = vg
|
|
8
|
+
|
|
9
|
+
def pad_periodic(self, field):
|
|
10
|
+
"""
|
|
11
|
+
Periodic boundary conditions in all directions.
|
|
12
|
+
Consistent with cell centered grid.
|
|
13
|
+
"""
|
|
14
|
+
return self.vg.pad_periodic(field)
|
|
15
|
+
|
|
16
|
+
def pad_dirichlet_periodic(self, field, bc0=0, bc1=0):
|
|
17
|
+
"""
|
|
18
|
+
Homogenous Dirichlet boundary conditions in x-drection,
|
|
19
|
+
periodic in y- and z-direction. Consistent with cell centered grid,
|
|
20
|
+
but loss of 2nd order convergence.
|
|
21
|
+
"""
|
|
22
|
+
padded = self.vg.pad_periodic(field)
|
|
23
|
+
padded = self.vg.set(padded, (__, 0,__,__), 2.0*bc0 - padded[:, 1,:,:])
|
|
24
|
+
padded = self.vg.set(padded, (__,-1,__,__), 2.0*bc1 - padded[:,-2,:,:])
|
|
25
|
+
return padded
|
|
26
|
+
|
|
27
|
+
def pad_zero_flux_periodic(self, field):
|
|
28
|
+
padded = self.vg.pad_periodic(field)
|
|
29
|
+
padded = self.vg.set(padded, (__, 0,__,__), padded[:, 1,:,:])
|
|
30
|
+
padded = self.vg.set(padded, (__,-1,__,__), padded[:,-2,:,:])
|
|
31
|
+
return padded
|
|
32
|
+
|
|
33
|
+
def pad_zero_flux(self, field):
|
|
34
|
+
padded = self.vg.pad_zeros(field)
|
|
35
|
+
padded = self.vg.set(padded, (__, 0,__,__), padded[:, 1,:,:])
|
|
36
|
+
padded = self.vg.set(padded, (__,-1,__,__), padded[:,-2,:,:])
|
|
37
|
+
padded = self.vg.set(padded, (__,__, 0,__), padded[:,:, 1,:])
|
|
38
|
+
padded = self.vg.set(padded, (__,__,-1,__), padded[:,:,-2,:])
|
|
39
|
+
padded = self.vg.set(padded, (__,__,__, 0), padded[:,:,:, 1])
|
|
40
|
+
padded = self.vg.set(padded, (__,__,__,-1), padded[:,:,:,-2])
|
|
41
|
+
return padded
|
|
42
|
+
|
|
43
|
+
def pad_fft_periodic(self, field):
|
|
44
|
+
"""Periodic field needs no fft padding."""
|
|
45
|
+
return field
|
|
46
|
+
|
|
47
|
+
def pad_fft_dirichlet_periodic(self, field):
|
|
48
|
+
"""Pad with inverse of flipped field in x direction."""
|
|
49
|
+
return self.vg.concatenate((field, -self.vg.lib.flip(field, [0])), 1)
|
|
50
|
+
|
|
51
|
+
def pad_fft_zero_flux_periodic(self, field):
|
|
52
|
+
"""Pad with flipped field in x direction."""
|
|
53
|
+
return self.vg.concatenate((field, self.vg.lib.flip(field, [0])), 1)
|
|
54
|
+
|
|
55
|
+
def trim_boundary_nodes(self, field):
|
|
56
|
+
return field
|
|
57
|
+
|
|
58
|
+
def trim_ghost_nodes(self, field):
|
|
59
|
+
if field[0,_i_,_i_,_i_].shape == self.vg.shape:
|
|
60
|
+
return field[:,_i_,_i_,_i_]
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"The provided field has the wrong shape {self.vg.shape}."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class StaggeredXBCs:
|
|
68
|
+
def __init__(self, vg):
|
|
69
|
+
self.vg = vg
|
|
70
|
+
|
|
71
|
+
def pad_periodic_BC_staggered_x(self, field):
|
|
72
|
+
"""
|
|
73
|
+
If field is fully periodic it should be in
|
|
74
|
+
cell center convention!
|
|
75
|
+
"""
|
|
76
|
+
raise NotImplementedError
|
|
77
|
+
|
|
78
|
+
def pad_dirichlet_periodic(self, field, bc0=0, bc1=0):
|
|
79
|
+
"""
|
|
80
|
+
Homogenous Dirichlet boundary conditions in x-drection,
|
|
81
|
+
periodic in y- and z-direction. Consistent with staggered_x grid,
|
|
82
|
+
maintains 2nd order convergence.
|
|
83
|
+
"""
|
|
84
|
+
padded = self.vg.pad_periodic(field)
|
|
85
|
+
padded = self.vg.set(padded, (__, 0,__,__), bc0)
|
|
86
|
+
padded = self.vg.set(padded, (__,-1,__,__), bc1)
|
|
87
|
+
return padded
|
|
88
|
+
|
|
89
|
+
def pad_zero_flux_periodic(self, field):
|
|
90
|
+
"""
|
|
91
|
+
The following comes out of on interpolation polynomial p with
|
|
92
|
+
p'(0) = 0, p(dx) = f(dx,...), p(2*dx) = f(2*dx,...)
|
|
93
|
+
and then use p(0) for the ghost cell.
|
|
94
|
+
This should be of sufficient order of f'(0) = 0, and even better if
|
|
95
|
+
also f'''(0) = 0 (as it holds for cos(k*pi*x) )
|
|
96
|
+
"""
|
|
97
|
+
padded = self.vg.pad_periodic(field)
|
|
98
|
+
fac1 = 4/3
|
|
99
|
+
fac2 = 1/3
|
|
100
|
+
padded = self.vg.set(padded, (__, 0,__,__), fac1*padded[:, 1,:,:] - fac2*padded[:, 2,:,:])
|
|
101
|
+
padded = self.vg.set(padded, (__,-1,__,__), fac1*padded[:,-2,:,:] - fac2*padded[:,-3,:,:])
|
|
102
|
+
return padded
|
|
103
|
+
|
|
104
|
+
def pad_zero_flux(self, field):
|
|
105
|
+
raise NotImplementedError
|
|
106
|
+
|
|
107
|
+
def pad_fft_periodic(self, field):
|
|
108
|
+
"""
|
|
109
|
+
If field is fully periodic it should be in
|
|
110
|
+
cell center convention!
|
|
111
|
+
"""
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
114
|
+
def pad_fft_dirichlet_periodic(self, field):
|
|
115
|
+
"""Pad with inverse of flipped field in x direction."""
|
|
116
|
+
bc = self.vg.lib.zeros_like(field[:,0:1])
|
|
117
|
+
return self.vg.concatenate((field, bc, -self.vg.lib.flip(field, [0]), bc), 1)
|
|
118
|
+
|
|
119
|
+
def pad_fft_zero_flux_periodic(self, field):
|
|
120
|
+
"""Pad with flipped field in x direction."""
|
|
121
|
+
raise NotImplementedError
|
|
122
|
+
|
|
123
|
+
def trim_boundary_nodes(self, field):
|
|
124
|
+
"""Trim boundary nodes of ``field`` for staggered grids."""
|
|
125
|
+
if field.shape[1] == self.vg.shape[0]:
|
|
126
|
+
return field[:,_i_,:,:]
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"The provided field must have the shape {self.vg.shape}."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def trim_ghost_nodes(self, field):
|
|
133
|
+
if field[0,:,_i_,_i_].shape == self.vg.shape:
|
|
134
|
+
return field[:,:,_i_,_i_]
|
|
135
|
+
else:
|
|
136
|
+
raise ValueError(
|
|
137
|
+
f"The provided field has the wrong shape {self.vg.shape}."
|
|
138
|
+
)
|
evoxels/fd_stencils.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Shorthands in slicing logic
|
|
2
|
+
__ = slice(None) # all elements [:]
|
|
3
|
+
_i_ = slice(1, -1) # inner elements [1:-1]
|
|
4
|
+
|
|
5
|
+
CENTER = (__, _i_, _i_, _i_)
|
|
6
|
+
LEFT = (__, slice(None,-2), _i_, _i_)
|
|
7
|
+
RIGHT = (__, slice(2, None), _i_, _i_)
|
|
8
|
+
BOTTOM = (__, _i_, slice(None,-2), _i_)
|
|
9
|
+
TOP = (__, _i_, slice(2, None), _i_)
|
|
10
|
+
BACK = (__, _i_, _i_, slice(None,-2))
|
|
11
|
+
FRONT = (__, _i_, _i_, slice(2, None))
|
|
12
|
+
|
|
13
|
+
class FDStencils:
|
|
14
|
+
"""Class wrapper for finite difference stencils
|
|
15
|
+
|
|
16
|
+
Is inherited by the VoxelGrid to apply stencils to
|
|
17
|
+
backend arrays.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def to_x_face(self, field):
|
|
21
|
+
"""Interpolate to face position staggered in x"""
|
|
22
|
+
return 0.5 * (field[:,1:,:,:] + field[:,:-1,:,:])
|
|
23
|
+
|
|
24
|
+
def to_y_face(self, field):
|
|
25
|
+
"""Interpolate to face position staggered in y"""
|
|
26
|
+
return 0.5 * (field[:,:,1:,:] + field[:,:,:-1,:])
|
|
27
|
+
|
|
28
|
+
def to_z_face(self, field):
|
|
29
|
+
"""Interpolate to face position staggered in z"""
|
|
30
|
+
return 0.5 * (field[:,:,:,1:] + field[:,:,:,:-1])
|
|
31
|
+
|
|
32
|
+
def grad_x_face(self, field):
|
|
33
|
+
"""Gradient at face position staggered in x"""
|
|
34
|
+
return (field[:,1:,:,:] - field[:,:-1,:,:]) * self.div_dx[0]
|
|
35
|
+
|
|
36
|
+
def grad_y_face(self, field):
|
|
37
|
+
"""Gradient at face position staggered in y"""
|
|
38
|
+
return (field[:,:,1:,:] - field[:,:,:-1,:]) * self.div_dx[1]
|
|
39
|
+
|
|
40
|
+
def grad_z_face(self, field):
|
|
41
|
+
"""Gradient at face position staggered in z"""
|
|
42
|
+
return (field[:,:,:,1:] - field[:,:,:,:-1]) * self.div_dx[2]
|
|
43
|
+
|
|
44
|
+
def grad_x_center(self, field):
|
|
45
|
+
"""Gradient in x at cell center"""
|
|
46
|
+
return 0.5 * (field[RIGHT] - field[LEFT]) * self.div_dx[0]
|
|
47
|
+
|
|
48
|
+
def grad_y_center(self, field):
|
|
49
|
+
"""Gradient in x at cell center"""
|
|
50
|
+
return 0.5 * (field[TOP] - field[BOTTOM]) * self.div_dx[1]
|
|
51
|
+
|
|
52
|
+
def grad_z_center(self, field):
|
|
53
|
+
"""Gradient in x at cell center"""
|
|
54
|
+
return 0.5 * (field[FRONT] - field[BACK]) * self.div_dx[2]
|
|
55
|
+
|
|
56
|
+
def gradient_norm_squared(self, field):
|
|
57
|
+
"""Gradient norm squared at cell centers"""
|
|
58
|
+
return self.grad_x_center(field)**2 +\
|
|
59
|
+
self.grad_y_center(field)**2 + \
|
|
60
|
+
self.grad_z_center(field)**2
|
|
61
|
+
|
|
62
|
+
def laplace(self, field):
|
|
63
|
+
r"""Calculate laplace based on compact 2nd order stencil.
|
|
64
|
+
|
|
65
|
+
Laplace given as $\nabla\cdot(\nabla u)$ which in 3D is given by
|
|
66
|
+
$\partial^2 u/\partial^2 x + \partial^2 u/\partial^2 y+ \partial^2 u/\partial^2 z$
|
|
67
|
+
Returned field has same shape as the input field (padded with zeros)
|
|
68
|
+
"""
|
|
69
|
+
# Manual indexing is ~10x faster than conv3d with laplace kernel in torch
|
|
70
|
+
laplace = \
|
|
71
|
+
(field[RIGHT] + field[LEFT]) * self.div_dx2[0] + \
|
|
72
|
+
(field[TOP] + field[BOTTOM]) * self.div_dx2[1] + \
|
|
73
|
+
(field[FRONT] + field[BACK]) * self.div_dx2[2] - \
|
|
74
|
+
2 * field[CENTER] * self.lib.sum(self.div_dx2)
|
|
75
|
+
return laplace
|
|
76
|
+
|
|
77
|
+
def normal_laplace(self, field):
|
|
78
|
+
r"""Calculate the normal component of the laplacian
|
|
79
|
+
|
|
80
|
+
which is identical to the full laplacian minus curvature.
|
|
81
|
+
It is defined as $\partial^2_n u = \nabla\cdot(\nabla u\cdot n)\cdot n$
|
|
82
|
+
where $n$ denotes the surface normal.
|
|
83
|
+
In the context of phasefield models $n$ is defined as
|
|
84
|
+
$\frac{\nabla u}{|\nabla u|}$.
|
|
85
|
+
The calaculation is based on a compact 2nd order stencil.
|
|
86
|
+
"""
|
|
87
|
+
n_laplace =\
|
|
88
|
+
self.grad_x_center(field)**2 * (field[RIGHT] - 2*field[CENTER] + field[LEFT]) * self.div_dx2[0] +\
|
|
89
|
+
self.grad_y_center(field)**2 * (field[TOP] - 2*field[CENTER] + field[BOTTOM]) * self.div_dx2[1]+\
|
|
90
|
+
self.grad_z_center(field)**2 * (field[FRONT] - 2*field[CENTER] + field[BACK]) * self.div_dx2[2]+\
|
|
91
|
+
0.5 * self.grad_x_center(field) * self.grad_y_center(field) *\
|
|
92
|
+
(field[:,2:,2:,1:-1] + field[:,:-2,:-2,1:-1] -\
|
|
93
|
+
field[:,:-2,2:,1:-1] - field[:,2:,:-2,1:-1]) * self.div_dx[0] * self.div_dx[1] +\
|
|
94
|
+
0.5 *self.grad_x_center(field) * self.grad_z_center(field) *\
|
|
95
|
+
(field[:,2:,1:-1,2:] + field[:,:-2,1:-1,:-2] -\
|
|
96
|
+
field[:,:-2,1:-1,2:] - field[:,2:,1:-1,:-2]) * self.div_dx[0] * self.div_dx[2] +\
|
|
97
|
+
0.5 * self.grad_y_center(field) * self.grad_z_center(field) *\
|
|
98
|
+
(field[:,1:-1,2:,2:] + field[:,1:-1,:-2,:-2] -\
|
|
99
|
+
field[:,1:-1,:-2,2:] - field[:,1:-1,2:,:-2]) * self.div_dx[1] * self.div_dx[2]
|
|
100
|
+
norm2 = self.gradient_norm_squared(field)
|
|
101
|
+
bulk = self.lib.where(norm2 <= 1e-7)
|
|
102
|
+
norm2 = self.set(norm2, bulk, 1.0)
|
|
103
|
+
return n_laplace/norm2
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
_HAS_JAX = True
|
|
7
|
+
except ImportError:
|
|
8
|
+
_HAS_JAX = False
|
|
9
|
+
class DummyJax:
|
|
10
|
+
@staticmethod
|
|
11
|
+
def jit(f):
|
|
12
|
+
return f
|
|
13
|
+
class DummyJnp:
|
|
14
|
+
@staticmethod
|
|
15
|
+
def ones_like(x):
|
|
16
|
+
return np.ones_like(x)
|
|
17
|
+
@staticmethod
|
|
18
|
+
def exp(x):
|
|
19
|
+
return np.exp(x)
|
|
20
|
+
|
|
21
|
+
jax = DummyJax()
|
|
22
|
+
jnp = DummyJnp()
|
|
23
|
+
|
|
24
|
+
import dataclasses
|
|
25
|
+
|
|
26
|
+
@dataclasses.dataclass
|
|
27
|
+
class DiffusionLegendrePolynomials:
|
|
28
|
+
max_degree: int
|
|
29
|
+
|
|
30
|
+
def __post_init__(self):
|
|
31
|
+
self.leg_poly = ExpLegendrePolynomials(self.max_degree)
|
|
32
|
+
|
|
33
|
+
def __call__(self, params, inputs):
|
|
34
|
+
return self.leg_poly(params, 2.0 * inputs - 1.0)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclasses.dataclass
|
|
38
|
+
class ChemicalPotentialLegendrePolynomials:
|
|
39
|
+
max_degree: int
|
|
40
|
+
|
|
41
|
+
def __post_init__(self):
|
|
42
|
+
self.leg_poly = LegendrePolynomialRecurrence(self.max_degree)
|
|
43
|
+
|
|
44
|
+
def __call__(self, params, inputs):
|
|
45
|
+
return self.leg_poly(params, 2.0 * inputs - 1.0)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclasses.dataclass
|
|
49
|
+
class ExpLegendrePolynomials:
|
|
50
|
+
max_degree: int
|
|
51
|
+
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
leg_poly = LegendrePolynomialRecurrence(self.max_degree)
|
|
54
|
+
self.func = jax.jit(lambda p, x: jnp.exp(leg_poly(p, x)))
|
|
55
|
+
|
|
56
|
+
def __call__(self, params, inputs):
|
|
57
|
+
return self.func(params, inputs)
|
|
58
|
+
|
|
59
|
+
# TODO: This can be made more efficient
|
|
60
|
+
@dataclasses.dataclass
|
|
61
|
+
class LegendrePolynomialRecurrence:
|
|
62
|
+
max_degree: int
|
|
63
|
+
|
|
64
|
+
def __post_init__(self):
|
|
65
|
+
# Create a JIT-compiled function that computes the Legendre polynomial sum
|
|
66
|
+
def compute_polynomial_sum(params, x):
|
|
67
|
+
result = params[0] * self.T0(x)
|
|
68
|
+
for i in range(1, self.max_degree + 1):
|
|
69
|
+
result += params[i] * self._compute_legendre(i, x)
|
|
70
|
+
return result
|
|
71
|
+
|
|
72
|
+
self.func = jax.jit(compute_polynomial_sum)
|
|
73
|
+
|
|
74
|
+
def __call__(self, params, inputs):
|
|
75
|
+
return self.func(params, inputs)
|
|
76
|
+
|
|
77
|
+
def T0(self, x):
|
|
78
|
+
return 1.0 * jnp.ones_like(x)
|
|
79
|
+
|
|
80
|
+
def _compute_legendre(self, n, x):
|
|
81
|
+
"""Compute the nth Legendre polynomial using the three-term recurrence relation."""
|
|
82
|
+
if n == 0:
|
|
83
|
+
return self.T0(x)
|
|
84
|
+
elif n == 1:
|
|
85
|
+
return x
|
|
86
|
+
|
|
87
|
+
# Initialize P₀ and P₁
|
|
88
|
+
p_prev = self.T0(x) # P₀
|
|
89
|
+
p_curr = x # P₁
|
|
90
|
+
|
|
91
|
+
# Compute Pₙ using the recurrence relation
|
|
92
|
+
for i in range(1, n):
|
|
93
|
+
p_next = ((2 * i + 1) * x * p_curr - i * p_prev) / (i + 1)
|
|
94
|
+
p_prev = p_curr
|
|
95
|
+
p_curr = p_next
|
|
96
|
+
|
|
97
|
+
return p_curr
|
evoxels/inversion.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from timeit import default_timer as timer
|
|
4
|
+
from typing import Any, Type, Optional
|
|
5
|
+
from evoxels.timesteppers import PseudoSpectralIMEX_dfx
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import diffrax as dfx
|
|
9
|
+
import optimistix as optx
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import jax
|
|
12
|
+
except ImportError:
|
|
13
|
+
dfx = None
|
|
14
|
+
optx = None
|
|
15
|
+
jnp = None
|
|
16
|
+
jax = None
|
|
17
|
+
|
|
18
|
+
DIFFRAX_AVAILABLE = dfx is not None
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class InversionModel:
|
|
22
|
+
"""Inverse modeling using JAX and diffrax.
|
|
23
|
+
|
|
24
|
+
This small helper class wraps the differentiable solver implementation and
|
|
25
|
+
provides utilities to fit material parameters via gradient based
|
|
26
|
+
optimization. It is intentionally lightweight so that new users can easily
|
|
27
|
+
follow the individual steps: solving the PDE, computing residuals and
|
|
28
|
+
running a least squares optimiser.
|
|
29
|
+
"""
|
|
30
|
+
vf: Any # VoxelFields object
|
|
31
|
+
problem_cls: Type
|
|
32
|
+
pos_params: Optional[list[str]] = None
|
|
33
|
+
problem_kwargs: Optional[dict[str, Any]] = None
|
|
34
|
+
backend: str = 'jax'
|
|
35
|
+
|
|
36
|
+
def __post_init__(self):
|
|
37
|
+
"""Initialize backend specific components."""
|
|
38
|
+
self.problem_kwargs = self.problem_kwargs or {}
|
|
39
|
+
if self.backend == 'jax':
|
|
40
|
+
from evoxels.voxelgrid import VoxelGridJax
|
|
41
|
+
from .profiler import JAXMemoryProfiler
|
|
42
|
+
self.vg = VoxelGridJax(self.vf.grid_info(), precision=self.vf.precision)
|
|
43
|
+
self.profiler = JAXMemoryProfiler()
|
|
44
|
+
|
|
45
|
+
if not DIFFRAX_AVAILABLE:
|
|
46
|
+
raise ImportError(
|
|
47
|
+
"CahnHilliardInversionModel requires the optional JAX"
|
|
48
|
+
" dependencies (jax, diffrax)."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def solve(self, parameters, y0, saveat, adjoint=dfx.ForwardMode(), dt0=0.1):
|
|
52
|
+
"""Integrate the Cahn--Hilliard equation for a given parameter set.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
parameters (dict): Dictionary containing the material parameters to
|
|
56
|
+
solve with.
|
|
57
|
+
y0 (array-like): Initial concentration field.
|
|
58
|
+
saveat (:class:`diffrax.SaveAt`): Time points at which the solution
|
|
59
|
+
should be stored.
|
|
60
|
+
adjoint: Differentiation mode used by :func:`diffrax.diffeqsolve`.
|
|
61
|
+
dt0 (float): Initial step size for the time integrator.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
jax.Array: Array of saved concentration fields with shape
|
|
65
|
+
``(len(saveat.ts), Nx, Ny, Nz)``.
|
|
66
|
+
"""
|
|
67
|
+
u = self.vg.init_scalar_field(y0)
|
|
68
|
+
u = self.vg.bc.trim_boundary_nodes(u)
|
|
69
|
+
if self.pos_params:
|
|
70
|
+
parameters = {k: jnp.exp(v) if k in self.pos_params else v for k, v in parameters.items()}
|
|
71
|
+
problem = self.problem_cls(self.vg, **self.problem_kwargs, **parameters)
|
|
72
|
+
solver = PseudoSpectralIMEX_dfx(problem.fourier_symbol)
|
|
73
|
+
|
|
74
|
+
solution = dfx.diffeqsolve(
|
|
75
|
+
dfx.ODETerm(lambda t, y, args: problem.rhs(y, t)),
|
|
76
|
+
solver,
|
|
77
|
+
t0=saveat.subs.ts[0],
|
|
78
|
+
t1=saveat.subs.ts[-1],
|
|
79
|
+
dt0=dt0,
|
|
80
|
+
y0=u,
|
|
81
|
+
saveat=saveat,
|
|
82
|
+
max_steps=100000,
|
|
83
|
+
throw=False,
|
|
84
|
+
adjoint=adjoint,
|
|
85
|
+
)
|
|
86
|
+
padded = problem.pad_bc(solution.ys[:, 0])
|
|
87
|
+
out = self.vg.bc.trim_ghost_nodes(padded)
|
|
88
|
+
return out
|
|
89
|
+
|
|
90
|
+
def forward_solve(self, parameters, fieldname, saveat, dt0=0.1, verbose=True):
|
|
91
|
+
start = timer()
|
|
92
|
+
u0 = self.vf.fields[fieldname]
|
|
93
|
+
if self.pos_params:
|
|
94
|
+
parameters = {k: jnp.log(v) if k in self.pos_params else v for k, v in parameters.items()}
|
|
95
|
+
sol = self.solve(parameters, u0, saveat, dt0=dt0)
|
|
96
|
+
end = timer()
|
|
97
|
+
|
|
98
|
+
self.vf.fields[fieldname] = self.vg.to_numpy(sol[-1])
|
|
99
|
+
if verbose:
|
|
100
|
+
iterations = int(saveat.subs.ts[-1] // dt0)
|
|
101
|
+
self.profiler.print_memory_stats(start, end, iterations)
|
|
102
|
+
|
|
103
|
+
return sol
|
|
104
|
+
|
|
105
|
+
def residuals(self, parameters, y0s__values__saveat, adjoint=dfx.ForwardMode()):
|
|
106
|
+
"""Calculate residuals between measured and simulated states.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
parameters (dict): Current estimate of the model parameters.
|
|
110
|
+
y0s__values__saveat (tuple): Tuple ``(y0s, values, saveat)`` where
|
|
111
|
+
``y0s`` contains the initial states for each sequence, ``values``
|
|
112
|
+
contains the observed states and ``saveat`` specifies the time
|
|
113
|
+
points of these observations.
|
|
114
|
+
adjoint: Differentiation mode for :func:`solve`.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
jax.Array: Array of residuals with shape matching ``values``.
|
|
118
|
+
"""
|
|
119
|
+
y0s, values, saveat = y0s__values__saveat
|
|
120
|
+
solve_ = partial(self.solve, adjoint=adjoint)
|
|
121
|
+
batch_solve = jax.vmap(solve_, in_axes=(None, 0, None))
|
|
122
|
+
pred_values = batch_solve(parameters, y0s, saveat)
|
|
123
|
+
residuals = values - pred_values[:, 1:]
|
|
124
|
+
return residuals
|
|
125
|
+
|
|
126
|
+
def train(
|
|
127
|
+
self,
|
|
128
|
+
initial_parameters,
|
|
129
|
+
data,
|
|
130
|
+
inds,
|
|
131
|
+
adjoint=dfx.ForwardMode(),
|
|
132
|
+
rtol=1e-6,
|
|
133
|
+
atol=1e-6,
|
|
134
|
+
verbose=True,
|
|
135
|
+
max_steps=1000,
|
|
136
|
+
):
|
|
137
|
+
"""Fit ``parameters`` so that the model matches observed data.
|
|
138
|
+
|
|
139
|
+
This method assembles the observed sequences into a format suitable for
|
|
140
|
+
:func:`optimistix.least_squares` and then runs a Levenberg--Marquardt
|
|
141
|
+
optimisation to minimise the residuals returned by :func:`residuals`.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
initial_parameters (dict): Initial guess for the parameters to be
|
|
145
|
+
optimised.
|
|
146
|
+
data (dict): Dictionary containing ``"ts"`` (time stamps) and
|
|
147
|
+
``"ys"`` (concentration fields) as produced by :func:`solve`.
|
|
148
|
+
inds (list[list[int]]): For each sequence, the indices in ``data``
|
|
149
|
+
that should be used for training. All sequences must have the
|
|
150
|
+
same spacing.
|
|
151
|
+
adjoint: Differentiation mode used when evaluating the residuals.
|
|
152
|
+
rtol, atol (float): Tolerances for the optimiser.
|
|
153
|
+
verbose (bool): If ``True``, prints optimisation progress.
|
|
154
|
+
max_steps (int): Maximum number of optimisation steps.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
optimistix.State: The optimiser state after termination.
|
|
158
|
+
"""
|
|
159
|
+
# Get length of first sequence to use as reference
|
|
160
|
+
ref_len = len(inds[0])
|
|
161
|
+
if ref_len < 2:
|
|
162
|
+
raise ValueError("Each sequence in inds must have at least 2 elements")
|
|
163
|
+
|
|
164
|
+
# Get reference spacing from first sequence
|
|
165
|
+
ref_spacing = [inds[0][i + 1] - inds[0][i] for i in range(ref_len - 1)]
|
|
166
|
+
|
|
167
|
+
# Validate all other sequences
|
|
168
|
+
for i, sequence in enumerate(inds):
|
|
169
|
+
if len(sequence) != ref_len:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Sequence {i} has different length than first sequence"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Check spacing
|
|
175
|
+
spacing = [sequence[j + 1] - sequence[j] for j in range(len(sequence) - 1)]
|
|
176
|
+
if spacing != ref_spacing:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Sequence {i} has different spacing than first sequence"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# TODO: make data a voxelgrid or voxelfield object
|
|
182
|
+
y0s = jnp.array([data["ys"][ind[0]] for ind in inds])
|
|
183
|
+
values = jnp.array(
|
|
184
|
+
[
|
|
185
|
+
jnp.array([data["ys"][ind[i]] for i in range(1, len(ind))])
|
|
186
|
+
for ind in inds
|
|
187
|
+
]
|
|
188
|
+
)
|
|
189
|
+
saveat = dfx.SaveAt(
|
|
190
|
+
ts=jnp.array(
|
|
191
|
+
[0.0]
|
|
192
|
+
+ [
|
|
193
|
+
data["ts"][inds[0][i]] - data["ts"][inds[0][0]]
|
|
194
|
+
for i in range(1, len(inds[0]))
|
|
195
|
+
]
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
args = (y0s, values, saveat)
|
|
200
|
+
residuals_ = partial(self.residuals, adjoint=adjoint)
|
|
201
|
+
|
|
202
|
+
if self.pos_params:
|
|
203
|
+
# Ensure parameters are positive and take log
|
|
204
|
+
for key in self.pos_params:
|
|
205
|
+
if initial_parameters[key] <= 0:
|
|
206
|
+
raise ValueError(f"Parameter {key} must be positive")
|
|
207
|
+
initial_parameters[key] = jnp.log(initial_parameters[key])
|
|
208
|
+
|
|
209
|
+
solver = optx.LevenbergMarquardt(
|
|
210
|
+
rtol=rtol,
|
|
211
|
+
atol=atol,
|
|
212
|
+
verbose=frozenset(
|
|
213
|
+
{"step", "accepted", "loss", "step_size"} if verbose else None
|
|
214
|
+
),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
sol = optx.least_squares(
|
|
218
|
+
residuals_,
|
|
219
|
+
solver,
|
|
220
|
+
initial_parameters,
|
|
221
|
+
args=args,
|
|
222
|
+
max_steps=max_steps,
|
|
223
|
+
throw=False,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
res = sol.value
|
|
227
|
+
|
|
228
|
+
if self.pos_params:
|
|
229
|
+
# Ensure parameters are positive and take exp
|
|
230
|
+
for key in self.pos_params:
|
|
231
|
+
res[key] = jnp.exp(res[key])
|
|
232
|
+
|
|
233
|
+
return res
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Exports for the precompiled solver package."""
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from ..problem_definition import AllenCahnEquation
|
|
2
|
+
from ..solvers import TimeDependentSolver
|
|
3
|
+
from ..timesteppers import ForwardEuler
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
def run_allen_cahn_solver(
|
|
7
|
+
voxelfields,
|
|
8
|
+
fieldnames: str | list[str],
|
|
9
|
+
backend: str,
|
|
10
|
+
jit: bool = True,
|
|
11
|
+
device: str = "cuda",
|
|
12
|
+
time_increment: float = 0.1,
|
|
13
|
+
frames: int = 10,
|
|
14
|
+
max_iters: int = 100,
|
|
15
|
+
eps: float = 2.0,
|
|
16
|
+
gab: float = 1.0,
|
|
17
|
+
M: float = 1.0,
|
|
18
|
+
force: float = 0.0,
|
|
19
|
+
curvature: float = 0.01,
|
|
20
|
+
potential: Callable | None = None,
|
|
21
|
+
vtk_out: bool = False,
|
|
22
|
+
verbose: bool = True,
|
|
23
|
+
plot_bounds = None,
|
|
24
|
+
):
|
|
25
|
+
"""
|
|
26
|
+
Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
|
|
27
|
+
"""
|
|
28
|
+
solver = TimeDependentSolver(
|
|
29
|
+
voxelfields,
|
|
30
|
+
fieldnames,
|
|
31
|
+
backend,
|
|
32
|
+
problem_cls = AllenCahnEquation,
|
|
33
|
+
timestepper_cls = ForwardEuler,
|
|
34
|
+
device=device,
|
|
35
|
+
)
|
|
36
|
+
solver.solve(
|
|
37
|
+
time_increment=time_increment,
|
|
38
|
+
frames=frames,
|
|
39
|
+
max_iters=max_iters,
|
|
40
|
+
problem_kwargs={"eps": eps,
|
|
41
|
+
"gab": gab,
|
|
42
|
+
"M": M,
|
|
43
|
+
"force": force,
|
|
44
|
+
"curvature": curvature,
|
|
45
|
+
"potential": potential},
|
|
46
|
+
jit=jit,
|
|
47
|
+
verbose=verbose,
|
|
48
|
+
vtk_out=vtk_out,
|
|
49
|
+
plot_bounds=plot_bounds,
|
|
50
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from ..problem_definition import PeriodicCahnHilliard
|
|
2
|
+
from ..solvers import TimeDependentSolver
|
|
3
|
+
from ..timesteppers import PseudoSpectralIMEX
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
def run_cahn_hilliard_solver(
|
|
7
|
+
voxelfields,
|
|
8
|
+
fieldnames: str | list[str],
|
|
9
|
+
backend: str,
|
|
10
|
+
jit: bool = True,
|
|
11
|
+
device: str = "cuda",
|
|
12
|
+
time_increment: float = 0.1,
|
|
13
|
+
frames: int = 10,
|
|
14
|
+
max_iters: int = 100,
|
|
15
|
+
eps: float = 3.0,
|
|
16
|
+
diffusivity: float = 1.0,
|
|
17
|
+
mu_hom: Callable | None = None,
|
|
18
|
+
vtk_out: bool = False,
|
|
19
|
+
verbose: bool = True,
|
|
20
|
+
plot_bounds = None,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
|
|
24
|
+
"""
|
|
25
|
+
solver = TimeDependentSolver(
|
|
26
|
+
voxelfields,
|
|
27
|
+
fieldnames,
|
|
28
|
+
backend,
|
|
29
|
+
problem_cls = PeriodicCahnHilliard,
|
|
30
|
+
timestepper_cls = PseudoSpectralIMEX,
|
|
31
|
+
device=device,
|
|
32
|
+
)
|
|
33
|
+
solver.solve(
|
|
34
|
+
time_increment=time_increment,
|
|
35
|
+
frames=frames,
|
|
36
|
+
max_iters=max_iters,
|
|
37
|
+
problem_kwargs={"eps": eps, "D": diffusivity, "mu_hom": mu_hom},
|
|
38
|
+
jit=jit,
|
|
39
|
+
verbose=verbose,
|
|
40
|
+
vtk_out=vtk_out,
|
|
41
|
+
plot_bounds=plot_bounds,
|
|
42
|
+
)
|