evoxels 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,450 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Callable
4
+ import sympy as sp
5
+ import sympy.vector as spv
6
+ import warnings
7
+ from .voxelgrid import VoxelGrid
8
+
9
+ # Shorthands in slicing logic
10
+ __ = slice(None) # all elements [:]
11
+ _i_ = slice(1, -1) # inner elements [1:-1]
12
+
13
+ class ODE(ABC):
14
+ @property
15
+ @abstractmethod
16
+ def order(self):
17
+ """Spatial order of convergence for numerical right-hand side."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def rhs_analytic(self, u, t):
22
+ """Sympy expression of the problem right-hand side.
23
+
24
+ Args:
25
+ u : Sympy function of current state.
26
+ t (float): Current time.
27
+
28
+ Returns:
29
+ Sympy function of problem right-hand side.
30
+ """
31
+ pass
32
+
33
+ @abstractmethod
34
+ def rhs(self, u, t):
35
+ """Numerical right-hand side of the ODE system.
36
+
37
+ Args:
38
+ u (array): Current state.
39
+ t (float): Current time.
40
+
41
+ Returns:
42
+ Same type as ``u`` containing the time derivative.
43
+ """
44
+ pass
45
+
46
+ @property
47
+ @abstractmethod
48
+ def bc_type(self) -> str:
49
+ """E.g. 'periodic', 'dirichlet', or 'neumann'."""
50
+ pass
51
+
52
+ @abstractmethod
53
+ def pad_bc(self, u):
54
+ """Function to pad and impose boundary conditions.
55
+
56
+ Enables applying boundary conditions on u within and
57
+ outside of the right-hand-side function.
58
+
59
+ Args:
60
+ u : field
61
+
62
+ Returns:
63
+ Field padded with boundary values.
64
+ """
65
+ pass
66
+
67
+
68
+ class SemiLinearODE(ODE):
69
+ @property
70
+ @abstractmethod
71
+ def fourier_symbol(self):
72
+ """Symbol of the highest order spatial operator
73
+
74
+ The symbol of an operator is its representation in the
75
+ Fourier (spectral) domain. For instance the:
76
+ - Laplacian operator $\nabla^2$ has a symbol $-k^2$,
77
+ - diffusion operator $D\nabla^2$ corresponds to $-k^2D$
78
+
79
+ The symbol is required for pseudo-spectral timesteppers.
80
+ """
81
+ pass
82
+
83
+
84
+ class SmoothedBoundaryODE(ODE):
85
+ @property
86
+ @abstractmethod
87
+ def mask(self) -> Any | float:
88
+ """A field (same shape as the state) that remains fixed."""
89
+ pass
90
+
91
+
92
+ @dataclass
93
+ class ReactionDiffusion(SemiLinearODE):
94
+ vg: VoxelGrid
95
+ D: float
96
+ BC_type: str
97
+ bcs: tuple = (0,0)
98
+ f: Callable | None = None
99
+ A: float = 0.25
100
+ _fourier_symbol: Any = field(init=False, repr=False)
101
+
102
+ def __post_init__(self):
103
+ """Precompute factors required by the spectral solver."""
104
+ if self.f is None:
105
+ self.f = lambda c=None, t=None, lib=None: 0
106
+
107
+ if self.BC_type == 'periodic':
108
+ bc_fun = self.vg.bc.pad_periodic
109
+ self.pad_boundary = lambda field, bc0, bc1: bc_fun(field)
110
+ k_squared = self.vg.rfft_k_squared()
111
+ elif self.BC_type == 'dirichlet':
112
+ self.pad_boundary = self.vg.bc.pad_dirichlet_periodic
113
+ if self.vg.convention == 'cell_center':
114
+ warnings.warn(
115
+ "Applying Dirichlet BCs on a cell_center grid "
116
+ "reduces the spatial order of convergence to 0.5!"
117
+ )
118
+ k_squared = self.vg.fft_k_squared_nonperiodic()
119
+ elif self.BC_type == 'neumann':
120
+ bc_fun = self.vg.bc.pad_zero_flux_periodic
121
+ self.pad_boundary = lambda field, bc0, bc1: bc_fun(field)
122
+ k_squared = self.vg.fft_k_squared_nonperiodic()
123
+
124
+ self._fourier_symbol = -self.D * self.A * k_squared
125
+
126
+ @property
127
+ def order(self):
128
+ return 2
129
+
130
+ @property
131
+ def fourier_symbol(self):
132
+ return self._fourier_symbol
133
+
134
+ def _eval_f(self, c, t, lib):
135
+ """Evaluate source/forcing term using ``self.f``."""
136
+ try:
137
+ return self.f(c, t, lib)
138
+ except TypeError:
139
+ return self.f(c, t)
140
+
141
+ @property
142
+ def bc_type(self):
143
+ return self.BC_type
144
+
145
+ def pad_bc(self, u):
146
+ return self.pad_boundary(u, self.bcs[0], self.bcs[1])
147
+
148
+ def rhs_analytic(self, u, t):
149
+ return self.D*spv.laplacian(u) + self._eval_f(u, t, sp)
150
+
151
+ def rhs(self, u, t):
152
+ laplace = self.vg.laplace(self.pad_bc(u))
153
+ update = self.D * laplace + self._eval_f(u, t, self.vg.lib)
154
+ return update
155
+
156
+ @dataclass
157
+ class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
158
+ mask: Any | None = None
159
+ bc_flux: Callable | float = 0.0
160
+
161
+ def __post_init__(self):
162
+ super().__post_init__()
163
+ if self.mask is None:
164
+ self.mask = self.vg.lib.ones(self.vg.shape)
165
+ self.mask = self.vg.init_scalar_field(self.mask)
166
+ self.mask = self.vg.pad_periodic(\
167
+ self.vg.bc.trim_boundary_nodes(self.mask))
168
+ self.norm = 1.0
169
+ else:
170
+ self.mask = self.vg.init_scalar_field(self.mask)
171
+ mask_0 = self.mask[:,0,:,:]
172
+ mask_1 = self.mask[:,-1,:,:]
173
+ self.mask = self.vg.pad_periodic(\
174
+ self.vg.bc.trim_boundary_nodes(self.mask))
175
+ if self.BC_type != 'periodic':
176
+ self.mask = self.vg.set(self.mask, (__, 0,_i_,_i_), mask_0)
177
+ self.mask = self.vg.set(self.mask, (__,-1,_i_,_i_), mask_1)
178
+
179
+ self.norm = self.vg.lib.sqrt(self.vg.gradient_norm_squared(self.mask))
180
+ self.mask = self.vg.lib.clip(self.mask, 1e-4, 1)
181
+
182
+ self.bcs = (self.bcs[0] * self.mask[:,0,:,:],
183
+ self.bcs[1] * self.mask[:,-1,:,:])
184
+
185
+ def pad_bc(self, u):
186
+ return self.pad_boundary(u, self.bcs[0], self.bcs[1])
187
+
188
+ def rhs_analytic(self, mask, u, t):
189
+ grad = spv.gradient(u)
190
+ norm_grad = sp.sqrt(grad.dot(grad))
191
+
192
+ divergence = spv.divergence(self.D*(grad - u/mask*spv.gradient(mask)))
193
+ du = divergence + norm_grad*self.bc_flux + mask*self._eval_f(u/mask, t, sp)
194
+ return du
195
+
196
+ def rhs(self, u, t):
197
+ z = self.pad_bc(u)
198
+ divergence = self.vg.grad_x_face(self.vg.grad_x_face(z) -\
199
+ self.vg.to_x_face(z/self.mask) * self.vg.grad_x_face(self.mask)
200
+ )[:,:,1:-1,1:-1]
201
+ divergence += self.vg.grad_y_face(self.vg.grad_y_face(z) -\
202
+ self.vg.to_y_face(z/self.mask) * self.vg.grad_y_face(self.mask)
203
+ )[:,1:-1,:,1:-1]
204
+ divergence += self.vg.grad_z_face(self.vg.grad_z_face(z) -\
205
+ self.vg.to_z_face(z/self.mask) * self.vg.grad_z_face(self.mask)
206
+ )[:,1:-1,1:-1,:]
207
+
208
+ update = self.D * divergence + \
209
+ self.norm*self.bc_flux + \
210
+ self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(u/self.mask[:,1:-1,1:-1,1:-1], t, self.vg.lib)
211
+ return update
212
+
213
+
214
+ @dataclass
215
+ class PeriodicCahnHilliard(SemiLinearODE):
216
+ vg: VoxelGrid
217
+ eps: float = 3.0
218
+ D: float = 1.0
219
+ mu_hom: Callable | None = None
220
+ A: float = 0.25
221
+ _fourier_symbol: Any = field(init=False, repr=False)
222
+
223
+ def __post_init__(self):
224
+ """Precompute factors required by the spectral solver."""
225
+ k_squared = self.vg.rfft_k_squared()
226
+ self._fourier_symbol = -2 * self.eps * self.D * self.A * k_squared**2
227
+ if self.mu_hom is None:
228
+ self.mu_hom = lambda c, lib=None: 18 / self.eps * c * (1 - c) * (1 - 2 * c)
229
+
230
+ @property
231
+ def order(self):
232
+ return 2
233
+
234
+ @property
235
+ def fourier_symbol(self):
236
+ return self._fourier_symbol
237
+
238
+ @property
239
+ def bc_type(self):
240
+ return 'periodic'
241
+
242
+ def pad_bc(self, u):
243
+ return self.vg.bc.pad_periodic(u)
244
+
245
+ def _eval_mu(self, c, lib):
246
+ """Evaluate homogeneous chemical potential using ``self.mu``."""
247
+ try:
248
+ return self.mu_hom(c, lib)
249
+ except TypeError:
250
+ return self.mu_hom(c)
251
+
252
+ def rhs_analytic(self, c, t):
253
+ mu = self._eval_mu(c, sp) - 2*self.eps*spv.laplacian(c)
254
+ fluxes = self.D*c*(1-c)*spv.gradient(mu)
255
+ rhs = spv.divergence(fluxes)
256
+ return rhs
257
+
258
+ def rhs(self, c, t):
259
+ r"""Evaluate :math:`\partial c / \partial t` for the CH equation.
260
+
261
+ Numerical computation of
262
+
263
+ .. math::
264
+ \frac{\partial c}{\partial t}
265
+ = \nabla \cdot \bigl( M \, \nabla \mu \bigr),
266
+ \quad
267
+ \mu = \frac{\delta F}{\delta c}
268
+ = f'(c) - \kappa \, \nabla^2 c
269
+
270
+ where :math:`M` is the (possibly concentration-dependent) mobility,
271
+ :math:`\mu` the chemical potential, and :math:`\kappa` the gradient energy coefficient.
272
+
273
+ Args:
274
+ c (array-like): Concentration field.
275
+ t (float): Current time.
276
+
277
+ Returns:
278
+ Backend array of the same shape as ``c`` containing ``dc/dt``.
279
+ """
280
+ c = self.vg.lib.clip(c, 0, 1)
281
+ c_BC = self.pad_bc(c)
282
+ laplace = self.vg.laplace(c_BC)
283
+ mu = self._eval_mu(c, self.vg.lib) - 2*self.eps*laplace
284
+ mu = self.pad_bc(mu)
285
+
286
+ divergence = self.vg.grad_x_face(
287
+ self.vg.to_x_face(c_BC) * (1-self.vg.to_x_face(c_BC)) *\
288
+ self.vg.grad_x_face(mu)
289
+ )[:,:,1:-1,1:-1]
290
+
291
+ divergence += self.vg.grad_y_face(
292
+ self.vg.to_y_face(c_BC) * (1-self.vg.to_y_face(c_BC)) *\
293
+ self.vg.grad_y_face(mu)
294
+ )[:,1:-1,:,1:-1]
295
+
296
+ divergence += self.vg.grad_z_face(
297
+ self.vg.to_z_face(c_BC) * (1-self.vg.to_z_face(c_BC)) *\
298
+ self.vg.grad_z_face(mu)
299
+ )[:,1:-1,1:-1,:]
300
+
301
+ return self.D * divergence
302
+
303
+
304
+ @dataclass
305
+ class AllenCahnEquation(SemiLinearODE):
306
+ vg: VoxelGrid
307
+ eps: float = 2.0
308
+ gab: float = 1.0
309
+ M: float = 1.0
310
+ force: float = 0.0
311
+ curvature: float = 0.01
312
+ potential: Callable | None = None
313
+ _fourier_symbol: Any = field(init=False, repr=False)
314
+
315
+ def __post_init__(self):
316
+ """Precompute factors required by the spectral solver."""
317
+ k_squared = self.vg.rfft_k_squared()
318
+ self._fourier_symbol = -2 * self.M * self.gab* k_squared
319
+ if self.potential is None:
320
+ self.potential = lambda u, lib=None: 18 / self.eps * u * (1-u) * (1-2*u)
321
+
322
+ @property
323
+ def order(self):
324
+ return 2
325
+
326
+ @property
327
+ def fourier_symbol(self):
328
+ return self._fourier_symbol
329
+
330
+ @property
331
+ def bc_type(self):
332
+ return 'neumann'
333
+
334
+ def pad_bc(self, u):
335
+ return self.vg.bc.pad_zero_flux(u)
336
+
337
+ def _eval_potential(self, phi, lib):
338
+ """Evaluate phasefield potential"""
339
+ try:
340
+ return self.potential(phi, lib)
341
+ except TypeError:
342
+ return self.potential(phi)
343
+
344
+ def rhs_analytic(self, phi, t):
345
+ grad = spv.gradient(phi)
346
+ laplace = spv.laplacian(phi)
347
+ norm_grad = sp.sqrt(grad.dot(grad))
348
+
349
+ # Curvature equals |∇ψ| ∇·(∇ψ/|∇ψ|)
350
+ unit_normal = grad / norm_grad
351
+ curv = norm_grad * spv.divergence(unit_normal)
352
+ n_laplace = laplace - (1-self.curvature)*curv
353
+ df_dphi = self.gab * (2*n_laplace - self._eval_potential(phi, sp)/self.eps) \
354
+ + 3/self.eps * phi * (1-phi) * self.force
355
+ return self.M * df_dphi
356
+
357
+ def rhs(self, phi, t):
358
+ r"""Two-phase Allen-Cahn equation
359
+
360
+ Microstructural evolution of the order parameter ``\phi``
361
+ which can be interpreted as a phase fraction.
362
+ :math:`M` denotes the mobility,
363
+ :math:`\epsilon` controls the diffuse interface width,
364
+ :math:`\gamma` denotes the interfacial energy.
365
+ The laplacian leads to a phase evolution driven by
366
+ curvature minimization which can be controlled by setting
367
+ ``curvature=`` in range :math:`[0,1]`.
368
+
369
+ Args:
370
+ phi (array-like): order parameter.
371
+ t (float): Current time.
372
+
373
+ Returns:
374
+ Backend array of the same shape as ``\phi`` containing ``d\phi/dt``.
375
+ """
376
+ phi = self.vg.lib.clip(phi, 0, 1)
377
+ potential = self._eval_potential(phi, self.vg.lib)
378
+ phi_pad = self.pad_bc(phi)
379
+ laplace = self.curvature*self.vg.laplace(phi_pad)
380
+ n_laplace = (1-self.curvature) * self.vg.normal_laplace(phi_pad)
381
+ df_dphi = self.gab * (2.0 * (laplace+n_laplace) - potential/self.eps)\
382
+ + 3/self.eps * phi * (1-phi) * self.force
383
+ return self.M * df_dphi
384
+
385
+
386
+ @dataclass
387
+ class CoupledReactionDiffusion(SemiLinearODE):
388
+ vg: VoxelGrid
389
+ D_A: float = 1.0
390
+ D_B: float = 0.5
391
+ feed: float = 0.055
392
+ kill: float = 0.117
393
+ interaction: Callable | None = None
394
+ _fourier_symbol: Any = field(init=False, repr=False)
395
+
396
+ def __post_init__(self):
397
+ """Precompute factors required by the spectral solver."""
398
+ k_squared = self.vg.rfft_k_squared()
399
+ self._fourier_symbol = - max(self.D_A, self.D_B) * k_squared
400
+ if self.interaction is None:
401
+ self.interaction = lambda u, lib=None: u[0] * u[1]**2
402
+
403
+ @property
404
+ def order(self):
405
+ return 2
406
+
407
+ @property
408
+ def fourier_symbol(self):
409
+ return self._fourier_symbol
410
+
411
+ @property
412
+ def bc_type(self):
413
+ return 'periodic'
414
+
415
+ def pad_bc(self, u):
416
+ return self.vg.bc.pad_periodic(u)
417
+
418
+ def _eval_interaction(self, u, lib):
419
+ """Evaluate interaction term"""
420
+ try:
421
+ return self.interaction(u, lib)
422
+ except TypeError:
423
+ return self.interaction(u)
424
+
425
+ def rhs_analytic(self, u, t):
426
+ interaction = self._eval_interaction(u, sp)
427
+ dc_A = self.D_A*spv.laplacian(u[0]) - interaction + self.feed * (1-u[0])
428
+ dc_B = self.D_B*spv.laplacian(u[1]) + interaction - self.kill * u[1]
429
+ return (dc_A, dc_B)
430
+
431
+ def rhs(self, u, t):
432
+ r"""Two-component reaction-diffusion system
433
+
434
+ Use batch channels for multiple species:
435
+ - Species A with concentration c_A = u[0]
436
+ - Species B with concentration c_B = u[1]
437
+
438
+ Args:
439
+ u (array-like): species
440
+ t (float): Current time.
441
+
442
+ Returns:
443
+ Backend array of the same shape as ``u`` containing ``du/dt``.
444
+ """
445
+ interaction = self._eval_interaction(u, self.vg.lib)
446
+ u_pad = self.pad_bc(u)
447
+ laplace = self.vg.laplace(u_pad)
448
+ dc_A = self.D_A*laplace[0] - interaction + self.feed * (1-u[0])
449
+ dc_B = self.D_B*laplace[1] + interaction - self.kill * u[1]
450
+ return self.vg.lib.stack((dc_A, dc_B), 0)
evoxels/profiler.py ADDED
@@ -0,0 +1,94 @@
1
+ import numpy as np
2
+ import psutil
3
+ import os
4
+ import subprocess
5
+ import tracemalloc
6
+ from abc import ABC, abstractmethod
7
+
8
+ class MemoryProfiler(ABC):
9
+ """Base interface for tracking host and device memory usage."""
10
+ def get_cuda_memory_from_nvidia_smi(self):
11
+ """Return currently used CUDA memory in megabytes."""
12
+ try:
13
+ output = subprocess.check_output(
14
+ ['nvidia-smi', '--query-gpu=memory.used',
15
+ '--format=csv,nounits,noheader'],
16
+ encoding='utf-8'
17
+ )
18
+ return int(output.strip().split('\n')[0])
19
+ except Exception as e:
20
+ print(f"Error tracking memory with nvidia-smi: {e}")
21
+
22
+ def update_memory_stats(self):
23
+ process = psutil.Process(os.getpid())
24
+ used_cpu = process.memory_info().rss / 1024**2
25
+ self.max_used_cpu = np.max((self.max_used_cpu, used_cpu))
26
+ used = self.get_cuda_memory_from_nvidia_smi()
27
+ self.max_used_gpu = np.max((self.max_used_gpu, used))
28
+
29
+ @abstractmethod
30
+ def print_memory_stats(self, start: float, end: float, iters: int):
31
+ """Print profiling summary after a simulation run."""
32
+ pass
33
+
34
+ class TorchMemoryProfiler(MemoryProfiler):
35
+ def __init__(self, device):
36
+ """Initialize the profiler for a given torch device."""
37
+ import torch
38
+ self.torch = torch
39
+ self.device = device
40
+ tracemalloc.start()
41
+ if device.type == 'cuda':
42
+ torch.cuda.reset_peak_memory_stats(device=device)
43
+ self.max_used_gpu = 0
44
+ self.max_used_cpu = 0
45
+
46
+ def print_memory_stats(self, start, end, iters):
47
+ """Print usage statistics for the Torch backend."""
48
+ print(f'Wall time: {np.around(end - start, 4)} s after {iters} iterations '
49
+ f'({np.around((end - start)/iters, 4)} s/iter)')
50
+
51
+ if self.device.type == 'cpu':
52
+ current, peak = tracemalloc.get_traced_memory()
53
+ print(f"CPU-RAM (tracemalloc) current: {current / 1024**2:.2f} MB ({peak / 1024**2:.2f} MB max)")
54
+ tracemalloc.stop()
55
+
56
+ process = psutil.Process(os.getpid())
57
+ current = process.memory_info().rss / 1024**2
58
+ print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")
59
+
60
+ elif self.device.type == 'cuda':
61
+ self.update_memory_stats()
62
+ used = self.get_cuda_memory_from_nvidia_smi()
63
+ print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
64
+ print(f"GPU-RAM (torch) current: "
65
+ f"{self.torch.cuda.memory_allocated(self.device) / 1024**2:.2f} MB "
66
+ f"({self.torch.cuda.max_memory_allocated(self.device) / 1024**2:.2f} MB max, "
67
+ f"{self.torch.cuda.max_memory_reserved(self.device) / 1024**2:.2f} MB reserved)")
68
+
69
+ class JAXMemoryProfiler(MemoryProfiler):
70
+ def __init__(self):
71
+ """Initialize the profiler for JAX."""
72
+ import jax
73
+ self.jax = jax
74
+ self.max_used_gpu = 0
75
+ self.max_used_cpu = 0
76
+ tracemalloc.start()
77
+
78
+ def print_memory_stats(self, start, end, iters):
79
+ """Print usage statistics for the JAX backend."""
80
+ print(f'Wall time: {np.around(end - start, 4)} s after {iters} iterations '
81
+ f'({np.around((end - start)/iters, 4)} s/iter)')
82
+
83
+ current, peak = tracemalloc.get_traced_memory()
84
+ print(f"CPU-RAM (tracemalloc) current: {current / 1024**2:.2f} MB ({peak / 1024**2:.2f} MB max)")
85
+ tracemalloc.stop()
86
+
87
+ process = psutil.Process(os.getpid())
88
+ current = process.memory_info().rss / 1024**2
89
+ print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")
90
+
91
+ if self.jax.default_backend() == 'gpu':
92
+ self.update_memory_stats()
93
+ used = self.get_cuda_memory_from_nvidia_smi()
94
+ print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
evoxels/solvers.py ADDED
@@ -0,0 +1,134 @@
1
+ from IPython.display import clear_output
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Any, Type
4
+ from timeit import default_timer as timer
5
+ import sys
6
+ from .problem_definition import ODE
7
+ from .timesteppers import TimeStepper
8
+
9
+ @dataclass
10
+ class TimeDependentSolver:
11
+ """Generic wrapper for solving one or more fields with a time stepper."""
12
+ vf: Any # VoxelFields object
13
+ fieldnames: str | list[str]
14
+ backend: str
15
+ problem_cls: Type[ODE] | None = None
16
+ timestepper_cls: Type[TimeStepper] | None = None
17
+ step_fn: Callable | None = None
18
+ device: str='cuda'
19
+
20
+ def __post_init__(self):
21
+ """Initialize backend specific components."""
22
+ if self.backend == 'torch':
23
+ from .voxelgrid import VoxelGridTorch
24
+ from .profiler import TorchMemoryProfiler
25
+ grid = self.vf.grid_info()
26
+ self.vg = VoxelGridTorch(grid, precision=self.vf.precision, device=self.device)
27
+ self.profiler = TorchMemoryProfiler(self.vg.device)
28
+
29
+ elif self.backend == 'jax':
30
+ from .voxelgrid import VoxelGridJax
31
+ from .profiler import JAXMemoryProfiler
32
+ self.vg = VoxelGridJax(self.vf.grid_info(), precision=self.vf.precision)
33
+ self.profiler = JAXMemoryProfiler()
34
+ else:
35
+ raise ValueError(f"Unsupported backend: {self.backend}")
36
+
37
+ def solve(
38
+ self,
39
+ time_increment=0.1,
40
+ frames=10,
41
+ max_iters=100,
42
+ problem_kwargs=None,
43
+ jit=True,
44
+ verbose=True,
45
+ vtk_out=False,
46
+ plot_bounds=None,
47
+ colormap='viridis'
48
+ ):
49
+ """Run the time integration loop.
50
+
51
+ Args:
52
+ time_increment (float): Size of a single time step.
53
+ frames (int): Number of output frames (for plotting, vtk, checks).
54
+ max_iters (int): Number of time steps to compute.
55
+ problem_kwargs (dict | None): Problem-specific input arguments.
56
+ jit (bool): Create just-in-time compiled kernel if ``True``
57
+ verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
58
+ updates an interactive plot.
59
+ vtk_out (bool): Write VTK files for each frame if ``True``.
60
+ plot_bounds (tuple | None): Optional value range for plots.
61
+ """
62
+
63
+ problem_kwargs = problem_kwargs or {}
64
+ if isinstance(self.fieldnames, str):
65
+ self.fieldnames = [self.fieldnames]
66
+ else:
67
+ self.fieldnames = list(self.fieldnames)
68
+
69
+ u_list = [self.vg.init_scalar_field(self.vf.fields[name]) for name in self.fieldnames]
70
+ u = self.vg.concatenate(u_list, 0)
71
+ u = self.vg.bc.trim_boundary_nodes(u)
72
+
73
+ if self.step_fn is not None:
74
+ step = self.step_fn
75
+ self.problem = None
76
+ else:
77
+ if self.problem_cls is None or self.timestepper_cls is None:
78
+ raise ValueError("Either provide step_fn or both problem_cls and timestepper_cls")
79
+ self.problem = self.problem_cls(self.vg, **problem_kwargs)
80
+ timestepper = self.timestepper_cls(self.problem, time_increment)
81
+ step = timestepper.step
82
+
83
+ # Make use of just-in-time compilation
84
+ if jit and self.backend == 'jax':
85
+ import jax
86
+ step = jax.jit(step)
87
+ elif jit and self.backend == 'torch':
88
+ import torch
89
+ step = torch.compile(step)
90
+
91
+ n_out = max_iters // frames
92
+ frame = 0
93
+ slice_idx = self.vf.Nz // 2
94
+
95
+ start = timer()
96
+ for i in range(max_iters):
97
+ time = i * time_increment
98
+ if i % n_out == 0:
99
+ self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
100
+ frame += 1
101
+
102
+ u = step(u, time)
103
+
104
+ end = timer()
105
+ time = max_iters * time_increment
106
+ self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
107
+
108
+ if verbose:
109
+ self.profiler.print_memory_stats(start, end, max_iters)
110
+
111
+ def _handle_outputs(self, u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap):
112
+ """Store results and optionally plot or write them to disk."""
113
+ if getattr(self, 'problem', None) is not None:
114
+ u_out = self.vg.bc.trim_ghost_nodes(self.problem.pad_bc(u))
115
+ else:
116
+ u_out = u
117
+
118
+ for i, name in enumerate(self.fieldnames):
119
+ self.vf.fields[name] = self.vg.export_scalar_field_to_numpy(u_out[i:i+1])
120
+
121
+ if verbose:
122
+ self.profiler.update_memory_stats()
123
+
124
+ if self.vg.lib.isnan(u_out).any():
125
+ print(f"NaN detected in frame {frame} at time {time}. Aborting simulation.")
126
+ sys.exit(1)
127
+
128
+ if vtk_out:
129
+ filename = self.problem_cls.__name__ + "_" +\
130
+ self.fieldnames[0] + f"_{frame:03d}.vtk"
131
+ self.vf.export_to_vtk(filename=filename, field_names=self.fieldnames)
132
+ if verbose == 'plot':
133
+ clear_output(wait=True)
134
+ self.vf.plot_slice(self.fieldnames[0], slice_idx, time=time, colormap=colormap, value_bounds=plot_bounds)