evoxels 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evoxels/__init__.py +13 -0
- evoxels/boundary_conditions.py +138 -0
- evoxels/fd_stencils.py +103 -0
- evoxels/function_approximators.py +97 -0
- evoxels/inversion.py +233 -0
- evoxels/precompiled_solvers/__init__.py +1 -0
- evoxels/precompiled_solvers/allen_cahn.py +50 -0
- evoxels/precompiled_solvers/cahn_hilliard.py +42 -0
- evoxels/problem_definition.py +450 -0
- evoxels/profiler.py +94 -0
- evoxels/solvers.py +134 -0
- evoxels/timesteppers.py +119 -0
- evoxels/utils.py +124 -0
- evoxels/voxelfields.py +318 -0
- evoxels/voxelgrid.py +278 -0
- evoxels-0.1.0.dist-info/METADATA +171 -0
- evoxels-0.1.0.dist-info/RECORD +20 -0
- evoxels-0.1.0.dist-info/WHEEL +5 -0
- evoxels-0.1.0.dist-info/licenses/LICENSE +21 -0
- evoxels-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
import sympy as sp
|
|
5
|
+
import sympy.vector as spv
|
|
6
|
+
import warnings
|
|
7
|
+
from .voxelgrid import VoxelGrid
|
|
8
|
+
|
|
9
|
+
# Shorthands in slicing logic
|
|
10
|
+
__ = slice(None) # all elements [:]
|
|
11
|
+
_i_ = slice(1, -1) # inner elements [1:-1]
|
|
12
|
+
|
|
13
|
+
class ODE(ABC):
|
|
14
|
+
@property
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def order(self):
|
|
17
|
+
"""Spatial order of convergence for numerical right-hand side."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def rhs_analytic(self, u, t):
|
|
22
|
+
"""Sympy expression of the problem right-hand side.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
u : Sympy function of current state.
|
|
26
|
+
t (float): Current time.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Sympy function of problem right-hand side.
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def rhs(self, u, t):
|
|
35
|
+
"""Numerical right-hand side of the ODE system.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
u (array): Current state.
|
|
39
|
+
t (float): Current time.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Same type as ``u`` containing the time derivative.
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def bc_type(self) -> str:
|
|
49
|
+
"""E.g. 'periodic', 'dirichlet', or 'neumann'."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def pad_bc(self, u):
|
|
54
|
+
"""Function to pad and impose boundary conditions.
|
|
55
|
+
|
|
56
|
+
Enables applying boundary conditions on u within and
|
|
57
|
+
outside of the right-hand-side function.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
u : field
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Field padded with boundary values.
|
|
64
|
+
"""
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SemiLinearODE(ODE):
|
|
69
|
+
@property
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def fourier_symbol(self):
|
|
72
|
+
"""Symbol of the highest order spatial operator
|
|
73
|
+
|
|
74
|
+
The symbol of an operator is its representation in the
|
|
75
|
+
Fourier (spectral) domain. For instance the:
|
|
76
|
+
- Laplacian operator $\nabla^2$ has a symbol $-k^2$,
|
|
77
|
+
- diffusion operator $D\nabla^2$ corresponds to $-k^2D$
|
|
78
|
+
|
|
79
|
+
The symbol is required for pseudo-spectral timesteppers.
|
|
80
|
+
"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class SmoothedBoundaryODE(ODE):
|
|
85
|
+
@property
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def mask(self) -> Any | float:
|
|
88
|
+
"""A field (same shape as the state) that remains fixed."""
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class ReactionDiffusion(SemiLinearODE):
|
|
94
|
+
vg: VoxelGrid
|
|
95
|
+
D: float
|
|
96
|
+
BC_type: str
|
|
97
|
+
bcs: tuple = (0,0)
|
|
98
|
+
f: Callable | None = None
|
|
99
|
+
A: float = 0.25
|
|
100
|
+
_fourier_symbol: Any = field(init=False, repr=False)
|
|
101
|
+
|
|
102
|
+
def __post_init__(self):
|
|
103
|
+
"""Precompute factors required by the spectral solver."""
|
|
104
|
+
if self.f is None:
|
|
105
|
+
self.f = lambda c=None, t=None, lib=None: 0
|
|
106
|
+
|
|
107
|
+
if self.BC_type == 'periodic':
|
|
108
|
+
bc_fun = self.vg.bc.pad_periodic
|
|
109
|
+
self.pad_boundary = lambda field, bc0, bc1: bc_fun(field)
|
|
110
|
+
k_squared = self.vg.rfft_k_squared()
|
|
111
|
+
elif self.BC_type == 'dirichlet':
|
|
112
|
+
self.pad_boundary = self.vg.bc.pad_dirichlet_periodic
|
|
113
|
+
if self.vg.convention == 'cell_center':
|
|
114
|
+
warnings.warn(
|
|
115
|
+
"Applying Dirichlet BCs on a cell_center grid "
|
|
116
|
+
"reduces the spatial order of convergence to 0.5!"
|
|
117
|
+
)
|
|
118
|
+
k_squared = self.vg.fft_k_squared_nonperiodic()
|
|
119
|
+
elif self.BC_type == 'neumann':
|
|
120
|
+
bc_fun = self.vg.bc.pad_zero_flux_periodic
|
|
121
|
+
self.pad_boundary = lambda field, bc0, bc1: bc_fun(field)
|
|
122
|
+
k_squared = self.vg.fft_k_squared_nonperiodic()
|
|
123
|
+
|
|
124
|
+
self._fourier_symbol = -self.D * self.A * k_squared
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def order(self):
|
|
128
|
+
return 2
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def fourier_symbol(self):
|
|
132
|
+
return self._fourier_symbol
|
|
133
|
+
|
|
134
|
+
def _eval_f(self, c, t, lib):
|
|
135
|
+
"""Evaluate source/forcing term using ``self.f``."""
|
|
136
|
+
try:
|
|
137
|
+
return self.f(c, t, lib)
|
|
138
|
+
except TypeError:
|
|
139
|
+
return self.f(c, t)
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def bc_type(self):
|
|
143
|
+
return self.BC_type
|
|
144
|
+
|
|
145
|
+
def pad_bc(self, u):
|
|
146
|
+
return self.pad_boundary(u, self.bcs[0], self.bcs[1])
|
|
147
|
+
|
|
148
|
+
def rhs_analytic(self, u, t):
|
|
149
|
+
return self.D*spv.laplacian(u) + self._eval_f(u, t, sp)
|
|
150
|
+
|
|
151
|
+
def rhs(self, u, t):
|
|
152
|
+
laplace = self.vg.laplace(self.pad_bc(u))
|
|
153
|
+
update = self.D * laplace + self._eval_f(u, t, self.vg.lib)
|
|
154
|
+
return update
|
|
155
|
+
|
|
156
|
+
@dataclass
|
|
157
|
+
class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
|
|
158
|
+
mask: Any | None = None
|
|
159
|
+
bc_flux: Callable | float = 0.0
|
|
160
|
+
|
|
161
|
+
def __post_init__(self):
|
|
162
|
+
super().__post_init__()
|
|
163
|
+
if self.mask is None:
|
|
164
|
+
self.mask = self.vg.lib.ones(self.vg.shape)
|
|
165
|
+
self.mask = self.vg.init_scalar_field(self.mask)
|
|
166
|
+
self.mask = self.vg.pad_periodic(\
|
|
167
|
+
self.vg.bc.trim_boundary_nodes(self.mask))
|
|
168
|
+
self.norm = 1.0
|
|
169
|
+
else:
|
|
170
|
+
self.mask = self.vg.init_scalar_field(self.mask)
|
|
171
|
+
mask_0 = self.mask[:,0,:,:]
|
|
172
|
+
mask_1 = self.mask[:,-1,:,:]
|
|
173
|
+
self.mask = self.vg.pad_periodic(\
|
|
174
|
+
self.vg.bc.trim_boundary_nodes(self.mask))
|
|
175
|
+
if self.BC_type != 'periodic':
|
|
176
|
+
self.mask = self.vg.set(self.mask, (__, 0,_i_,_i_), mask_0)
|
|
177
|
+
self.mask = self.vg.set(self.mask, (__,-1,_i_,_i_), mask_1)
|
|
178
|
+
|
|
179
|
+
self.norm = self.vg.lib.sqrt(self.vg.gradient_norm_squared(self.mask))
|
|
180
|
+
self.mask = self.vg.lib.clip(self.mask, 1e-4, 1)
|
|
181
|
+
|
|
182
|
+
self.bcs = (self.bcs[0] * self.mask[:,0,:,:],
|
|
183
|
+
self.bcs[1] * self.mask[:,-1,:,:])
|
|
184
|
+
|
|
185
|
+
def pad_bc(self, u):
|
|
186
|
+
return self.pad_boundary(u, self.bcs[0], self.bcs[1])
|
|
187
|
+
|
|
188
|
+
def rhs_analytic(self, mask, u, t):
|
|
189
|
+
grad = spv.gradient(u)
|
|
190
|
+
norm_grad = sp.sqrt(grad.dot(grad))
|
|
191
|
+
|
|
192
|
+
divergence = spv.divergence(self.D*(grad - u/mask*spv.gradient(mask)))
|
|
193
|
+
du = divergence + norm_grad*self.bc_flux + mask*self._eval_f(u/mask, t, sp)
|
|
194
|
+
return du
|
|
195
|
+
|
|
196
|
+
def rhs(self, u, t):
|
|
197
|
+
z = self.pad_bc(u)
|
|
198
|
+
divergence = self.vg.grad_x_face(self.vg.grad_x_face(z) -\
|
|
199
|
+
self.vg.to_x_face(z/self.mask) * self.vg.grad_x_face(self.mask)
|
|
200
|
+
)[:,:,1:-1,1:-1]
|
|
201
|
+
divergence += self.vg.grad_y_face(self.vg.grad_y_face(z) -\
|
|
202
|
+
self.vg.to_y_face(z/self.mask) * self.vg.grad_y_face(self.mask)
|
|
203
|
+
)[:,1:-1,:,1:-1]
|
|
204
|
+
divergence += self.vg.grad_z_face(self.vg.grad_z_face(z) -\
|
|
205
|
+
self.vg.to_z_face(z/self.mask) * self.vg.grad_z_face(self.mask)
|
|
206
|
+
)[:,1:-1,1:-1,:]
|
|
207
|
+
|
|
208
|
+
update = self.D * divergence + \
|
|
209
|
+
self.norm*self.bc_flux + \
|
|
210
|
+
self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(u/self.mask[:,1:-1,1:-1,1:-1], t, self.vg.lib)
|
|
211
|
+
return update
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dataclass
|
|
215
|
+
class PeriodicCahnHilliard(SemiLinearODE):
|
|
216
|
+
vg: VoxelGrid
|
|
217
|
+
eps: float = 3.0
|
|
218
|
+
D: float = 1.0
|
|
219
|
+
mu_hom: Callable | None = None
|
|
220
|
+
A: float = 0.25
|
|
221
|
+
_fourier_symbol: Any = field(init=False, repr=False)
|
|
222
|
+
|
|
223
|
+
def __post_init__(self):
|
|
224
|
+
"""Precompute factors required by the spectral solver."""
|
|
225
|
+
k_squared = self.vg.rfft_k_squared()
|
|
226
|
+
self._fourier_symbol = -2 * self.eps * self.D * self.A * k_squared**2
|
|
227
|
+
if self.mu_hom is None:
|
|
228
|
+
self.mu_hom = lambda c, lib=None: 18 / self.eps * c * (1 - c) * (1 - 2 * c)
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def order(self):
|
|
232
|
+
return 2
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def fourier_symbol(self):
|
|
236
|
+
return self._fourier_symbol
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def bc_type(self):
|
|
240
|
+
return 'periodic'
|
|
241
|
+
|
|
242
|
+
def pad_bc(self, u):
|
|
243
|
+
return self.vg.bc.pad_periodic(u)
|
|
244
|
+
|
|
245
|
+
def _eval_mu(self, c, lib):
|
|
246
|
+
"""Evaluate homogeneous chemical potential using ``self.mu``."""
|
|
247
|
+
try:
|
|
248
|
+
return self.mu_hom(c, lib)
|
|
249
|
+
except TypeError:
|
|
250
|
+
return self.mu_hom(c)
|
|
251
|
+
|
|
252
|
+
def rhs_analytic(self, c, t):
|
|
253
|
+
mu = self._eval_mu(c, sp) - 2*self.eps*spv.laplacian(c)
|
|
254
|
+
fluxes = self.D*c*(1-c)*spv.gradient(mu)
|
|
255
|
+
rhs = spv.divergence(fluxes)
|
|
256
|
+
return rhs
|
|
257
|
+
|
|
258
|
+
def rhs(self, c, t):
|
|
259
|
+
r"""Evaluate :math:`\partial c / \partial t` for the CH equation.
|
|
260
|
+
|
|
261
|
+
Numerical computation of
|
|
262
|
+
|
|
263
|
+
.. math::
|
|
264
|
+
\frac{\partial c}{\partial t}
|
|
265
|
+
= \nabla \cdot \bigl( M \, \nabla \mu \bigr),
|
|
266
|
+
\quad
|
|
267
|
+
\mu = \frac{\delta F}{\delta c}
|
|
268
|
+
= f'(c) - \kappa \, \nabla^2 c
|
|
269
|
+
|
|
270
|
+
where :math:`M` is the (possibly concentration-dependent) mobility,
|
|
271
|
+
:math:`\mu` the chemical potential, and :math:`\kappa` the gradient energy coefficient.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
c (array-like): Concentration field.
|
|
275
|
+
t (float): Current time.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Backend array of the same shape as ``c`` containing ``dc/dt``.
|
|
279
|
+
"""
|
|
280
|
+
c = self.vg.lib.clip(c, 0, 1)
|
|
281
|
+
c_BC = self.pad_bc(c)
|
|
282
|
+
laplace = self.vg.laplace(c_BC)
|
|
283
|
+
mu = self._eval_mu(c, self.vg.lib) - 2*self.eps*laplace
|
|
284
|
+
mu = self.pad_bc(mu)
|
|
285
|
+
|
|
286
|
+
divergence = self.vg.grad_x_face(
|
|
287
|
+
self.vg.to_x_face(c_BC) * (1-self.vg.to_x_face(c_BC)) *\
|
|
288
|
+
self.vg.grad_x_face(mu)
|
|
289
|
+
)[:,:,1:-1,1:-1]
|
|
290
|
+
|
|
291
|
+
divergence += self.vg.grad_y_face(
|
|
292
|
+
self.vg.to_y_face(c_BC) * (1-self.vg.to_y_face(c_BC)) *\
|
|
293
|
+
self.vg.grad_y_face(mu)
|
|
294
|
+
)[:,1:-1,:,1:-1]
|
|
295
|
+
|
|
296
|
+
divergence += self.vg.grad_z_face(
|
|
297
|
+
self.vg.to_z_face(c_BC) * (1-self.vg.to_z_face(c_BC)) *\
|
|
298
|
+
self.vg.grad_z_face(mu)
|
|
299
|
+
)[:,1:-1,1:-1,:]
|
|
300
|
+
|
|
301
|
+
return self.D * divergence
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@dataclass
|
|
305
|
+
class AllenCahnEquation(SemiLinearODE):
|
|
306
|
+
vg: VoxelGrid
|
|
307
|
+
eps: float = 2.0
|
|
308
|
+
gab: float = 1.0
|
|
309
|
+
M: float = 1.0
|
|
310
|
+
force: float = 0.0
|
|
311
|
+
curvature: float = 0.01
|
|
312
|
+
potential: Callable | None = None
|
|
313
|
+
_fourier_symbol: Any = field(init=False, repr=False)
|
|
314
|
+
|
|
315
|
+
def __post_init__(self):
|
|
316
|
+
"""Precompute factors required by the spectral solver."""
|
|
317
|
+
k_squared = self.vg.rfft_k_squared()
|
|
318
|
+
self._fourier_symbol = -2 * self.M * self.gab* k_squared
|
|
319
|
+
if self.potential is None:
|
|
320
|
+
self.potential = lambda u, lib=None: 18 / self.eps * u * (1-u) * (1-2*u)
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def order(self):
|
|
324
|
+
return 2
|
|
325
|
+
|
|
326
|
+
@property
|
|
327
|
+
def fourier_symbol(self):
|
|
328
|
+
return self._fourier_symbol
|
|
329
|
+
|
|
330
|
+
@property
|
|
331
|
+
def bc_type(self):
|
|
332
|
+
return 'neumann'
|
|
333
|
+
|
|
334
|
+
def pad_bc(self, u):
|
|
335
|
+
return self.vg.bc.pad_zero_flux(u)
|
|
336
|
+
|
|
337
|
+
def _eval_potential(self, phi, lib):
|
|
338
|
+
"""Evaluate phasefield potential"""
|
|
339
|
+
try:
|
|
340
|
+
return self.potential(phi, lib)
|
|
341
|
+
except TypeError:
|
|
342
|
+
return self.potential(phi)
|
|
343
|
+
|
|
344
|
+
def rhs_analytic(self, phi, t):
|
|
345
|
+
grad = spv.gradient(phi)
|
|
346
|
+
laplace = spv.laplacian(phi)
|
|
347
|
+
norm_grad = sp.sqrt(grad.dot(grad))
|
|
348
|
+
|
|
349
|
+
# Curvature equals |∇ψ| ∇·(∇ψ/|∇ψ|)
|
|
350
|
+
unit_normal = grad / norm_grad
|
|
351
|
+
curv = norm_grad * spv.divergence(unit_normal)
|
|
352
|
+
n_laplace = laplace - (1-self.curvature)*curv
|
|
353
|
+
df_dphi = self.gab * (2*n_laplace - self._eval_potential(phi, sp)/self.eps) \
|
|
354
|
+
+ 3/self.eps * phi * (1-phi) * self.force
|
|
355
|
+
return self.M * df_dphi
|
|
356
|
+
|
|
357
|
+
def rhs(self, phi, t):
|
|
358
|
+
r"""Two-phase Allen-Cahn equation
|
|
359
|
+
|
|
360
|
+
Microstructural evolution of the order parameter ``\phi``
|
|
361
|
+
which can be interpreted as a phase fraction.
|
|
362
|
+
:math:`M` denotes the mobility,
|
|
363
|
+
:math:`\epsilon` controls the diffuse interface width,
|
|
364
|
+
:math:`\gamma` denotes the interfacial energy.
|
|
365
|
+
The laplacian leads to a phase evolution driven by
|
|
366
|
+
curvature minimization which can be controlled by setting
|
|
367
|
+
``curvature=`` in range :math:`[0,1]`.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
phi (array-like): order parameter.
|
|
371
|
+
t (float): Current time.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
Backend array of the same shape as ``\phi`` containing ``d\phi/dt``.
|
|
375
|
+
"""
|
|
376
|
+
phi = self.vg.lib.clip(phi, 0, 1)
|
|
377
|
+
potential = self._eval_potential(phi, self.vg.lib)
|
|
378
|
+
phi_pad = self.pad_bc(phi)
|
|
379
|
+
laplace = self.curvature*self.vg.laplace(phi_pad)
|
|
380
|
+
n_laplace = (1-self.curvature) * self.vg.normal_laplace(phi_pad)
|
|
381
|
+
df_dphi = self.gab * (2.0 * (laplace+n_laplace) - potential/self.eps)\
|
|
382
|
+
+ 3/self.eps * phi * (1-phi) * self.force
|
|
383
|
+
return self.M * df_dphi
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
@dataclass
|
|
387
|
+
class CoupledReactionDiffusion(SemiLinearODE):
|
|
388
|
+
vg: VoxelGrid
|
|
389
|
+
D_A: float = 1.0
|
|
390
|
+
D_B: float = 0.5
|
|
391
|
+
feed: float = 0.055
|
|
392
|
+
kill: float = 0.117
|
|
393
|
+
interaction: Callable | None = None
|
|
394
|
+
_fourier_symbol: Any = field(init=False, repr=False)
|
|
395
|
+
|
|
396
|
+
def __post_init__(self):
|
|
397
|
+
"""Precompute factors required by the spectral solver."""
|
|
398
|
+
k_squared = self.vg.rfft_k_squared()
|
|
399
|
+
self._fourier_symbol = - max(self.D_A, self.D_B) * k_squared
|
|
400
|
+
if self.interaction is None:
|
|
401
|
+
self.interaction = lambda u, lib=None: u[0] * u[1]**2
|
|
402
|
+
|
|
403
|
+
@property
|
|
404
|
+
def order(self):
|
|
405
|
+
return 2
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def fourier_symbol(self):
|
|
409
|
+
return self._fourier_symbol
|
|
410
|
+
|
|
411
|
+
@property
|
|
412
|
+
def bc_type(self):
|
|
413
|
+
return 'periodic'
|
|
414
|
+
|
|
415
|
+
def pad_bc(self, u):
|
|
416
|
+
return self.vg.bc.pad_periodic(u)
|
|
417
|
+
|
|
418
|
+
def _eval_interaction(self, u, lib):
|
|
419
|
+
"""Evaluate interaction term"""
|
|
420
|
+
try:
|
|
421
|
+
return self.interaction(u, lib)
|
|
422
|
+
except TypeError:
|
|
423
|
+
return self.interaction(u)
|
|
424
|
+
|
|
425
|
+
def rhs_analytic(self, u, t):
|
|
426
|
+
interaction = self._eval_interaction(u, sp)
|
|
427
|
+
dc_A = self.D_A*spv.laplacian(u[0]) - interaction + self.feed * (1-u[0])
|
|
428
|
+
dc_B = self.D_B*spv.laplacian(u[1]) + interaction - self.kill * u[1]
|
|
429
|
+
return (dc_A, dc_B)
|
|
430
|
+
|
|
431
|
+
def rhs(self, u, t):
|
|
432
|
+
r"""Two-component reaction-diffusion system
|
|
433
|
+
|
|
434
|
+
Use batch channels for multiple species:
|
|
435
|
+
- Species A with concentration c_A = u[0]
|
|
436
|
+
- Species B with concentration c_B = u[1]
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
u (array-like): species
|
|
440
|
+
t (float): Current time.
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
Backend array of the same shape as ``u`` containing ``du/dt``.
|
|
444
|
+
"""
|
|
445
|
+
interaction = self._eval_interaction(u, self.vg.lib)
|
|
446
|
+
u_pad = self.pad_bc(u)
|
|
447
|
+
laplace = self.vg.laplace(u_pad)
|
|
448
|
+
dc_A = self.D_A*laplace[0] - interaction + self.feed * (1-u[0])
|
|
449
|
+
dc_B = self.D_B*laplace[1] + interaction - self.kill * u[1]
|
|
450
|
+
return self.vg.lib.stack((dc_A, dc_B), 0)
|
evoxels/profiler.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import psutil
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
import tracemalloc
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
class MemoryProfiler(ABC):
|
|
9
|
+
"""Base interface for tracking host and device memory usage."""
|
|
10
|
+
def get_cuda_memory_from_nvidia_smi(self):
|
|
11
|
+
"""Return currently used CUDA memory in megabytes."""
|
|
12
|
+
try:
|
|
13
|
+
output = subprocess.check_output(
|
|
14
|
+
['nvidia-smi', '--query-gpu=memory.used',
|
|
15
|
+
'--format=csv,nounits,noheader'],
|
|
16
|
+
encoding='utf-8'
|
|
17
|
+
)
|
|
18
|
+
return int(output.strip().split('\n')[0])
|
|
19
|
+
except Exception as e:
|
|
20
|
+
print(f"Error tracking memory with nvidia-smi: {e}")
|
|
21
|
+
|
|
22
|
+
def update_memory_stats(self):
|
|
23
|
+
process = psutil.Process(os.getpid())
|
|
24
|
+
used_cpu = process.memory_info().rss / 1024**2
|
|
25
|
+
self.max_used_cpu = np.max((self.max_used_cpu, used_cpu))
|
|
26
|
+
used = self.get_cuda_memory_from_nvidia_smi()
|
|
27
|
+
self.max_used_gpu = np.max((self.max_used_gpu, used))
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def print_memory_stats(self, start: float, end: float, iters: int):
|
|
31
|
+
"""Print profiling summary after a simulation run."""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
class TorchMemoryProfiler(MemoryProfiler):
|
|
35
|
+
def __init__(self, device):
|
|
36
|
+
"""Initialize the profiler for a given torch device."""
|
|
37
|
+
import torch
|
|
38
|
+
self.torch = torch
|
|
39
|
+
self.device = device
|
|
40
|
+
tracemalloc.start()
|
|
41
|
+
if device.type == 'cuda':
|
|
42
|
+
torch.cuda.reset_peak_memory_stats(device=device)
|
|
43
|
+
self.max_used_gpu = 0
|
|
44
|
+
self.max_used_cpu = 0
|
|
45
|
+
|
|
46
|
+
def print_memory_stats(self, start, end, iters):
|
|
47
|
+
"""Print usage statistics for the Torch backend."""
|
|
48
|
+
print(f'Wall time: {np.around(end - start, 4)} s after {iters} iterations '
|
|
49
|
+
f'({np.around((end - start)/iters, 4)} s/iter)')
|
|
50
|
+
|
|
51
|
+
if self.device.type == 'cpu':
|
|
52
|
+
current, peak = tracemalloc.get_traced_memory()
|
|
53
|
+
print(f"CPU-RAM (tracemalloc) current: {current / 1024**2:.2f} MB ({peak / 1024**2:.2f} MB max)")
|
|
54
|
+
tracemalloc.stop()
|
|
55
|
+
|
|
56
|
+
process = psutil.Process(os.getpid())
|
|
57
|
+
current = process.memory_info().rss / 1024**2
|
|
58
|
+
print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")
|
|
59
|
+
|
|
60
|
+
elif self.device.type == 'cuda':
|
|
61
|
+
self.update_memory_stats()
|
|
62
|
+
used = self.get_cuda_memory_from_nvidia_smi()
|
|
63
|
+
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
|
|
64
|
+
print(f"GPU-RAM (torch) current: "
|
|
65
|
+
f"{self.torch.cuda.memory_allocated(self.device) / 1024**2:.2f} MB "
|
|
66
|
+
f"({self.torch.cuda.max_memory_allocated(self.device) / 1024**2:.2f} MB max, "
|
|
67
|
+
f"{self.torch.cuda.max_memory_reserved(self.device) / 1024**2:.2f} MB reserved)")
|
|
68
|
+
|
|
69
|
+
class JAXMemoryProfiler(MemoryProfiler):
|
|
70
|
+
def __init__(self):
|
|
71
|
+
"""Initialize the profiler for JAX."""
|
|
72
|
+
import jax
|
|
73
|
+
self.jax = jax
|
|
74
|
+
self.max_used_gpu = 0
|
|
75
|
+
self.max_used_cpu = 0
|
|
76
|
+
tracemalloc.start()
|
|
77
|
+
|
|
78
|
+
def print_memory_stats(self, start, end, iters):
|
|
79
|
+
"""Print usage statistics for the JAX backend."""
|
|
80
|
+
print(f'Wall time: {np.around(end - start, 4)} s after {iters} iterations '
|
|
81
|
+
f'({np.around((end - start)/iters, 4)} s/iter)')
|
|
82
|
+
|
|
83
|
+
current, peak = tracemalloc.get_traced_memory()
|
|
84
|
+
print(f"CPU-RAM (tracemalloc) current: {current / 1024**2:.2f} MB ({peak / 1024**2:.2f} MB max)")
|
|
85
|
+
tracemalloc.stop()
|
|
86
|
+
|
|
87
|
+
process = psutil.Process(os.getpid())
|
|
88
|
+
current = process.memory_info().rss / 1024**2
|
|
89
|
+
print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")
|
|
90
|
+
|
|
91
|
+
if self.jax.default_backend() == 'gpu':
|
|
92
|
+
self.update_memory_stats()
|
|
93
|
+
used = self.get_cuda_memory_from_nvidia_smi()
|
|
94
|
+
print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
|
evoxels/solvers.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from IPython.display import clear_output
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Any, Type
|
|
4
|
+
from timeit import default_timer as timer
|
|
5
|
+
import sys
|
|
6
|
+
from .problem_definition import ODE
|
|
7
|
+
from .timesteppers import TimeStepper
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TimeDependentSolver:
|
|
11
|
+
"""Generic wrapper for solving one or more fields with a time stepper."""
|
|
12
|
+
vf: Any # VoxelFields object
|
|
13
|
+
fieldnames: str | list[str]
|
|
14
|
+
backend: str
|
|
15
|
+
problem_cls: Type[ODE] | None = None
|
|
16
|
+
timestepper_cls: Type[TimeStepper] | None = None
|
|
17
|
+
step_fn: Callable | None = None
|
|
18
|
+
device: str='cuda'
|
|
19
|
+
|
|
20
|
+
def __post_init__(self):
|
|
21
|
+
"""Initialize backend specific components."""
|
|
22
|
+
if self.backend == 'torch':
|
|
23
|
+
from .voxelgrid import VoxelGridTorch
|
|
24
|
+
from .profiler import TorchMemoryProfiler
|
|
25
|
+
grid = self.vf.grid_info()
|
|
26
|
+
self.vg = VoxelGridTorch(grid, precision=self.vf.precision, device=self.device)
|
|
27
|
+
self.profiler = TorchMemoryProfiler(self.vg.device)
|
|
28
|
+
|
|
29
|
+
elif self.backend == 'jax':
|
|
30
|
+
from .voxelgrid import VoxelGridJax
|
|
31
|
+
from .profiler import JAXMemoryProfiler
|
|
32
|
+
self.vg = VoxelGridJax(self.vf.grid_info(), precision=self.vf.precision)
|
|
33
|
+
self.profiler = JAXMemoryProfiler()
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Unsupported backend: {self.backend}")
|
|
36
|
+
|
|
37
|
+
def solve(
|
|
38
|
+
self,
|
|
39
|
+
time_increment=0.1,
|
|
40
|
+
frames=10,
|
|
41
|
+
max_iters=100,
|
|
42
|
+
problem_kwargs=None,
|
|
43
|
+
jit=True,
|
|
44
|
+
verbose=True,
|
|
45
|
+
vtk_out=False,
|
|
46
|
+
plot_bounds=None,
|
|
47
|
+
colormap='viridis'
|
|
48
|
+
):
|
|
49
|
+
"""Run the time integration loop.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
time_increment (float): Size of a single time step.
|
|
53
|
+
frames (int): Number of output frames (for plotting, vtk, checks).
|
|
54
|
+
max_iters (int): Number of time steps to compute.
|
|
55
|
+
problem_kwargs (dict | None): Problem-specific input arguments.
|
|
56
|
+
jit (bool): Create just-in-time compiled kernel if ``True``
|
|
57
|
+
verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
|
|
58
|
+
updates an interactive plot.
|
|
59
|
+
vtk_out (bool): Write VTK files for each frame if ``True``.
|
|
60
|
+
plot_bounds (tuple | None): Optional value range for plots.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
problem_kwargs = problem_kwargs or {}
|
|
64
|
+
if isinstance(self.fieldnames, str):
|
|
65
|
+
self.fieldnames = [self.fieldnames]
|
|
66
|
+
else:
|
|
67
|
+
self.fieldnames = list(self.fieldnames)
|
|
68
|
+
|
|
69
|
+
u_list = [self.vg.init_scalar_field(self.vf.fields[name]) for name in self.fieldnames]
|
|
70
|
+
u = self.vg.concatenate(u_list, 0)
|
|
71
|
+
u = self.vg.bc.trim_boundary_nodes(u)
|
|
72
|
+
|
|
73
|
+
if self.step_fn is not None:
|
|
74
|
+
step = self.step_fn
|
|
75
|
+
self.problem = None
|
|
76
|
+
else:
|
|
77
|
+
if self.problem_cls is None or self.timestepper_cls is None:
|
|
78
|
+
raise ValueError("Either provide step_fn or both problem_cls and timestepper_cls")
|
|
79
|
+
self.problem = self.problem_cls(self.vg, **problem_kwargs)
|
|
80
|
+
timestepper = self.timestepper_cls(self.problem, time_increment)
|
|
81
|
+
step = timestepper.step
|
|
82
|
+
|
|
83
|
+
# Make use of just-in-time compilation
|
|
84
|
+
if jit and self.backend == 'jax':
|
|
85
|
+
import jax
|
|
86
|
+
step = jax.jit(step)
|
|
87
|
+
elif jit and self.backend == 'torch':
|
|
88
|
+
import torch
|
|
89
|
+
step = torch.compile(step)
|
|
90
|
+
|
|
91
|
+
n_out = max_iters // frames
|
|
92
|
+
frame = 0
|
|
93
|
+
slice_idx = self.vf.Nz // 2
|
|
94
|
+
|
|
95
|
+
start = timer()
|
|
96
|
+
for i in range(max_iters):
|
|
97
|
+
time = i * time_increment
|
|
98
|
+
if i % n_out == 0:
|
|
99
|
+
self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
|
|
100
|
+
frame += 1
|
|
101
|
+
|
|
102
|
+
u = step(u, time)
|
|
103
|
+
|
|
104
|
+
end = timer()
|
|
105
|
+
time = max_iters * time_increment
|
|
106
|
+
self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
|
|
107
|
+
|
|
108
|
+
if verbose:
|
|
109
|
+
self.profiler.print_memory_stats(start, end, max_iters)
|
|
110
|
+
|
|
111
|
+
def _handle_outputs(self, u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap):
|
|
112
|
+
"""Store results and optionally plot or write them to disk."""
|
|
113
|
+
if getattr(self, 'problem', None) is not None:
|
|
114
|
+
u_out = self.vg.bc.trim_ghost_nodes(self.problem.pad_bc(u))
|
|
115
|
+
else:
|
|
116
|
+
u_out = u
|
|
117
|
+
|
|
118
|
+
for i, name in enumerate(self.fieldnames):
|
|
119
|
+
self.vf.fields[name] = self.vg.export_scalar_field_to_numpy(u_out[i:i+1])
|
|
120
|
+
|
|
121
|
+
if verbose:
|
|
122
|
+
self.profiler.update_memory_stats()
|
|
123
|
+
|
|
124
|
+
if self.vg.lib.isnan(u_out).any():
|
|
125
|
+
print(f"NaN detected in frame {frame} at time {time}. Aborting simulation.")
|
|
126
|
+
sys.exit(1)
|
|
127
|
+
|
|
128
|
+
if vtk_out:
|
|
129
|
+
filename = self.problem_cls.__name__ + "_" +\
|
|
130
|
+
self.fieldnames[0] + f"_{frame:03d}.vtk"
|
|
131
|
+
self.vf.export_to_vtk(filename=filename, field_names=self.fieldnames)
|
|
132
|
+
if verbose == 'plot':
|
|
133
|
+
clear_output(wait=True)
|
|
134
|
+
self.vf.plot_slice(self.fieldnames[0], slice_idx, time=time, colormap=colormap, value_bounds=plot_bounds)
|