evoxels 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,119 @@
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+ from .problem_definition import ODE, SemiLinearODE
6
+
7
+ State = Any # e.g. torch.Tensor or jax.Array
8
+
9
+ class TimeStepper(ABC):
10
+ """Abstract interface for single‐step timestepping schemes."""
11
+
12
+ @property
13
+ @abstractmethod
14
+ def order(self) -> int:
15
+ """Temporal order of accuracy."""
16
+ pass
17
+
18
+ @abstractmethod
19
+ def step(self, u: State, t: float) -> State:
20
+ """
21
+ Take one timestep from t to (t+dt).
22
+
23
+ Args:
24
+ u : Current state
25
+ t : Current time
26
+ Returns:
27
+ Updated state at t + dt.
28
+ """
29
+ pass
30
+
31
+
32
+ @dataclass
33
+ class ForwardEuler(TimeStepper):
34
+ """First order Euler forward scheme."""
35
+ problem: ODE
36
+ dt: float
37
+
38
+ @property
39
+ def order(self) -> int:
40
+ return 1
41
+
42
+ def step(self, u: State, t: float) -> State:
43
+ return u + self.dt * self.problem.rhs(u, t)
44
+
45
+
46
+ @dataclass
47
+ class PseudoSpectralIMEX(TimeStepper):
48
+ """First‐order IMEX Fourier pseudo‐spectral scheme
49
+
50
+ aka semi-implicit Fourier spectral method; see
51
+ [Zhu and Chen 1999, doi:10.1103/PhysRevE.60.3564]
52
+ for more details.
53
+ """
54
+ problem: SemiLinearODE
55
+ dt: float
56
+
57
+ def __post_init__(self):
58
+ # Pre‐bake the linear prefactor in Fourier
59
+ self._fft_prefac = self.dt / (1 - self.dt*self.problem.fourier_symbol)
60
+ if self.problem.bc_type == 'periodic':
61
+ self.pad = self.problem.vg.bc.pad_fft_periodic
62
+ elif self.problem.bc_type == 'dirichlet':
63
+ self.pad = self.problem.vg.bc.pad_fft_dirichlet_periodic
64
+ elif self.problem.bc_type == 'neumann':
65
+ self.pad = self.problem.vg.bc.pad_fft_zero_flux_periodic
66
+
67
+ @property
68
+ def order(self) -> int:
69
+ return 1
70
+
71
+ def step(self, u: State, t: float) -> State:
72
+ dc = self.pad(self.problem.rhs(u, t))
73
+ dc_fft = self._fft_prefac * self.problem.vg.rfftn(dc, dc.shape)
74
+ update = self.problem.vg.irfftn(dc_fft, dc.shape)[:,:u.shape[1]]
75
+ return u + update
76
+
77
+
78
+ try:
79
+ import jax.numpy as jnp
80
+ import diffrax as dfx
81
+
82
+ class PseudoSpectralIMEX_dfx(dfx.AbstractSolver):
83
+ """Re-implementation of pseudo_spectral_IMEX as diffrax class
84
+
85
+ This is used for the inversion models based on jax and diffrax
86
+ """
87
+ fourier_symbol: float
88
+ term_structure = dfx.ODETerm
89
+ interpolation_cls = dfx.LocalLinearInterpolation
90
+
91
+ def order(self, terms):
92
+ return 1
93
+
94
+ def init(self, terms, t0, t1, y0, args):
95
+ return None
96
+
97
+ def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
98
+ del solver_state, made_jump
99
+ δt = t1 - t0
100
+ f0 = terms.vf(t0, y0, args)
101
+ euler_y1 = y0 + δt * f0
102
+ dc_fft = jnp.fft.rfftn(f0)
103
+ dc_fft *= δt / (1.0 - self.fourier_symbol * δt)
104
+ update = jnp.fft.irfftn(dc_fft, f0.shape)
105
+ y1 = y0 + update
106
+
107
+ y_error = y1 - euler_y1
108
+ dense_info = dict(y0=y0, y1=y1)
109
+
110
+ solver_state = None
111
+ result = dfx.RESULTS.successful
112
+ return y1, y_error, dense_info, solver_state, result
113
+
114
+ def func(self, terms, t0, y0, args):
115
+ return terms.vf(t0, y0, args)
116
+
117
+ except ImportError:
118
+ PseudoSpectralIMEX_dfx = None
119
+ warnings.warn("Diffrax not found. 'PseudoSpectralIMEX_dfx' will not be available.")
evoxels/utils.py ADDED
@@ -0,0 +1,124 @@
1
+ import numpy as np
2
+ import sympy as sp
3
+ import sympy.vector as spv
4
+ import evoxels as evo
5
+ from evoxels.problem_definition import SmoothedBoundaryODE
6
+
7
+ ### Generalized test case
8
+ def rhs_convergence_test(
9
+ ODE_class, # an ODE class with callable rhs(field, t)->torch.Tensor (shape [x,y,z])
10
+ problem_kwargs, # problem parameters to instantiate ODE
11
+ test_function, # exact init_fun(x,y,z)->np.ndarray
12
+ mask_function=None,
13
+ convention="cell_center",
14
+ dtype="float32",
15
+ powers = np.array([3,4,5,6,7]),
16
+ backend = "torch"
17
+ ):
18
+ """Evaluate spatial order of an ODE right-hand side.
19
+
20
+ ``test_function`` can be a single sympy expression or a list of
21
+ expressions representing multiple variables. The returned error and
22
+ slope arrays have one entry for each provided function.
23
+
24
+ Args:
25
+ ODE_class: an ODE class with callable rhs(field, t).
26
+ problem_kwargs: problem-specific parameters to instantiate ODE.
27
+ test_function: single sympy expression or a list of expressions.
28
+ mask_function: static mask for smoothed boundary method.
29
+ convention: grid convention.
30
+ dtype: floate precision (``float32`` or ``float64``).
31
+ powers: refine grid in powers of two (i.e. ``Nx = 2**p``).
32
+ backend: use ``torch`` or ``jax`` for testing.
33
+ """
34
+ # Verify mask_function only used with SmoothedBoundaryODE
35
+ if mask_function is not None and not issubclass(ODE_class, SmoothedBoundaryODE):
36
+ raise TypeError(
37
+ f"Mask function provided but {ODE_class.__name__} "
38
+ "is not a SmoothedBoundaryODE."
39
+ )
40
+ CS = spv.CoordSys3D('CS')
41
+ # Prepare lambdified mask if needed
42
+ mask = (
43
+ sp.lambdify((CS.x, CS.y, CS.z), mask_function, "numpy")
44
+ if mask_function is not None
45
+ else None
46
+ )
47
+
48
+ if isinstance(test_function, (list, tuple)):
49
+ test_functions = list(test_function)
50
+ else:
51
+ test_functions = [test_function]
52
+ n_funcs = len(test_functions)
53
+
54
+ # Multiply test functions with mask for SBM testing
55
+ if mask is not None:
56
+ temp_list = []
57
+ for func in test_functions:
58
+ temp_list.append(func*mask_function)
59
+ test_functions = temp_list
60
+
61
+ dx = np.zeros(len(powers))
62
+ errors = np.zeros((n_funcs, len(powers)))
63
+
64
+ for i, p in enumerate(powers):
65
+ if convention == 'cell_center':
66
+ vf = evo.VoxelFields((2**p, 2**p, 2**p), (1, 1, 1), convention=convention)
67
+ elif convention == 'staggered_x':
68
+ vf = evo.VoxelFields((2**p + 1, 2**p, 2**p), (1, 1, 1), convention=convention)
69
+ vf.precision = dtype
70
+ grid = vf.meshgrid()
71
+ if backend == 'torch':
72
+ vg = evo.voxelgrid.VoxelGridTorch(vf.grid_info(), precision=vf.precision, device='cpu')
73
+ elif backend == 'jax':
74
+ vg = evo.voxelgrid.VoxelGridJax(vf.grid_info(), precision=vf.precision)
75
+
76
+ # Initialise fields
77
+ u_list = []
78
+ for func in test_functions:
79
+ init_fun = sp.lambdify((CS.x, CS.y, CS.z), func, "numpy")
80
+ init_data = init_fun(*grid)
81
+ u_list.append(vg.init_scalar_field(init_data))
82
+
83
+ u = vg.concatenate(u_list, 0)
84
+ u = vg.bc.trim_boundary_nodes(u)
85
+
86
+ # Init mask if smoothed boundary ODE
87
+ if mask is not None:
88
+ problem_kwargs["mask"] = mask(*grid)
89
+
90
+ ODE = ODE_class(vg, **problem_kwargs)
91
+ rhs_numeric = ODE.rhs(u, 0)
92
+
93
+ if n_funcs > 1 and mask is not None:
94
+ rhs_analytic = ODE.rhs_analytic(mask_function, test_functions, 0)
95
+ elif n_funcs > 1 and mask is None:
96
+ rhs_analytic = ODE.rhs_analytic(test_functions, 0)
97
+ elif n_funcs == 1 and mask is not None:
98
+ rhs_analytic = [ODE.rhs_analytic(mask_function, test_functions[0], 0)]
99
+ else:
100
+ rhs_analytic = [ODE.rhs_analytic(test_functions[0], 0)]
101
+
102
+ # Compute solutions
103
+ for j, func in enumerate(test_functions):
104
+ comp = vg.export_scalar_field_to_numpy(rhs_numeric[j:j+1])
105
+ exact_fun = sp.lambdify((CS.x, CS.y, CS.z), rhs_analytic[j], "numpy")
106
+ exact = exact_fun(*grid)
107
+ if convention == "staggered_x":
108
+ exact = exact[1:-1, :, :]
109
+
110
+ # Error norm
111
+ diff = comp - exact
112
+ errors[j, i] = np.linalg.norm(diff) / np.linalg.norm(exact)
113
+ dx[i] = vf.spacing[0]
114
+
115
+ # Fit slope after loop
116
+ slopes = np.array(
117
+ [np.polyfit(np.log(dx), np.log(err), 1)[0] for err in errors]
118
+ )
119
+ if slopes.size == 1:
120
+ slopes = slopes[0]
121
+ order = ODE.order
122
+
123
+ return dx, errors if errors.shape[0] > 1 else errors[0], slopes, order
124
+
evoxels/voxelfields.py ADDED
@@ -0,0 +1,318 @@
1
+ # In a world of cubes and blocks,
2
+ # Where reality takes voxel knocks,
3
+ # Every shape and form we see,
4
+ # Is a pixelated mystery.
5
+
6
+ # Mountains rise in jagged peaks,
7
+ # Rivers flow in blocky streaks.
8
+ # So embrace the charm of this edgy place,
9
+ # Where every voxel finds its space
10
+
11
+ # In einer Welt aus Würfeln und Blöcken,
12
+ # in der die Realität in Voxelform erscheint,
13
+ # ist jede Form, die wir sehen,
14
+ # ein verpixeltes Rätsel.
15
+
16
+ # Berge erheben sich in gezackten Gipfeln,
17
+ # Flüsse fließen in blockförmigen Adern.
18
+ # Also lass dich vom Charme dieses kantigen Ortes verzaubern,
19
+ # wo jedes Voxel seinen Platz findet.
20
+
21
+ import matplotlib.pyplot as plt
22
+ from matplotlib.widgets import Slider
23
+ import numpy as np
24
+ from typing import Tuple
25
+ import warnings
26
+ from .voxelgrid import Grid
27
+
28
+ class VoxelFields:
29
+ """Manage 3D voxel grids for simulation, visualization, and I/O.
30
+
31
+ This class provides a uniform, cell‐centered or staggered‐x voxel grid,
32
+ handles spacing and origin, and stores any number of named 3D fields.
33
+
34
+ Args:
35
+ shape (tuple[int, int, int]): Number of voxels ``(Nx, Ny, Nz)``.
36
+ domain_size (tuple[float, float, float], optional):
37
+ Physical dimensions (Lx, Ly, Lz). Defaults to (1, 1, 1).
38
+ convention (str, optional):
39
+ Grid convention, either 'cell_center' or 'staggered_x'.
40
+ Defaults to 'cell_center'.
41
+
42
+ Raises:
43
+ ValueError: If `domain_size` is not length 3 or contains non-numeric values.
44
+ ValueError: If `convention` is not one of 'cell_center' or 'staggered_x'.
45
+ Warning: If the spacing ratio max/min > 10, a warning is issued.
46
+
47
+ Attributes:
48
+ shape (tuple[int, int, int]): Number of voxels ``(Nx, Ny, Nz)``.
49
+ domain_size (tuple[float, float, float]): Physical domain lengths.
50
+ spacing (tuple[float, float, float]): Grid spacing (dx, dy, dz).
51
+ origin (tuple[float, float, float]):
52
+ Coordinates of the (0, 0, 0) corner for cell-centered or staggered grids.
53
+ convention (str): Either 'cell_center' or 'staggered_x'.
54
+ precision (type): NumPy floating-point type for grid coordinates.
55
+ grid (tuple[np.ndarray, np.ndarray, np.ndarray] or None):
56
+ Meshgrid arrays (x, y, z) once created by `add_grid()`, else None.
57
+ fields (dict[str, np.ndarray]): Mapping field names to 3D arrays.
58
+
59
+ Example:
60
+ >>> vf = VoxelFields((100, 100, 100), domain_size=(1, 1, 1))
61
+ >>> vf.add_field('temperature', np_array)
62
+ >>> x, y, z = vf.plot_slice('temperature', 10)
63
+ """
64
+
65
+ def __init__(self, shape: Tuple[int, int, int], domain_size=(1, 1, 1), convention='cell_center'):
66
+ """Create a voxel grid with ``shape`` cells."""
67
+ if not (
68
+ isinstance(shape, (list, tuple))
69
+ and len(shape) == 3
70
+ and all(isinstance(n, (int, np.integer)) for n in shape)
71
+ ):
72
+ raise ValueError("shape must be a tuple of three integers")
73
+ self.shape = tuple(int(n) for n in shape)
74
+ num_x, num_y, num_z = self.shape
75
+ self.precision = 'float32' # float64
76
+ self.convention = convention
77
+
78
+ if not isinstance(domain_size, (list, tuple)) or len(domain_size) != 3:
79
+ raise ValueError("domain_size must be a list or tuple with three elements (dx, dy, dz)")
80
+ if not all(isinstance(x, (int, float)) for x in domain_size):
81
+ raise ValueError("All elements in domain_size must be integers or floats")
82
+ self.domain_size = domain_size
83
+
84
+ if convention == 'cell_center':
85
+ self.spacing = (domain_size[0]/num_x, domain_size[1]/num_y, domain_size[2]/num_z)
86
+ self.origin = (self.spacing[0]/2, self.spacing[1]/2, self.spacing[2]/2)
87
+ elif convention == 'staggered_x':
88
+ self.spacing = (domain_size[0]/(num_x-1), domain_size[1]/num_y, domain_size[2]/num_z)
89
+ self.origin = (0, self.spacing[1]/2, self.spacing[2]/2)
90
+ else:
91
+ raise ValueError("Chosen convention must be cell_center or staggered_x.")
92
+
93
+ if (np.max(self.spacing)/np.min(self.spacing) > 10):
94
+ warnings.warn("Simulations become very questionable for largely different spacings e.g. dz >> dx.")
95
+ self.grid = None
96
+ self.fields = {}
97
+
98
+ @property
99
+ def Nx(self) -> int: # backward compatibility
100
+ return self.shape[0]
101
+
102
+ @property
103
+ def Ny(self) -> int:
104
+ return self.shape[1]
105
+
106
+ @property
107
+ def Nz(self) -> int:
108
+ return self.shape[2]
109
+
110
+ def __str__(self):
111
+ """Return a human readable description of the voxel grid."""
112
+ return (
113
+ f"Domain with size {self.domain_size} and "
114
+ f"{self.shape} grid points on "
115
+ f"{self.convention} position."
116
+ )
117
+
118
+ def grid_info(self):
119
+ """Return a :class:`Grid` dataclass describing this domain."""
120
+ grid = Grid(self.shape, self.origin, self.spacing, self.convention)
121
+ return grid
122
+
123
+ def add_field(self, name: str, array=None):
124
+ """
125
+ Adds a field to the voxel grid.
126
+
127
+ Args:
128
+ name (str): Name of the field.
129
+ array (numpy.ndarray, optional): 3D array to initialize the field. If None, initializes with zeros.
130
+
131
+ Raises:
132
+ ValueError: If the provided array does not match the voxel grid dimensions.
133
+ TypeError: If the provided array is not a numpy array.
134
+ """
135
+ if array is not None:
136
+ if isinstance(array, np.ndarray):
137
+ if array.shape == self.shape:
138
+ self.fields[name] = array
139
+ else:
140
+ raise ValueError(
141
+ f"The provided array must have the shape {self.shape}."
142
+ )
143
+ else:
144
+ raise TypeError("The provided array must be a numpy array.")
145
+ else:
146
+ self.fields[name] = np.zeros(self.shape)
147
+
148
+ def set_voxel_sphere(self, name: str, center, radius, label: int | float = 1):
149
+ """Create a voxelized representation of a sphere in 3D
150
+
151
+ Fill voxels within given ``radius`` around the given ``center``
152
+ with value provided by ``label``.
153
+ """
154
+ x, y, z = np.ogrid[:self.Nx, :self.Ny, :self.Nz]
155
+ distance_squared = (x * self.spacing[0] + self.origin[0] - center[0])**2 +\
156
+ (y * self.spacing[1] + self.origin[1] - center[1])**2 +\
157
+ (z * self.spacing[2] + self.origin[2] - center[2])**2
158
+ mask = distance_squared <= radius**2
159
+ self.fields[name][mask] = label
160
+
161
+ def average(self, name: str):
162
+ """Return the average value of a stored field."""
163
+ if self.convention == 'cell_center':
164
+ average = np.mean(self.fields[name])
165
+ elif self.convention == 'staggered_x':
166
+ # Count first and last slice as half cells
167
+ average = np.sum(self.fields[name][1:-1,:,:]) \
168
+ + 0.5*np.sum(self.fields[name][ 0,:,:]) \
169
+ + 0.5*np.sum(self.fields[name][-1,:,:])
170
+ average /= ((self.Nx - 1) * self.Ny * self.Nz)
171
+ return average
172
+
173
+ def axes(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
174
+ """ Returns the 1D coordinate arrays along each axis. """
175
+ return tuple(
176
+ np.arange(0, n, dtype=self.precision) * self.spacing[i] + self.origin[i]
177
+ for i, n in enumerate(self.shape)
178
+ )
179
+
180
+ def meshgrid(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
181
+ """ Returns full 3D mesh grids for each axis. """
182
+ ax = self.axes()
183
+ # indexing='ij' makes Ax[i,j,k] = x-coordinate at (i,j,k), etc.
184
+ return tuple(np.meshgrid(*ax, indexing='ij'))
185
+
186
+ def export_to_vtk(self, filename="output.vtk", field_names=None):
187
+ """
188
+ Exports fields to a VTK file for visualization (e.g. VisIt or ParaView).
189
+
190
+ Args:
191
+ filename (str): Name of the output VTK file.
192
+ field_names (list, optional): List of field names to export. Exports all fields if None.
193
+ """
194
+ import pyvista as pv
195
+ names = field_names if field_names else list(self.fields.keys())
196
+ grid = pv.ImageData()
197
+ grid.spacing = self.spacing
198
+ grid.dimensions = (self.Nx + 1, self.Ny + 1, self.Nz + 1)
199
+ grid.origin = (self.origin[0] - self.spacing[0]/2, \
200
+ self.origin[1] - self.spacing[1]/2, \
201
+ self.origin[2] - self.spacing[2]/2)
202
+ for name in names:
203
+ grid.cell_data[name] = self.fields[name].flatten(order="F") # Fortran order flattening
204
+ grid.save(filename)
205
+
206
+ def plot_slice(self, fieldname, slice_index, direction='z', time=None, colormap='viridis', value_bounds=None):
207
+ """
208
+ Plots a 2D slice of a field along a specified direction.
209
+
210
+ Args:
211
+ fieldname (str): Name of the field to plot.
212
+ slice_index (int): Index of the slice to plot.
213
+ direction (str): Normal direction of the slice ('x', 'y', or 'z').
214
+ dpi (int): Resolution of the plot.
215
+ colormap (str): Colormap to use for the plot.
216
+
217
+ Raises:
218
+ ValueError: If an invalid direction is provided.
219
+ """
220
+ # Colormaps
221
+ # linear: viridis, Greys
222
+ # diverging: seismic
223
+ # levels: tab20, flag
224
+ # gradual: turbo
225
+ if direction == 'x':
226
+ slice = np.s_[slice_index,:,:]
227
+ start1, start2 = self.origin[1]-self.spacing[1]/2, self.origin[2]-self.spacing[2]/2
228
+ end1, end2 = self.domain_size[1]-start1, self.domain_size[2]-start2
229
+ label1, label2 = ['Y', 'Z']
230
+ elif direction == 'y':
231
+ slice = np.s_[:,slice_index,:]
232
+ start1, start2 = self.origin[0]-self.spacing[0]/2, self.origin[2]-self.spacing[2]/2
233
+ end1, end2 = self.domain_size[0]-start1, self.domain_size[2]-start2
234
+ label1, label2 = ['X', 'Z']
235
+ elif direction == 'z':
236
+ slice = np.s_[:,:,slice_index]
237
+ start1, start2 = self.origin[0]-self.spacing[0]/2, self.origin[1]-self.spacing[1]/2
238
+ end1, end2 = self.domain_size[0]-start1, self.domain_size[1]-start2
239
+ label1, label2 = ['X', 'Y']
240
+ else:
241
+ raise ValueError("Given direction must be x, y or z")
242
+
243
+ plt.figure()
244
+ if value_bounds is not None:
245
+ im = plt.imshow(self.fields[fieldname][slice].T, cmap=colormap,\
246
+ origin='lower', extent=[start1, end1, start2, end2],\
247
+ vmin=value_bounds[0], vmax=value_bounds[1])
248
+ else:
249
+ im = plt.imshow(self.fields[fieldname][slice].T, cmap=colormap, \
250
+ origin='lower', extent=[start1, end1, start2, end2])
251
+
252
+ ratio = np.clip((end2-start2)/(end1-start1), 0, 1)
253
+ plt.colorbar(im, shrink=ratio)
254
+ plt.xlabel(label1)
255
+ plt.ylabel(label2)
256
+ if time:
257
+ plt.title(f'Slice {slice_index} of {fieldname} in {direction} at time {time}')
258
+ else:
259
+ plt.title(f'Slice {slice_index} of {fieldname} in {direction}')
260
+ plt.show()
261
+
262
+ def plot_field_interactive(self, fieldname, direction='x', colormap='viridis', value_bounds=None):
263
+ """
264
+ Creates an interactive plot for exploring slices of a 3D field.
265
+
266
+ Args:
267
+ fieldname (str): Name of the field to plot.
268
+ direction (str): Direction of slicing ('x', 'y', or 'z').
269
+ dpi (int): Resolution of the plot.
270
+ colormap (str): Colormap to use for the plot.
271
+
272
+ Raises:
273
+ ValueError: If an invalid direction is provided.
274
+ """
275
+ if direction == 'x':
276
+ axes = (0,1,2)
277
+ end1, end2 = self.domain_size[1], self.domain_size[2]
278
+ label1, label2 = ['Y', 'Z']
279
+ elif direction == 'y':
280
+ axes = (1,0,2)
281
+ end1, end2 = self.domain_size[0], self.domain_size[2]
282
+ label1, label2 = ['X', 'Z']
283
+ elif direction == 'z':
284
+ axes = (2,0,1)
285
+ end1, end2 = self.domain_size[0], self.domain_size[1]
286
+ label1, label2 = ['X', 'Y']
287
+ else:
288
+ raise ValueError("Given direction must be x, y or z")
289
+
290
+ field = np.transpose(self.fields[fieldname], axes)
291
+ fig, ax = plt.subplots()
292
+ if value_bounds is None:
293
+ value_bounds = (np.min(field), np.max(field))
294
+ im = ax.imshow(
295
+ field[0].T,
296
+ cmap=colormap,
297
+ origin="lower",
298
+ extent=[0, end1, 0, end2],
299
+ vmin=value_bounds[0],
300
+ vmax=value_bounds[1],
301
+ )
302
+ ax.set_xlabel(label1)
303
+ ax.set_ylabel(label2)
304
+ ax.set_title(f'Slice 0 in {direction}-direction of {fieldname}')
305
+ plt.colorbar(im, ax=ax)
306
+
307
+ # Add a slider for changing timeframes
308
+ position = plt.axes([0.2, 0.0, 0.6, 0.02])
309
+ ax_slider = Slider(position, 'Slice', 0, field.shape[0]-1, valinit=0, valstep=1)
310
+
311
+ def update(val):
312
+ slice_idx = int(ax_slider.val)
313
+ im.set_array(field[slice_idx].T)
314
+ ax.set_title(f'Slice {slice_idx} in ' + direction + '-direction of ' + fieldname)
315
+ fig.canvas.draw_idle()
316
+
317
+ ax_slider.on_changed(update)
318
+ return ax_slider