evoxels 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evoxels/__init__.py +13 -0
- evoxels/boundary_conditions.py +138 -0
- evoxels/fd_stencils.py +103 -0
- evoxels/function_approximators.py +97 -0
- evoxels/inversion.py +233 -0
- evoxels/precompiled_solvers/__init__.py +1 -0
- evoxels/precompiled_solvers/allen_cahn.py +50 -0
- evoxels/precompiled_solvers/cahn_hilliard.py +42 -0
- evoxels/problem_definition.py +450 -0
- evoxels/profiler.py +94 -0
- evoxels/solvers.py +134 -0
- evoxels/timesteppers.py +119 -0
- evoxels/utils.py +124 -0
- evoxels/voxelfields.py +318 -0
- evoxels/voxelgrid.py +278 -0
- evoxels-0.1.0.dist-info/METADATA +171 -0
- evoxels-0.1.0.dist-info/RECORD +20 -0
- evoxels-0.1.0.dist-info/WHEEL +5 -0
- evoxels-0.1.0.dist-info/licenses/LICENSE +21 -0
- evoxels-0.1.0.dist-info/top_level.txt +1 -0
evoxels/timesteppers.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
from .problem_definition import ODE, SemiLinearODE
|
|
6
|
+
|
|
7
|
+
State = Any # e.g. torch.Tensor or jax.Array
|
|
8
|
+
|
|
9
|
+
class TimeStepper(ABC):
|
|
10
|
+
"""Abstract interface for single‐step timestepping schemes."""
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def order(self) -> int:
|
|
15
|
+
"""Temporal order of accuracy."""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def step(self, u: State, t: float) -> State:
|
|
20
|
+
"""
|
|
21
|
+
Take one timestep from t to (t+dt).
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
u : Current state
|
|
25
|
+
t : Current time
|
|
26
|
+
Returns:
|
|
27
|
+
Updated state at t + dt.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class ForwardEuler(TimeStepper):
|
|
34
|
+
"""First order Euler forward scheme."""
|
|
35
|
+
problem: ODE
|
|
36
|
+
dt: float
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def order(self) -> int:
|
|
40
|
+
return 1
|
|
41
|
+
|
|
42
|
+
def step(self, u: State, t: float) -> State:
|
|
43
|
+
return u + self.dt * self.problem.rhs(u, t)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class PseudoSpectralIMEX(TimeStepper):
|
|
48
|
+
"""First‐order IMEX Fourier pseudo‐spectral scheme
|
|
49
|
+
|
|
50
|
+
aka semi-implicit Fourier spectral method; see
|
|
51
|
+
[Zhu and Chen 1999, doi:10.1103/PhysRevE.60.3564]
|
|
52
|
+
for more details.
|
|
53
|
+
"""
|
|
54
|
+
problem: SemiLinearODE
|
|
55
|
+
dt: float
|
|
56
|
+
|
|
57
|
+
def __post_init__(self):
|
|
58
|
+
# Pre‐bake the linear prefactor in Fourier
|
|
59
|
+
self._fft_prefac = self.dt / (1 - self.dt*self.problem.fourier_symbol)
|
|
60
|
+
if self.problem.bc_type == 'periodic':
|
|
61
|
+
self.pad = self.problem.vg.bc.pad_fft_periodic
|
|
62
|
+
elif self.problem.bc_type == 'dirichlet':
|
|
63
|
+
self.pad = self.problem.vg.bc.pad_fft_dirichlet_periodic
|
|
64
|
+
elif self.problem.bc_type == 'neumann':
|
|
65
|
+
self.pad = self.problem.vg.bc.pad_fft_zero_flux_periodic
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def order(self) -> int:
|
|
69
|
+
return 1
|
|
70
|
+
|
|
71
|
+
def step(self, u: State, t: float) -> State:
|
|
72
|
+
dc = self.pad(self.problem.rhs(u, t))
|
|
73
|
+
dc_fft = self._fft_prefac * self.problem.vg.rfftn(dc, dc.shape)
|
|
74
|
+
update = self.problem.vg.irfftn(dc_fft, dc.shape)[:,:u.shape[1]]
|
|
75
|
+
return u + update
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
import jax.numpy as jnp
|
|
80
|
+
import diffrax as dfx
|
|
81
|
+
|
|
82
|
+
class PseudoSpectralIMEX_dfx(dfx.AbstractSolver):
|
|
83
|
+
"""Re-implementation of pseudo_spectral_IMEX as diffrax class
|
|
84
|
+
|
|
85
|
+
This is used for the inversion models based on jax and diffrax
|
|
86
|
+
"""
|
|
87
|
+
fourier_symbol: float
|
|
88
|
+
term_structure = dfx.ODETerm
|
|
89
|
+
interpolation_cls = dfx.LocalLinearInterpolation
|
|
90
|
+
|
|
91
|
+
def order(self, terms):
|
|
92
|
+
return 1
|
|
93
|
+
|
|
94
|
+
def init(self, terms, t0, t1, y0, args):
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
|
|
98
|
+
del solver_state, made_jump
|
|
99
|
+
δt = t1 - t0
|
|
100
|
+
f0 = terms.vf(t0, y0, args)
|
|
101
|
+
euler_y1 = y0 + δt * f0
|
|
102
|
+
dc_fft = jnp.fft.rfftn(f0)
|
|
103
|
+
dc_fft *= δt / (1.0 - self.fourier_symbol * δt)
|
|
104
|
+
update = jnp.fft.irfftn(dc_fft, f0.shape)
|
|
105
|
+
y1 = y0 + update
|
|
106
|
+
|
|
107
|
+
y_error = y1 - euler_y1
|
|
108
|
+
dense_info = dict(y0=y0, y1=y1)
|
|
109
|
+
|
|
110
|
+
solver_state = None
|
|
111
|
+
result = dfx.RESULTS.successful
|
|
112
|
+
return y1, y_error, dense_info, solver_state, result
|
|
113
|
+
|
|
114
|
+
def func(self, terms, t0, y0, args):
|
|
115
|
+
return terms.vf(t0, y0, args)
|
|
116
|
+
|
|
117
|
+
except ImportError:
|
|
118
|
+
PseudoSpectralIMEX_dfx = None
|
|
119
|
+
warnings.warn("Diffrax not found. 'PseudoSpectralIMEX_dfx' will not be available.")
|
evoxels/utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import sympy as sp
|
|
3
|
+
import sympy.vector as spv
|
|
4
|
+
import evoxels as evo
|
|
5
|
+
from evoxels.problem_definition import SmoothedBoundaryODE
|
|
6
|
+
|
|
7
|
+
### Generalized test case
|
|
8
|
+
def rhs_convergence_test(
|
|
9
|
+
ODE_class, # an ODE class with callable rhs(field, t)->torch.Tensor (shape [x,y,z])
|
|
10
|
+
problem_kwargs, # problem parameters to instantiate ODE
|
|
11
|
+
test_function, # exact init_fun(x,y,z)->np.ndarray
|
|
12
|
+
mask_function=None,
|
|
13
|
+
convention="cell_center",
|
|
14
|
+
dtype="float32",
|
|
15
|
+
powers = np.array([3,4,5,6,7]),
|
|
16
|
+
backend = "torch"
|
|
17
|
+
):
|
|
18
|
+
"""Evaluate spatial order of an ODE right-hand side.
|
|
19
|
+
|
|
20
|
+
``test_function`` can be a single sympy expression or a list of
|
|
21
|
+
expressions representing multiple variables. The returned error and
|
|
22
|
+
slope arrays have one entry for each provided function.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
ODE_class: an ODE class with callable rhs(field, t).
|
|
26
|
+
problem_kwargs: problem-specific parameters to instantiate ODE.
|
|
27
|
+
test_function: single sympy expression or a list of expressions.
|
|
28
|
+
mask_function: static mask for smoothed boundary method.
|
|
29
|
+
convention: grid convention.
|
|
30
|
+
dtype: floate precision (``float32`` or ``float64``).
|
|
31
|
+
powers: refine grid in powers of two (i.e. ``Nx = 2**p``).
|
|
32
|
+
backend: use ``torch`` or ``jax`` for testing.
|
|
33
|
+
"""
|
|
34
|
+
# Verify mask_function only used with SmoothedBoundaryODE
|
|
35
|
+
if mask_function is not None and not issubclass(ODE_class, SmoothedBoundaryODE):
|
|
36
|
+
raise TypeError(
|
|
37
|
+
f"Mask function provided but {ODE_class.__name__} "
|
|
38
|
+
"is not a SmoothedBoundaryODE."
|
|
39
|
+
)
|
|
40
|
+
CS = spv.CoordSys3D('CS')
|
|
41
|
+
# Prepare lambdified mask if needed
|
|
42
|
+
mask = (
|
|
43
|
+
sp.lambdify((CS.x, CS.y, CS.z), mask_function, "numpy")
|
|
44
|
+
if mask_function is not None
|
|
45
|
+
else None
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if isinstance(test_function, (list, tuple)):
|
|
49
|
+
test_functions = list(test_function)
|
|
50
|
+
else:
|
|
51
|
+
test_functions = [test_function]
|
|
52
|
+
n_funcs = len(test_functions)
|
|
53
|
+
|
|
54
|
+
# Multiply test functions with mask for SBM testing
|
|
55
|
+
if mask is not None:
|
|
56
|
+
temp_list = []
|
|
57
|
+
for func in test_functions:
|
|
58
|
+
temp_list.append(func*mask_function)
|
|
59
|
+
test_functions = temp_list
|
|
60
|
+
|
|
61
|
+
dx = np.zeros(len(powers))
|
|
62
|
+
errors = np.zeros((n_funcs, len(powers)))
|
|
63
|
+
|
|
64
|
+
for i, p in enumerate(powers):
|
|
65
|
+
if convention == 'cell_center':
|
|
66
|
+
vf = evo.VoxelFields((2**p, 2**p, 2**p), (1, 1, 1), convention=convention)
|
|
67
|
+
elif convention == 'staggered_x':
|
|
68
|
+
vf = evo.VoxelFields((2**p + 1, 2**p, 2**p), (1, 1, 1), convention=convention)
|
|
69
|
+
vf.precision = dtype
|
|
70
|
+
grid = vf.meshgrid()
|
|
71
|
+
if backend == 'torch':
|
|
72
|
+
vg = evo.voxelgrid.VoxelGridTorch(vf.grid_info(), precision=vf.precision, device='cpu')
|
|
73
|
+
elif backend == 'jax':
|
|
74
|
+
vg = evo.voxelgrid.VoxelGridJax(vf.grid_info(), precision=vf.precision)
|
|
75
|
+
|
|
76
|
+
# Initialise fields
|
|
77
|
+
u_list = []
|
|
78
|
+
for func in test_functions:
|
|
79
|
+
init_fun = sp.lambdify((CS.x, CS.y, CS.z), func, "numpy")
|
|
80
|
+
init_data = init_fun(*grid)
|
|
81
|
+
u_list.append(vg.init_scalar_field(init_data))
|
|
82
|
+
|
|
83
|
+
u = vg.concatenate(u_list, 0)
|
|
84
|
+
u = vg.bc.trim_boundary_nodes(u)
|
|
85
|
+
|
|
86
|
+
# Init mask if smoothed boundary ODE
|
|
87
|
+
if mask is not None:
|
|
88
|
+
problem_kwargs["mask"] = mask(*grid)
|
|
89
|
+
|
|
90
|
+
ODE = ODE_class(vg, **problem_kwargs)
|
|
91
|
+
rhs_numeric = ODE.rhs(u, 0)
|
|
92
|
+
|
|
93
|
+
if n_funcs > 1 and mask is not None:
|
|
94
|
+
rhs_analytic = ODE.rhs_analytic(mask_function, test_functions, 0)
|
|
95
|
+
elif n_funcs > 1 and mask is None:
|
|
96
|
+
rhs_analytic = ODE.rhs_analytic(test_functions, 0)
|
|
97
|
+
elif n_funcs == 1 and mask is not None:
|
|
98
|
+
rhs_analytic = [ODE.rhs_analytic(mask_function, test_functions[0], 0)]
|
|
99
|
+
else:
|
|
100
|
+
rhs_analytic = [ODE.rhs_analytic(test_functions[0], 0)]
|
|
101
|
+
|
|
102
|
+
# Compute solutions
|
|
103
|
+
for j, func in enumerate(test_functions):
|
|
104
|
+
comp = vg.export_scalar_field_to_numpy(rhs_numeric[j:j+1])
|
|
105
|
+
exact_fun = sp.lambdify((CS.x, CS.y, CS.z), rhs_analytic[j], "numpy")
|
|
106
|
+
exact = exact_fun(*grid)
|
|
107
|
+
if convention == "staggered_x":
|
|
108
|
+
exact = exact[1:-1, :, :]
|
|
109
|
+
|
|
110
|
+
# Error norm
|
|
111
|
+
diff = comp - exact
|
|
112
|
+
errors[j, i] = np.linalg.norm(diff) / np.linalg.norm(exact)
|
|
113
|
+
dx[i] = vf.spacing[0]
|
|
114
|
+
|
|
115
|
+
# Fit slope after loop
|
|
116
|
+
slopes = np.array(
|
|
117
|
+
[np.polyfit(np.log(dx), np.log(err), 1)[0] for err in errors]
|
|
118
|
+
)
|
|
119
|
+
if slopes.size == 1:
|
|
120
|
+
slopes = slopes[0]
|
|
121
|
+
order = ODE.order
|
|
122
|
+
|
|
123
|
+
return dx, errors if errors.shape[0] > 1 else errors[0], slopes, order
|
|
124
|
+
|
evoxels/voxelfields.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
# In a world of cubes and blocks,
|
|
2
|
+
# Where reality takes voxel knocks,
|
|
3
|
+
# Every shape and form we see,
|
|
4
|
+
# Is a pixelated mystery.
|
|
5
|
+
|
|
6
|
+
# Mountains rise in jagged peaks,
|
|
7
|
+
# Rivers flow in blocky streaks.
|
|
8
|
+
# So embrace the charm of this edgy place,
|
|
9
|
+
# Where every voxel finds its space
|
|
10
|
+
|
|
11
|
+
# In einer Welt aus Würfeln und Blöcken,
|
|
12
|
+
# in der die Realität in Voxelform erscheint,
|
|
13
|
+
# ist jede Form, die wir sehen,
|
|
14
|
+
# ein verpixeltes Rätsel.
|
|
15
|
+
|
|
16
|
+
# Berge erheben sich in gezackten Gipfeln,
|
|
17
|
+
# Flüsse fließen in blockförmigen Adern.
|
|
18
|
+
# Also lass dich vom Charme dieses kantigen Ortes verzaubern,
|
|
19
|
+
# wo jedes Voxel seinen Platz findet.
|
|
20
|
+
|
|
21
|
+
import matplotlib.pyplot as plt
|
|
22
|
+
from matplotlib.widgets import Slider
|
|
23
|
+
import numpy as np
|
|
24
|
+
from typing import Tuple
|
|
25
|
+
import warnings
|
|
26
|
+
from .voxelgrid import Grid
|
|
27
|
+
|
|
28
|
+
class VoxelFields:
|
|
29
|
+
"""Manage 3D voxel grids for simulation, visualization, and I/O.
|
|
30
|
+
|
|
31
|
+
This class provides a uniform, cell‐centered or staggered‐x voxel grid,
|
|
32
|
+
handles spacing and origin, and stores any number of named 3D fields.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
shape (tuple[int, int, int]): Number of voxels ``(Nx, Ny, Nz)``.
|
|
36
|
+
domain_size (tuple[float, float, float], optional):
|
|
37
|
+
Physical dimensions (Lx, Ly, Lz). Defaults to (1, 1, 1).
|
|
38
|
+
convention (str, optional):
|
|
39
|
+
Grid convention, either 'cell_center' or 'staggered_x'.
|
|
40
|
+
Defaults to 'cell_center'.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If `domain_size` is not length 3 or contains non-numeric values.
|
|
44
|
+
ValueError: If `convention` is not one of 'cell_center' or 'staggered_x'.
|
|
45
|
+
Warning: If the spacing ratio max/min > 10, a warning is issued.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
shape (tuple[int, int, int]): Number of voxels ``(Nx, Ny, Nz)``.
|
|
49
|
+
domain_size (tuple[float, float, float]): Physical domain lengths.
|
|
50
|
+
spacing (tuple[float, float, float]): Grid spacing (dx, dy, dz).
|
|
51
|
+
origin (tuple[float, float, float]):
|
|
52
|
+
Coordinates of the (0, 0, 0) corner for cell-centered or staggered grids.
|
|
53
|
+
convention (str): Either 'cell_center' or 'staggered_x'.
|
|
54
|
+
precision (type): NumPy floating-point type for grid coordinates.
|
|
55
|
+
grid (tuple[np.ndarray, np.ndarray, np.ndarray] or None):
|
|
56
|
+
Meshgrid arrays (x, y, z) once created by `add_grid()`, else None.
|
|
57
|
+
fields (dict[str, np.ndarray]): Mapping field names to 3D arrays.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
>>> vf = VoxelFields((100, 100, 100), domain_size=(1, 1, 1))
|
|
61
|
+
>>> vf.add_field('temperature', np_array)
|
|
62
|
+
>>> x, y, z = vf.plot_slice('temperature', 10)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, shape: Tuple[int, int, int], domain_size=(1, 1, 1), convention='cell_center'):
|
|
66
|
+
"""Create a voxel grid with ``shape`` cells."""
|
|
67
|
+
if not (
|
|
68
|
+
isinstance(shape, (list, tuple))
|
|
69
|
+
and len(shape) == 3
|
|
70
|
+
and all(isinstance(n, (int, np.integer)) for n in shape)
|
|
71
|
+
):
|
|
72
|
+
raise ValueError("shape must be a tuple of three integers")
|
|
73
|
+
self.shape = tuple(int(n) for n in shape)
|
|
74
|
+
num_x, num_y, num_z = self.shape
|
|
75
|
+
self.precision = 'float32' # float64
|
|
76
|
+
self.convention = convention
|
|
77
|
+
|
|
78
|
+
if not isinstance(domain_size, (list, tuple)) or len(domain_size) != 3:
|
|
79
|
+
raise ValueError("domain_size must be a list or tuple with three elements (dx, dy, dz)")
|
|
80
|
+
if not all(isinstance(x, (int, float)) for x in domain_size):
|
|
81
|
+
raise ValueError("All elements in domain_size must be integers or floats")
|
|
82
|
+
self.domain_size = domain_size
|
|
83
|
+
|
|
84
|
+
if convention == 'cell_center':
|
|
85
|
+
self.spacing = (domain_size[0]/num_x, domain_size[1]/num_y, domain_size[2]/num_z)
|
|
86
|
+
self.origin = (self.spacing[0]/2, self.spacing[1]/2, self.spacing[2]/2)
|
|
87
|
+
elif convention == 'staggered_x':
|
|
88
|
+
self.spacing = (domain_size[0]/(num_x-1), domain_size[1]/num_y, domain_size[2]/num_z)
|
|
89
|
+
self.origin = (0, self.spacing[1]/2, self.spacing[2]/2)
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError("Chosen convention must be cell_center or staggered_x.")
|
|
92
|
+
|
|
93
|
+
if (np.max(self.spacing)/np.min(self.spacing) > 10):
|
|
94
|
+
warnings.warn("Simulations become very questionable for largely different spacings e.g. dz >> dx.")
|
|
95
|
+
self.grid = None
|
|
96
|
+
self.fields = {}
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def Nx(self) -> int: # backward compatibility
|
|
100
|
+
return self.shape[0]
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def Ny(self) -> int:
|
|
104
|
+
return self.shape[1]
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def Nz(self) -> int:
|
|
108
|
+
return self.shape[2]
|
|
109
|
+
|
|
110
|
+
def __str__(self):
|
|
111
|
+
"""Return a human readable description of the voxel grid."""
|
|
112
|
+
return (
|
|
113
|
+
f"Domain with size {self.domain_size} and "
|
|
114
|
+
f"{self.shape} grid points on "
|
|
115
|
+
f"{self.convention} position."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def grid_info(self):
|
|
119
|
+
"""Return a :class:`Grid` dataclass describing this domain."""
|
|
120
|
+
grid = Grid(self.shape, self.origin, self.spacing, self.convention)
|
|
121
|
+
return grid
|
|
122
|
+
|
|
123
|
+
def add_field(self, name: str, array=None):
|
|
124
|
+
"""
|
|
125
|
+
Adds a field to the voxel grid.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
name (str): Name of the field.
|
|
129
|
+
array (numpy.ndarray, optional): 3D array to initialize the field. If None, initializes with zeros.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValueError: If the provided array does not match the voxel grid dimensions.
|
|
133
|
+
TypeError: If the provided array is not a numpy array.
|
|
134
|
+
"""
|
|
135
|
+
if array is not None:
|
|
136
|
+
if isinstance(array, np.ndarray):
|
|
137
|
+
if array.shape == self.shape:
|
|
138
|
+
self.fields[name] = array
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"The provided array must have the shape {self.shape}."
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
raise TypeError("The provided array must be a numpy array.")
|
|
145
|
+
else:
|
|
146
|
+
self.fields[name] = np.zeros(self.shape)
|
|
147
|
+
|
|
148
|
+
def set_voxel_sphere(self, name: str, center, radius, label: int | float = 1):
|
|
149
|
+
"""Create a voxelized representation of a sphere in 3D
|
|
150
|
+
|
|
151
|
+
Fill voxels within given ``radius`` around the given ``center``
|
|
152
|
+
with value provided by ``label``.
|
|
153
|
+
"""
|
|
154
|
+
x, y, z = np.ogrid[:self.Nx, :self.Ny, :self.Nz]
|
|
155
|
+
distance_squared = (x * self.spacing[0] + self.origin[0] - center[0])**2 +\
|
|
156
|
+
(y * self.spacing[1] + self.origin[1] - center[1])**2 +\
|
|
157
|
+
(z * self.spacing[2] + self.origin[2] - center[2])**2
|
|
158
|
+
mask = distance_squared <= radius**2
|
|
159
|
+
self.fields[name][mask] = label
|
|
160
|
+
|
|
161
|
+
def average(self, name: str):
|
|
162
|
+
"""Return the average value of a stored field."""
|
|
163
|
+
if self.convention == 'cell_center':
|
|
164
|
+
average = np.mean(self.fields[name])
|
|
165
|
+
elif self.convention == 'staggered_x':
|
|
166
|
+
# Count first and last slice as half cells
|
|
167
|
+
average = np.sum(self.fields[name][1:-1,:,:]) \
|
|
168
|
+
+ 0.5*np.sum(self.fields[name][ 0,:,:]) \
|
|
169
|
+
+ 0.5*np.sum(self.fields[name][-1,:,:])
|
|
170
|
+
average /= ((self.Nx - 1) * self.Ny * self.Nz)
|
|
171
|
+
return average
|
|
172
|
+
|
|
173
|
+
def axes(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
174
|
+
""" Returns the 1D coordinate arrays along each axis. """
|
|
175
|
+
return tuple(
|
|
176
|
+
np.arange(0, n, dtype=self.precision) * self.spacing[i] + self.origin[i]
|
|
177
|
+
for i, n in enumerate(self.shape)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def meshgrid(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
181
|
+
""" Returns full 3D mesh grids for each axis. """
|
|
182
|
+
ax = self.axes()
|
|
183
|
+
# indexing='ij' makes Ax[i,j,k] = x-coordinate at (i,j,k), etc.
|
|
184
|
+
return tuple(np.meshgrid(*ax, indexing='ij'))
|
|
185
|
+
|
|
186
|
+
def export_to_vtk(self, filename="output.vtk", field_names=None):
|
|
187
|
+
"""
|
|
188
|
+
Exports fields to a VTK file for visualization (e.g. VisIt or ParaView).
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
filename (str): Name of the output VTK file.
|
|
192
|
+
field_names (list, optional): List of field names to export. Exports all fields if None.
|
|
193
|
+
"""
|
|
194
|
+
import pyvista as pv
|
|
195
|
+
names = field_names if field_names else list(self.fields.keys())
|
|
196
|
+
grid = pv.ImageData()
|
|
197
|
+
grid.spacing = self.spacing
|
|
198
|
+
grid.dimensions = (self.Nx + 1, self.Ny + 1, self.Nz + 1)
|
|
199
|
+
grid.origin = (self.origin[0] - self.spacing[0]/2, \
|
|
200
|
+
self.origin[1] - self.spacing[1]/2, \
|
|
201
|
+
self.origin[2] - self.spacing[2]/2)
|
|
202
|
+
for name in names:
|
|
203
|
+
grid.cell_data[name] = self.fields[name].flatten(order="F") # Fortran order flattening
|
|
204
|
+
grid.save(filename)
|
|
205
|
+
|
|
206
|
+
def plot_slice(self, fieldname, slice_index, direction='z', time=None, colormap='viridis', value_bounds=None):
|
|
207
|
+
"""
|
|
208
|
+
Plots a 2D slice of a field along a specified direction.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
fieldname (str): Name of the field to plot.
|
|
212
|
+
slice_index (int): Index of the slice to plot.
|
|
213
|
+
direction (str): Normal direction of the slice ('x', 'y', or 'z').
|
|
214
|
+
dpi (int): Resolution of the plot.
|
|
215
|
+
colormap (str): Colormap to use for the plot.
|
|
216
|
+
|
|
217
|
+
Raises:
|
|
218
|
+
ValueError: If an invalid direction is provided.
|
|
219
|
+
"""
|
|
220
|
+
# Colormaps
|
|
221
|
+
# linear: viridis, Greys
|
|
222
|
+
# diverging: seismic
|
|
223
|
+
# levels: tab20, flag
|
|
224
|
+
# gradual: turbo
|
|
225
|
+
if direction == 'x':
|
|
226
|
+
slice = np.s_[slice_index,:,:]
|
|
227
|
+
start1, start2 = self.origin[1]-self.spacing[1]/2, self.origin[2]-self.spacing[2]/2
|
|
228
|
+
end1, end2 = self.domain_size[1]-start1, self.domain_size[2]-start2
|
|
229
|
+
label1, label2 = ['Y', 'Z']
|
|
230
|
+
elif direction == 'y':
|
|
231
|
+
slice = np.s_[:,slice_index,:]
|
|
232
|
+
start1, start2 = self.origin[0]-self.spacing[0]/2, self.origin[2]-self.spacing[2]/2
|
|
233
|
+
end1, end2 = self.domain_size[0]-start1, self.domain_size[2]-start2
|
|
234
|
+
label1, label2 = ['X', 'Z']
|
|
235
|
+
elif direction == 'z':
|
|
236
|
+
slice = np.s_[:,:,slice_index]
|
|
237
|
+
start1, start2 = self.origin[0]-self.spacing[0]/2, self.origin[1]-self.spacing[1]/2
|
|
238
|
+
end1, end2 = self.domain_size[0]-start1, self.domain_size[1]-start2
|
|
239
|
+
label1, label2 = ['X', 'Y']
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError("Given direction must be x, y or z")
|
|
242
|
+
|
|
243
|
+
plt.figure()
|
|
244
|
+
if value_bounds is not None:
|
|
245
|
+
im = plt.imshow(self.fields[fieldname][slice].T, cmap=colormap,\
|
|
246
|
+
origin='lower', extent=[start1, end1, start2, end2],\
|
|
247
|
+
vmin=value_bounds[0], vmax=value_bounds[1])
|
|
248
|
+
else:
|
|
249
|
+
im = plt.imshow(self.fields[fieldname][slice].T, cmap=colormap, \
|
|
250
|
+
origin='lower', extent=[start1, end1, start2, end2])
|
|
251
|
+
|
|
252
|
+
ratio = np.clip((end2-start2)/(end1-start1), 0, 1)
|
|
253
|
+
plt.colorbar(im, shrink=ratio)
|
|
254
|
+
plt.xlabel(label1)
|
|
255
|
+
plt.ylabel(label2)
|
|
256
|
+
if time:
|
|
257
|
+
plt.title(f'Slice {slice_index} of {fieldname} in {direction} at time {time}')
|
|
258
|
+
else:
|
|
259
|
+
plt.title(f'Slice {slice_index} of {fieldname} in {direction}')
|
|
260
|
+
plt.show()
|
|
261
|
+
|
|
262
|
+
def plot_field_interactive(self, fieldname, direction='x', colormap='viridis', value_bounds=None):
|
|
263
|
+
"""
|
|
264
|
+
Creates an interactive plot for exploring slices of a 3D field.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
fieldname (str): Name of the field to plot.
|
|
268
|
+
direction (str): Direction of slicing ('x', 'y', or 'z').
|
|
269
|
+
dpi (int): Resolution of the plot.
|
|
270
|
+
colormap (str): Colormap to use for the plot.
|
|
271
|
+
|
|
272
|
+
Raises:
|
|
273
|
+
ValueError: If an invalid direction is provided.
|
|
274
|
+
"""
|
|
275
|
+
if direction == 'x':
|
|
276
|
+
axes = (0,1,2)
|
|
277
|
+
end1, end2 = self.domain_size[1], self.domain_size[2]
|
|
278
|
+
label1, label2 = ['Y', 'Z']
|
|
279
|
+
elif direction == 'y':
|
|
280
|
+
axes = (1,0,2)
|
|
281
|
+
end1, end2 = self.domain_size[0], self.domain_size[2]
|
|
282
|
+
label1, label2 = ['X', 'Z']
|
|
283
|
+
elif direction == 'z':
|
|
284
|
+
axes = (2,0,1)
|
|
285
|
+
end1, end2 = self.domain_size[0], self.domain_size[1]
|
|
286
|
+
label1, label2 = ['X', 'Y']
|
|
287
|
+
else:
|
|
288
|
+
raise ValueError("Given direction must be x, y or z")
|
|
289
|
+
|
|
290
|
+
field = np.transpose(self.fields[fieldname], axes)
|
|
291
|
+
fig, ax = plt.subplots()
|
|
292
|
+
if value_bounds is None:
|
|
293
|
+
value_bounds = (np.min(field), np.max(field))
|
|
294
|
+
im = ax.imshow(
|
|
295
|
+
field[0].T,
|
|
296
|
+
cmap=colormap,
|
|
297
|
+
origin="lower",
|
|
298
|
+
extent=[0, end1, 0, end2],
|
|
299
|
+
vmin=value_bounds[0],
|
|
300
|
+
vmax=value_bounds[1],
|
|
301
|
+
)
|
|
302
|
+
ax.set_xlabel(label1)
|
|
303
|
+
ax.set_ylabel(label2)
|
|
304
|
+
ax.set_title(f'Slice 0 in {direction}-direction of {fieldname}')
|
|
305
|
+
plt.colorbar(im, ax=ax)
|
|
306
|
+
|
|
307
|
+
# Add a slider for changing timeframes
|
|
308
|
+
position = plt.axes([0.2, 0.0, 0.6, 0.02])
|
|
309
|
+
ax_slider = Slider(position, 'Slice', 0, field.shape[0]-1, valinit=0, valstep=1)
|
|
310
|
+
|
|
311
|
+
def update(val):
|
|
312
|
+
slice_idx = int(ax_slider.val)
|
|
313
|
+
im.set_array(field[slice_idx].T)
|
|
314
|
+
ax.set_title(f'Slice {slice_idx} in ' + direction + '-direction of ' + fieldname)
|
|
315
|
+
fig.canvas.draw_idle()
|
|
316
|
+
|
|
317
|
+
ax_slider.on_changed(update)
|
|
318
|
+
return ax_slider
|