evoxels 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
evoxels/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ """Public API for the evoxels package."""
2
+
3
+ from .voxelfields import VoxelFields
4
+ from .precompiled_solvers.cahn_hilliard import (run_cahn_hilliard_solver)
5
+ from .precompiled_solvers.allen_cahn import (run_allen_cahn_solver)
6
+ from .inversion import InversionModel
7
+
8
+ __all__ = [
9
+ "VoxelFields",
10
+ "run_cahn_hilliard_solver",
11
+ "run_allen_cahn_solver",
12
+ "InversionModel"
13
+ ]
@@ -0,0 +1,138 @@
1
+ # Shorthands in slicing logic
2
+ __ = slice(None) # all elements [:]
3
+ _i_ = slice(1, -1) # inner elements [1:-1]
4
+
5
+ class CellCenteredBCs:
6
+ def __init__(self, vg):
7
+ self.vg = vg
8
+
9
+ def pad_periodic(self, field):
10
+ """
11
+ Periodic boundary conditions in all directions.
12
+ Consistent with cell centered grid.
13
+ """
14
+ return self.vg.pad_periodic(field)
15
+
16
+ def pad_dirichlet_periodic(self, field, bc0=0, bc1=0):
17
+ """
18
+ Homogenous Dirichlet boundary conditions in x-drection,
19
+ periodic in y- and z-direction. Consistent with cell centered grid,
20
+ but loss of 2nd order convergence.
21
+ """
22
+ padded = self.vg.pad_periodic(field)
23
+ padded = self.vg.set(padded, (__, 0,__,__), 2.0*bc0 - padded[:, 1,:,:])
24
+ padded = self.vg.set(padded, (__,-1,__,__), 2.0*bc1 - padded[:,-2,:,:])
25
+ return padded
26
+
27
+ def pad_zero_flux_periodic(self, field):
28
+ padded = self.vg.pad_periodic(field)
29
+ padded = self.vg.set(padded, (__, 0,__,__), padded[:, 1,:,:])
30
+ padded = self.vg.set(padded, (__,-1,__,__), padded[:,-2,:,:])
31
+ return padded
32
+
33
+ def pad_zero_flux(self, field):
34
+ padded = self.vg.pad_zeros(field)
35
+ padded = self.vg.set(padded, (__, 0,__,__), padded[:, 1,:,:])
36
+ padded = self.vg.set(padded, (__,-1,__,__), padded[:,-2,:,:])
37
+ padded = self.vg.set(padded, (__,__, 0,__), padded[:,:, 1,:])
38
+ padded = self.vg.set(padded, (__,__,-1,__), padded[:,:,-2,:])
39
+ padded = self.vg.set(padded, (__,__,__, 0), padded[:,:,:, 1])
40
+ padded = self.vg.set(padded, (__,__,__,-1), padded[:,:,:,-2])
41
+ return padded
42
+
43
+ def pad_fft_periodic(self, field):
44
+ """Periodic field needs no fft padding."""
45
+ return field
46
+
47
+ def pad_fft_dirichlet_periodic(self, field):
48
+ """Pad with inverse of flipped field in x direction."""
49
+ return self.vg.concatenate((field, -self.vg.lib.flip(field, [0])), 1)
50
+
51
+ def pad_fft_zero_flux_periodic(self, field):
52
+ """Pad with flipped field in x direction."""
53
+ return self.vg.concatenate((field, self.vg.lib.flip(field, [0])), 1)
54
+
55
+ def trim_boundary_nodes(self, field):
56
+ return field
57
+
58
+ def trim_ghost_nodes(self, field):
59
+ if field[0,_i_,_i_,_i_].shape == self.vg.shape:
60
+ return field[:,_i_,_i_,_i_]
61
+ else:
62
+ raise ValueError(
63
+ f"The provided field has the wrong shape {self.vg.shape}."
64
+ )
65
+
66
+
67
+ class StaggeredXBCs:
68
+ def __init__(self, vg):
69
+ self.vg = vg
70
+
71
+ def pad_periodic_BC_staggered_x(self, field):
72
+ """
73
+ If field is fully periodic it should be in
74
+ cell center convention!
75
+ """
76
+ raise NotImplementedError
77
+
78
+ def pad_dirichlet_periodic(self, field, bc0=0, bc1=0):
79
+ """
80
+ Homogenous Dirichlet boundary conditions in x-drection,
81
+ periodic in y- and z-direction. Consistent with staggered_x grid,
82
+ maintains 2nd order convergence.
83
+ """
84
+ padded = self.vg.pad_periodic(field)
85
+ padded = self.vg.set(padded, (__, 0,__,__), bc0)
86
+ padded = self.vg.set(padded, (__,-1,__,__), bc1)
87
+ return padded
88
+
89
+ def pad_zero_flux_periodic(self, field):
90
+ """
91
+ The following comes out of on interpolation polynomial p with
92
+ p'(0) = 0, p(dx) = f(dx,...), p(2*dx) = f(2*dx,...)
93
+ and then use p(0) for the ghost cell.
94
+ This should be of sufficient order of f'(0) = 0, and even better if
95
+ also f'''(0) = 0 (as it holds for cos(k*pi*x) )
96
+ """
97
+ padded = self.vg.pad_periodic(field)
98
+ fac1 = 4/3
99
+ fac2 = 1/3
100
+ padded = self.vg.set(padded, (__, 0,__,__), fac1*padded[:, 1,:,:] - fac2*padded[:, 2,:,:])
101
+ padded = self.vg.set(padded, (__,-1,__,__), fac1*padded[:,-2,:,:] - fac2*padded[:,-3,:,:])
102
+ return padded
103
+
104
+ def pad_zero_flux(self, field):
105
+ raise NotImplementedError
106
+
107
+ def pad_fft_periodic(self, field):
108
+ """
109
+ If field is fully periodic it should be in
110
+ cell center convention!
111
+ """
112
+ raise NotImplementedError
113
+
114
+ def pad_fft_dirichlet_periodic(self, field):
115
+ """Pad with inverse of flipped field in x direction."""
116
+ bc = self.vg.lib.zeros_like(field[:,0:1])
117
+ return self.vg.concatenate((field, bc, -self.vg.lib.flip(field, [0]), bc), 1)
118
+
119
+ def pad_fft_zero_flux_periodic(self, field):
120
+ """Pad with flipped field in x direction."""
121
+ raise NotImplementedError
122
+
123
+ def trim_boundary_nodes(self, field):
124
+ """Trim boundary nodes of ``field`` for staggered grids."""
125
+ if field.shape[1] == self.vg.shape[0]:
126
+ return field[:,_i_,:,:]
127
+ else:
128
+ raise ValueError(
129
+ f"The provided field must have the shape {self.vg.shape}."
130
+ )
131
+
132
+ def trim_ghost_nodes(self, field):
133
+ if field[0,:,_i_,_i_].shape == self.vg.shape:
134
+ return field[:,:,_i_,_i_]
135
+ else:
136
+ raise ValueError(
137
+ f"The provided field has the wrong shape {self.vg.shape}."
138
+ )
evoxels/fd_stencils.py ADDED
@@ -0,0 +1,103 @@
1
+ # Shorthands in slicing logic
2
+ __ = slice(None) # all elements [:]
3
+ _i_ = slice(1, -1) # inner elements [1:-1]
4
+
5
+ CENTER = (__, _i_, _i_, _i_)
6
+ LEFT = (__, slice(None,-2), _i_, _i_)
7
+ RIGHT = (__, slice(2, None), _i_, _i_)
8
+ BOTTOM = (__, _i_, slice(None,-2), _i_)
9
+ TOP = (__, _i_, slice(2, None), _i_)
10
+ BACK = (__, _i_, _i_, slice(None,-2))
11
+ FRONT = (__, _i_, _i_, slice(2, None))
12
+
13
+ class FDStencils:
14
+ """Class wrapper for finite difference stencils
15
+
16
+ Is inherited by the VoxelGrid to apply stencils to
17
+ backend arrays.
18
+ """
19
+
20
+ def to_x_face(self, field):
21
+ """Interpolate to face position staggered in x"""
22
+ return 0.5 * (field[:,1:,:,:] + field[:,:-1,:,:])
23
+
24
+ def to_y_face(self, field):
25
+ """Interpolate to face position staggered in y"""
26
+ return 0.5 * (field[:,:,1:,:] + field[:,:,:-1,:])
27
+
28
+ def to_z_face(self, field):
29
+ """Interpolate to face position staggered in z"""
30
+ return 0.5 * (field[:,:,:,1:] + field[:,:,:,:-1])
31
+
32
+ def grad_x_face(self, field):
33
+ """Gradient at face position staggered in x"""
34
+ return (field[:,1:,:,:] - field[:,:-1,:,:]) * self.div_dx[0]
35
+
36
+ def grad_y_face(self, field):
37
+ """Gradient at face position staggered in y"""
38
+ return (field[:,:,1:,:] - field[:,:,:-1,:]) * self.div_dx[1]
39
+
40
+ def grad_z_face(self, field):
41
+ """Gradient at face position staggered in z"""
42
+ return (field[:,:,:,1:] - field[:,:,:,:-1]) * self.div_dx[2]
43
+
44
+ def grad_x_center(self, field):
45
+ """Gradient in x at cell center"""
46
+ return 0.5 * (field[RIGHT] - field[LEFT]) * self.div_dx[0]
47
+
48
+ def grad_y_center(self, field):
49
+ """Gradient in x at cell center"""
50
+ return 0.5 * (field[TOP] - field[BOTTOM]) * self.div_dx[1]
51
+
52
+ def grad_z_center(self, field):
53
+ """Gradient in x at cell center"""
54
+ return 0.5 * (field[FRONT] - field[BACK]) * self.div_dx[2]
55
+
56
+ def gradient_norm_squared(self, field):
57
+ """Gradient norm squared at cell centers"""
58
+ return self.grad_x_center(field)**2 +\
59
+ self.grad_y_center(field)**2 + \
60
+ self.grad_z_center(field)**2
61
+
62
+ def laplace(self, field):
63
+ r"""Calculate laplace based on compact 2nd order stencil.
64
+
65
+ Laplace given as $\nabla\cdot(\nabla u)$ which in 3D is given by
66
+ $\partial^2 u/\partial^2 x + \partial^2 u/\partial^2 y+ \partial^2 u/\partial^2 z$
67
+ Returned field has same shape as the input field (padded with zeros)
68
+ """
69
+ # Manual indexing is ~10x faster than conv3d with laplace kernel in torch
70
+ laplace = \
71
+ (field[RIGHT] + field[LEFT]) * self.div_dx2[0] + \
72
+ (field[TOP] + field[BOTTOM]) * self.div_dx2[1] + \
73
+ (field[FRONT] + field[BACK]) * self.div_dx2[2] - \
74
+ 2 * field[CENTER] * self.lib.sum(self.div_dx2)
75
+ return laplace
76
+
77
+ def normal_laplace(self, field):
78
+ r"""Calculate the normal component of the laplacian
79
+
80
+ which is identical to the full laplacian minus curvature.
81
+ It is defined as $\partial^2_n u = \nabla\cdot(\nabla u\cdot n)\cdot n$
82
+ where $n$ denotes the surface normal.
83
+ In the context of phasefield models $n$ is defined as
84
+ $\frac{\nabla u}{|\nabla u|}$.
85
+ The calaculation is based on a compact 2nd order stencil.
86
+ """
87
+ n_laplace =\
88
+ self.grad_x_center(field)**2 * (field[RIGHT] - 2*field[CENTER] + field[LEFT]) * self.div_dx2[0] +\
89
+ self.grad_y_center(field)**2 * (field[TOP] - 2*field[CENTER] + field[BOTTOM]) * self.div_dx2[1]+\
90
+ self.grad_z_center(field)**2 * (field[FRONT] - 2*field[CENTER] + field[BACK]) * self.div_dx2[2]+\
91
+ 0.5 * self.grad_x_center(field) * self.grad_y_center(field) *\
92
+ (field[:,2:,2:,1:-1] + field[:,:-2,:-2,1:-1] -\
93
+ field[:,:-2,2:,1:-1] - field[:,2:,:-2,1:-1]) * self.div_dx[0] * self.div_dx[1] +\
94
+ 0.5 *self.grad_x_center(field) * self.grad_z_center(field) *\
95
+ (field[:,2:,1:-1,2:] + field[:,:-2,1:-1,:-2] -\
96
+ field[:,:-2,1:-1,2:] - field[:,2:,1:-1,:-2]) * self.div_dx[0] * self.div_dx[2] +\
97
+ 0.5 * self.grad_y_center(field) * self.grad_z_center(field) *\
98
+ (field[:,1:-1,2:,2:] + field[:,1:-1,:-2,:-2] -\
99
+ field[:,1:-1,:-2,2:] - field[:,1:-1,2:,:-2]) * self.div_dx[1] * self.div_dx[2]
100
+ norm2 = self.gradient_norm_squared(field)
101
+ bulk = self.lib.where(norm2 <= 1e-7)
102
+ norm2 = self.set(norm2, bulk, 1.0)
103
+ return n_laplace/norm2
@@ -0,0 +1,97 @@
1
+ import numpy as np
2
+
3
+ try:
4
+ import jax
5
+ import jax.numpy as jnp
6
+ _HAS_JAX = True
7
+ except ImportError:
8
+ _HAS_JAX = False
9
+ class DummyJax:
10
+ @staticmethod
11
+ def jit(f):
12
+ return f
13
+ class DummyJnp:
14
+ @staticmethod
15
+ def ones_like(x):
16
+ return np.ones_like(x)
17
+ @staticmethod
18
+ def exp(x):
19
+ return np.exp(x)
20
+
21
+ jax = DummyJax()
22
+ jnp = DummyJnp()
23
+
24
+ import dataclasses
25
+
26
+ @dataclasses.dataclass
27
+ class DiffusionLegendrePolynomials:
28
+ max_degree: int
29
+
30
+ def __post_init__(self):
31
+ self.leg_poly = ExpLegendrePolynomials(self.max_degree)
32
+
33
+ def __call__(self, params, inputs):
34
+ return self.leg_poly(params, 2.0 * inputs - 1.0)
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class ChemicalPotentialLegendrePolynomials:
39
+ max_degree: int
40
+
41
+ def __post_init__(self):
42
+ self.leg_poly = LegendrePolynomialRecurrence(self.max_degree)
43
+
44
+ def __call__(self, params, inputs):
45
+ return self.leg_poly(params, 2.0 * inputs - 1.0)
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class ExpLegendrePolynomials:
50
+ max_degree: int
51
+
52
+ def __post_init__(self):
53
+ leg_poly = LegendrePolynomialRecurrence(self.max_degree)
54
+ self.func = jax.jit(lambda p, x: jnp.exp(leg_poly(p, x)))
55
+
56
+ def __call__(self, params, inputs):
57
+ return self.func(params, inputs)
58
+
59
+ # TODO: This can be made more efficient
60
+ @dataclasses.dataclass
61
+ class LegendrePolynomialRecurrence:
62
+ max_degree: int
63
+
64
+ def __post_init__(self):
65
+ # Create a JIT-compiled function that computes the Legendre polynomial sum
66
+ def compute_polynomial_sum(params, x):
67
+ result = params[0] * self.T0(x)
68
+ for i in range(1, self.max_degree + 1):
69
+ result += params[i] * self._compute_legendre(i, x)
70
+ return result
71
+
72
+ self.func = jax.jit(compute_polynomial_sum)
73
+
74
+ def __call__(self, params, inputs):
75
+ return self.func(params, inputs)
76
+
77
+ def T0(self, x):
78
+ return 1.0 * jnp.ones_like(x)
79
+
80
+ def _compute_legendre(self, n, x):
81
+ """Compute the nth Legendre polynomial using the three-term recurrence relation."""
82
+ if n == 0:
83
+ return self.T0(x)
84
+ elif n == 1:
85
+ return x
86
+
87
+ # Initialize P₀ and P₁
88
+ p_prev = self.T0(x) # P₀
89
+ p_curr = x # P₁
90
+
91
+ # Compute Pₙ using the recurrence relation
92
+ for i in range(1, n):
93
+ p_next = ((2 * i + 1) * x * p_curr - i * p_prev) / (i + 1)
94
+ p_prev = p_curr
95
+ p_curr = p_next
96
+
97
+ return p_curr
evoxels/inversion.py ADDED
@@ -0,0 +1,233 @@
1
+ from functools import partial
2
+ from dataclasses import dataclass
3
+ from timeit import default_timer as timer
4
+ from typing import Any, Type, Optional
5
+ from evoxels.timesteppers import PseudoSpectralIMEX_dfx
6
+
7
+ try:
8
+ import diffrax as dfx
9
+ import optimistix as optx
10
+ import jax.numpy as jnp
11
+ import jax
12
+ except ImportError:
13
+ dfx = None
14
+ optx = None
15
+ jnp = None
16
+ jax = None
17
+
18
+ DIFFRAX_AVAILABLE = dfx is not None
19
+
20
+ @dataclass
21
+ class InversionModel:
22
+ """Inverse modeling using JAX and diffrax.
23
+
24
+ This small helper class wraps the differentiable solver implementation and
25
+ provides utilities to fit material parameters via gradient based
26
+ optimization. It is intentionally lightweight so that new users can easily
27
+ follow the individual steps: solving the PDE, computing residuals and
28
+ running a least squares optimiser.
29
+ """
30
+ vf: Any # VoxelFields object
31
+ problem_cls: Type
32
+ pos_params: Optional[list[str]] = None
33
+ problem_kwargs: Optional[dict[str, Any]] = None
34
+ backend: str = 'jax'
35
+
36
+ def __post_init__(self):
37
+ """Initialize backend specific components."""
38
+ self.problem_kwargs = self.problem_kwargs or {}
39
+ if self.backend == 'jax':
40
+ from evoxels.voxelgrid import VoxelGridJax
41
+ from .profiler import JAXMemoryProfiler
42
+ self.vg = VoxelGridJax(self.vf.grid_info(), precision=self.vf.precision)
43
+ self.profiler = JAXMemoryProfiler()
44
+
45
+ if not DIFFRAX_AVAILABLE:
46
+ raise ImportError(
47
+ "CahnHilliardInversionModel requires the optional JAX"
48
+ " dependencies (jax, diffrax)."
49
+ )
50
+
51
+ def solve(self, parameters, y0, saveat, adjoint=dfx.ForwardMode(), dt0=0.1):
52
+ """Integrate the Cahn--Hilliard equation for a given parameter set.
53
+
54
+ Args:
55
+ parameters (dict): Dictionary containing the material parameters to
56
+ solve with.
57
+ y0 (array-like): Initial concentration field.
58
+ saveat (:class:`diffrax.SaveAt`): Time points at which the solution
59
+ should be stored.
60
+ adjoint: Differentiation mode used by :func:`diffrax.diffeqsolve`.
61
+ dt0 (float): Initial step size for the time integrator.
62
+
63
+ Returns:
64
+ jax.Array: Array of saved concentration fields with shape
65
+ ``(len(saveat.ts), Nx, Ny, Nz)``.
66
+ """
67
+ u = self.vg.init_scalar_field(y0)
68
+ u = self.vg.bc.trim_boundary_nodes(u)
69
+ if self.pos_params:
70
+ parameters = {k: jnp.exp(v) if k in self.pos_params else v for k, v in parameters.items()}
71
+ problem = self.problem_cls(self.vg, **self.problem_kwargs, **parameters)
72
+ solver = PseudoSpectralIMEX_dfx(problem.fourier_symbol)
73
+
74
+ solution = dfx.diffeqsolve(
75
+ dfx.ODETerm(lambda t, y, args: problem.rhs(y, t)),
76
+ solver,
77
+ t0=saveat.subs.ts[0],
78
+ t1=saveat.subs.ts[-1],
79
+ dt0=dt0,
80
+ y0=u,
81
+ saveat=saveat,
82
+ max_steps=100000,
83
+ throw=False,
84
+ adjoint=adjoint,
85
+ )
86
+ padded = problem.pad_bc(solution.ys[:, 0])
87
+ out = self.vg.bc.trim_ghost_nodes(padded)
88
+ return out
89
+
90
+ def forward_solve(self, parameters, fieldname, saveat, dt0=0.1, verbose=True):
91
+ start = timer()
92
+ u0 = self.vf.fields[fieldname]
93
+ if self.pos_params:
94
+ parameters = {k: jnp.log(v) if k in self.pos_params else v for k, v in parameters.items()}
95
+ sol = self.solve(parameters, u0, saveat, dt0=dt0)
96
+ end = timer()
97
+
98
+ self.vf.fields[fieldname] = self.vg.to_numpy(sol[-1])
99
+ if verbose:
100
+ iterations = int(saveat.subs.ts[-1] // dt0)
101
+ self.profiler.print_memory_stats(start, end, iterations)
102
+
103
+ return sol
104
+
105
+ def residuals(self, parameters, y0s__values__saveat, adjoint=dfx.ForwardMode()):
106
+ """Calculate residuals between measured and simulated states.
107
+
108
+ Args:
109
+ parameters (dict): Current estimate of the model parameters.
110
+ y0s__values__saveat (tuple): Tuple ``(y0s, values, saveat)`` where
111
+ ``y0s`` contains the initial states for each sequence, ``values``
112
+ contains the observed states and ``saveat`` specifies the time
113
+ points of these observations.
114
+ adjoint: Differentiation mode for :func:`solve`.
115
+
116
+ Returns:
117
+ jax.Array: Array of residuals with shape matching ``values``.
118
+ """
119
+ y0s, values, saveat = y0s__values__saveat
120
+ solve_ = partial(self.solve, adjoint=adjoint)
121
+ batch_solve = jax.vmap(solve_, in_axes=(None, 0, None))
122
+ pred_values = batch_solve(parameters, y0s, saveat)
123
+ residuals = values - pred_values[:, 1:]
124
+ return residuals
125
+
126
+ def train(
127
+ self,
128
+ initial_parameters,
129
+ data,
130
+ inds,
131
+ adjoint=dfx.ForwardMode(),
132
+ rtol=1e-6,
133
+ atol=1e-6,
134
+ verbose=True,
135
+ max_steps=1000,
136
+ ):
137
+ """Fit ``parameters`` so that the model matches observed data.
138
+
139
+ This method assembles the observed sequences into a format suitable for
140
+ :func:`optimistix.least_squares` and then runs a Levenberg--Marquardt
141
+ optimisation to minimise the residuals returned by :func:`residuals`.
142
+
143
+ Args:
144
+ initial_parameters (dict): Initial guess for the parameters to be
145
+ optimised.
146
+ data (dict): Dictionary containing ``"ts"`` (time stamps) and
147
+ ``"ys"`` (concentration fields) as produced by :func:`solve`.
148
+ inds (list[list[int]]): For each sequence, the indices in ``data``
149
+ that should be used for training. All sequences must have the
150
+ same spacing.
151
+ adjoint: Differentiation mode used when evaluating the residuals.
152
+ rtol, atol (float): Tolerances for the optimiser.
153
+ verbose (bool): If ``True``, prints optimisation progress.
154
+ max_steps (int): Maximum number of optimisation steps.
155
+
156
+ Returns:
157
+ optimistix.State: The optimiser state after termination.
158
+ """
159
+ # Get length of first sequence to use as reference
160
+ ref_len = len(inds[0])
161
+ if ref_len < 2:
162
+ raise ValueError("Each sequence in inds must have at least 2 elements")
163
+
164
+ # Get reference spacing from first sequence
165
+ ref_spacing = [inds[0][i + 1] - inds[0][i] for i in range(ref_len - 1)]
166
+
167
+ # Validate all other sequences
168
+ for i, sequence in enumerate(inds):
169
+ if len(sequence) != ref_len:
170
+ raise ValueError(
171
+ f"Sequence {i} has different length than first sequence"
172
+ )
173
+
174
+ # Check spacing
175
+ spacing = [sequence[j + 1] - sequence[j] for j in range(len(sequence) - 1)]
176
+ if spacing != ref_spacing:
177
+ raise ValueError(
178
+ f"Sequence {i} has different spacing than first sequence"
179
+ )
180
+
181
+ # TODO: make data a voxelgrid or voxelfield object
182
+ y0s = jnp.array([data["ys"][ind[0]] for ind in inds])
183
+ values = jnp.array(
184
+ [
185
+ jnp.array([data["ys"][ind[i]] for i in range(1, len(ind))])
186
+ for ind in inds
187
+ ]
188
+ )
189
+ saveat = dfx.SaveAt(
190
+ ts=jnp.array(
191
+ [0.0]
192
+ + [
193
+ data["ts"][inds[0][i]] - data["ts"][inds[0][0]]
194
+ for i in range(1, len(inds[0]))
195
+ ]
196
+ )
197
+ )
198
+
199
+ args = (y0s, values, saveat)
200
+ residuals_ = partial(self.residuals, adjoint=adjoint)
201
+
202
+ if self.pos_params:
203
+ # Ensure parameters are positive and take log
204
+ for key in self.pos_params:
205
+ if initial_parameters[key] <= 0:
206
+ raise ValueError(f"Parameter {key} must be positive")
207
+ initial_parameters[key] = jnp.log(initial_parameters[key])
208
+
209
+ solver = optx.LevenbergMarquardt(
210
+ rtol=rtol,
211
+ atol=atol,
212
+ verbose=frozenset(
213
+ {"step", "accepted", "loss", "step_size"} if verbose else None
214
+ ),
215
+ )
216
+
217
+ sol = optx.least_squares(
218
+ residuals_,
219
+ solver,
220
+ initial_parameters,
221
+ args=args,
222
+ max_steps=max_steps,
223
+ throw=False,
224
+ )
225
+
226
+ res = sol.value
227
+
228
+ if self.pos_params:
229
+ # Ensure parameters are positive and take exp
230
+ for key in self.pos_params:
231
+ res[key] = jnp.exp(res[key])
232
+
233
+ return res
@@ -0,0 +1 @@
1
+ """Exports for the precompiled solver package."""
@@ -0,0 +1,50 @@
1
+ from ..problem_definition import AllenCahnEquation
2
+ from ..solvers import TimeDependentSolver
3
+ from ..timesteppers import ForwardEuler
4
+ from typing import Callable
5
+
6
+ def run_allen_cahn_solver(
7
+ voxelfields,
8
+ fieldnames: str | list[str],
9
+ backend: str,
10
+ jit: bool = True,
11
+ device: str = "cuda",
12
+ time_increment: float = 0.1,
13
+ frames: int = 10,
14
+ max_iters: int = 100,
15
+ eps: float = 2.0,
16
+ gab: float = 1.0,
17
+ M: float = 1.0,
18
+ force: float = 0.0,
19
+ curvature: float = 0.01,
20
+ potential: Callable | None = None,
21
+ vtk_out: bool = False,
22
+ verbose: bool = True,
23
+ plot_bounds = None,
24
+ ):
25
+ """
26
+ Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
27
+ """
28
+ solver = TimeDependentSolver(
29
+ voxelfields,
30
+ fieldnames,
31
+ backend,
32
+ problem_cls = AllenCahnEquation,
33
+ timestepper_cls = ForwardEuler,
34
+ device=device,
35
+ )
36
+ solver.solve(
37
+ time_increment=time_increment,
38
+ frames=frames,
39
+ max_iters=max_iters,
40
+ problem_kwargs={"eps": eps,
41
+ "gab": gab,
42
+ "M": M,
43
+ "force": force,
44
+ "curvature": curvature,
45
+ "potential": potential},
46
+ jit=jit,
47
+ verbose=verbose,
48
+ vtk_out=vtk_out,
49
+ plot_bounds=plot_bounds,
50
+ )
@@ -0,0 +1,42 @@
1
+ from ..problem_definition import PeriodicCahnHilliard
2
+ from ..solvers import TimeDependentSolver
3
+ from ..timesteppers import PseudoSpectralIMEX
4
+ from typing import Callable
5
+
6
+ def run_cahn_hilliard_solver(
7
+ voxelfields,
8
+ fieldnames: str | list[str],
9
+ backend: str,
10
+ jit: bool = True,
11
+ device: str = "cuda",
12
+ time_increment: float = 0.1,
13
+ frames: int = 10,
14
+ max_iters: int = 100,
15
+ eps: float = 3.0,
16
+ diffusivity: float = 1.0,
17
+ mu_hom: Callable | None = None,
18
+ vtk_out: bool = False,
19
+ verbose: bool = True,
20
+ plot_bounds = None,
21
+ ):
22
+ """
23
+ Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
24
+ """
25
+ solver = TimeDependentSolver(
26
+ voxelfields,
27
+ fieldnames,
28
+ backend,
29
+ problem_cls = PeriodicCahnHilliard,
30
+ timestepper_cls = PseudoSpectralIMEX,
31
+ device=device,
32
+ )
33
+ solver.solve(
34
+ time_increment=time_increment,
35
+ frames=frames,
36
+ max_iters=max_iters,
37
+ problem_kwargs={"eps": eps, "D": diffusivity, "mu_hom": mu_hom},
38
+ jit=jit,
39
+ verbose=verbose,
40
+ vtk_out=vtk_out,
41
+ plot_bounds=plot_bounds,
42
+ )