evoxels 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
evoxels-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 daubners
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
evoxels-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,171 @@
1
+ Metadata-Version: 2.4
2
+ Name: evoxels
3
+ Version: 0.1.0
4
+ Summary: Voxel-based structure simulation solvers
5
+ Author-email: Simon Daubner <s.daubner@imperial.ac.uk>
6
+ License: MIT
7
+ Classifier: Development Status :: 4 - Beta
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Topic :: Scientific/Engineering :: Physics
10
+ Classifier: Topic :: Scientific/Engineering :: Chemistry
11
+ Classifier: Topic :: Scientific/Engineering :: Image Processing
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Natural Language :: English
14
+ Classifier: Environment :: GPU
15
+ Classifier: Environment :: GPU :: NVIDIA CUDA
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Requires-Python: >=3.9
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: numpy>=1.21
25
+ Requires-Dist: matplotlib>=3.4
26
+ Requires-Dist: pyvista>=0.38
27
+ Requires-Dist: psutil>=5.9.0
28
+ Requires-Dist: ipython>=7.0.0
29
+ Requires-Dist: sympy>=1.10
30
+ Provides-Extra: torch
31
+ Requires-Dist: torch>=2.1; extra == "torch"
32
+ Provides-Extra: jax
33
+ Requires-Dist: jax>=0.4.14; extra == "jax"
34
+ Requires-Dist: jaxlib>=0.4.14; extra == "jax"
35
+ Requires-Dist: diffrax>=0.6.2; extra == "jax"
36
+ Provides-Extra: dev
37
+ Requires-Dist: pytest; extra == "dev"
38
+ Requires-Dist: ruff; extra == "dev"
39
+ Provides-Extra: notebooks
40
+ Requires-Dist: ipywidgets; extra == "notebooks"
41
+ Requires-Dist: ipympl; extra == "notebooks"
42
+ Requires-Dist: notebook; extra == "notebooks"
43
+ Dynamic: license-file
44
+
45
+ [![Python package](https://github.com/daubners/evoxels/actions/workflows/python-package.yml/badge.svg?branch=main)](https://github.com/daubners/evoxels/actions/workflows/python-package.yml)
46
+
47
+ # evoxels
48
+ A differentiable physics framework for voxel-based microstructure simulations
49
+
50
+ For more detailed information about the code [read the docs](https://evoxels.readthedocs.io).
51
+
52
+ <p align="center">
53
+ <img src="evoxels.png" width="90%"></img>
54
+ </p>
55
+
56
+ ```
57
+ In a world of cubes and blocks,
58
+ Where reality takes voxel knocks,
59
+ Every shape and form we see,
60
+ Is a pixelated mystery.
61
+
62
+ Mountains rise in jagged peaks,
63
+ Rivers flow in blocky streaks.
64
+ So embrace the charm of this edgy place,
65
+ Where every voxel finds its space
66
+ ```
67
+
68
+ ## Description
69
+ **evoxels are not static — they evolve, adapt, and reveal.**
70
+ Whether you're modeling phase transitions, predicting effective properties, or coupling imaging and simulation — evoxels is the GPU-native, differentiable core that keeps pace with your science.
71
+
72
+ Materials science inherently spans disciplines: experimentalists use advanced microscopy to uncover micro- and nanoscale structure, while theorists and computational scientists develop models that link processing, structure, and properties. Bridging these domains is essential for inverse material design where you start from desired performance and work backwards to optimal microstructures and manufacturing routes. Integrating high-resolution imaging with predictive simulations and data‐driven optimization accelerates discovery and deepens understanding of process–structure–property relationships
73
+
74
+ From a high-level perspective, evoxels is organized around two core abstractions: ``VoxelFields`` and ``VoxelGrid``. VoxelFields provides a uniform, NumPy-based container for any number of 3D fields on the same regular grid, maximizing interoperability with image I/O libraries (e.g. tifffile, h5py, napari, scikit-image) and visualization tools (PyVista, VTK). VoxelGrid couples these fields to either a PyTorch or JAX backend, offering pre-defined boundary conditions, finite difference stencils and FFT libraries.
75
+
76
+ The evoxels package enables large-scale forward and inverse simulations on uniform voxel grids, ensuring direct compatibility with microscopy data and harnessing GPU-optimized FFT and tensor operations.
77
+ This design supports forward modeling of transport and phase evolution phenomena, as well as backpropagation-based inverse problems such as parameter estimation and neural surrogate training - tasks which are still difficult to achieve with traditional FEM-based solvers.
78
+ This differentiable‐physics foundation makes it easy to embed voxel‐based solvers as neural‐network layers, train generative models for optimal microstructures, or jointly optimize processing and properties via gradient descent. By keeping each simulation step fast and fully backpropagatable, evoxels enables data‐driven materials discovery and high‐dimensional design‐space exploration.
79
+
80
+ ## Installation
81
+
82
+ TL;DR
83
+ ```bash
84
+ conda create --name voxenv python=3.12
85
+ conda activate voxenv
86
+ pip install evoxels[torch,jax,dev,notebooks]
87
+ pip install --upgrade "jax[cuda12]"
88
+ ```
89
+
90
+ The package is available on pypi but can also be installed by cloning the repository
91
+ ```
92
+ git clone git@github.com:daubners/evoxels.git
93
+ ```
94
+
95
+ and then locally installing in editable mode.
96
+ It is recommended to install the package inside a Python virtual environment so
97
+ that the dependencies do not interfere with your system packages. Create and
98
+ activate a virtual environment e.g. using miniconda
99
+
100
+ ```bash
101
+ conda create --name myenv python=3.12
102
+ conda activate myenv
103
+ ```
104
+ Navigate to the evoxels folder, then
105
+ ```
106
+ pip install -e .[torch] # install with torch backend
107
+ pip install -e .[jax] # install with jax backend
108
+ pip install -e .[dev, notebooks] # install testing and notebooks
109
+ ```
110
+ Note that the default `[jax]` installation is only CPU compatible. To install the corresponding CUDA libraries check your CUDA version with
111
+ ```bash
112
+ nvidia-smi
113
+ ```
114
+ then install the CUDA-enabled JAX backend via (in this case for CUDA version 12)
115
+ ```bash
116
+ pip install -U "jax[cuda12]"
117
+ ```
118
+ To install both backends within one environment it is important to install torch first and then upgrade the `jax` installation e.g.
119
+ ```bash
120
+ pip install evoxels[torch, jax, dev, notebooks]
121
+ pip install --upgrade "jax[cuda12]"
122
+ ```
123
+ To work with the example notebooks install Jupyter and all notebook related dependencies via
124
+ ```
125
+ pip install -e .[notebooks]
126
+ ```
127
+ Launch the notebooks with
128
+ ```
129
+ jupyter notebook
130
+ ```
131
+ If you are using VSCode open the Command Palette and select
132
+ "Jupyter: Create New Blank Notebook" or open an existing notebook file.
133
+
134
+
135
+ ## Usage
136
+
137
+ Example of creating a voxel field object and running a Cahn-Hilliard simulation based on a semi-implicit FFT approach
138
+
139
+ ```
140
+ import evoxels as evo
141
+ import numpy as np
142
+
143
+ nx, ny, nz = [100, 100, 100]
144
+
145
+ vf = evo.VoxelFields((nx, ny, nz), (nx,ny,nz))
146
+ noise = 0.5 + 0.1*np.random.rand(nx, ny, nz)
147
+ vf.add_field("c", noise)
148
+
149
+ dt = 0.1
150
+ final_time = 100
151
+ steps = int(final_time/dt)
152
+
153
+ evo.run_cahn_hilliard_solver(
154
+ vf, 'c', 'torch', jit=True, device='cuda',
155
+ time_increment=dt, frames=10, max_iters=steps,
156
+ verbose='plot', vtk_out=False, plot_bounds=(0,1)
157
+ )
158
+ ```
159
+ As the simulation is running, the "c" field will be overwritten each frame. Therefore, ``vf.fields["c"]`` will give you the last frame of the simulation. This code design has been chosen specifically for large data such that the RAM requirements are rather low.
160
+ For visual inspection of your simulation results, you can plot individual slices (e.g. slice=10) for a given direction (e.g. x)
161
+ ```
162
+ vf.plot_slice("c", 10, direction='x', colormap='viridis')
163
+ ```
164
+ or use the following code for interactive plotting with a slider to go through the volume
165
+ ```
166
+ %matplotlib widget
167
+ vf.plot_field_interactive("c", direction='x', colormap='turbo')
168
+ ```
169
+
170
+ ## License
171
+ This code has been published under the MIT licence.
@@ -0,0 +1,127 @@
1
+ [![Python package](https://github.com/daubners/evoxels/actions/workflows/python-package.yml/badge.svg?branch=main)](https://github.com/daubners/evoxels/actions/workflows/python-package.yml)
2
+
3
+ # evoxels
4
+ A differentiable physics framework for voxel-based microstructure simulations
5
+
6
+ For more detailed information about the code [read the docs](https://evoxels.readthedocs.io).
7
+
8
+ <p align="center">
9
+ <img src="evoxels.png" width="90%"></img>
10
+ </p>
11
+
12
+ ```
13
+ In a world of cubes and blocks,
14
+ Where reality takes voxel knocks,
15
+ Every shape and form we see,
16
+ Is a pixelated mystery.
17
+
18
+ Mountains rise in jagged peaks,
19
+ Rivers flow in blocky streaks.
20
+ So embrace the charm of this edgy place,
21
+ Where every voxel finds its space
22
+ ```
23
+
24
+ ## Description
25
+ **evoxels are not static — they evolve, adapt, and reveal.**
26
+ Whether you're modeling phase transitions, predicting effective properties, or coupling imaging and simulation — evoxels is the GPU-native, differentiable core that keeps pace with your science.
27
+
28
+ Materials science inherently spans disciplines: experimentalists use advanced microscopy to uncover micro- and nanoscale structure, while theorists and computational scientists develop models that link processing, structure, and properties. Bridging these domains is essential for inverse material design where you start from desired performance and work backwards to optimal microstructures and manufacturing routes. Integrating high-resolution imaging with predictive simulations and data‐driven optimization accelerates discovery and deepens understanding of process–structure–property relationships
29
+
30
+ From a high-level perspective, evoxels is organized around two core abstractions: ``VoxelFields`` and ``VoxelGrid``. VoxelFields provides a uniform, NumPy-based container for any number of 3D fields on the same regular grid, maximizing interoperability with image I/O libraries (e.g. tifffile, h5py, napari, scikit-image) and visualization tools (PyVista, VTK). VoxelGrid couples these fields to either a PyTorch or JAX backend, offering pre-defined boundary conditions, finite difference stencils and FFT libraries.
31
+
32
+ The evoxels package enables large-scale forward and inverse simulations on uniform voxel grids, ensuring direct compatibility with microscopy data and harnessing GPU-optimized FFT and tensor operations.
33
+ This design supports forward modeling of transport and phase evolution phenomena, as well as backpropagation-based inverse problems such as parameter estimation and neural surrogate training - tasks which are still difficult to achieve with traditional FEM-based solvers.
34
+ This differentiable‐physics foundation makes it easy to embed voxel‐based solvers as neural‐network layers, train generative models for optimal microstructures, or jointly optimize processing and properties via gradient descent. By keeping each simulation step fast and fully backpropagatable, evoxels enables data‐driven materials discovery and high‐dimensional design‐space exploration.
35
+
36
+ ## Installation
37
+
38
+ TL;DR
39
+ ```bash
40
+ conda create --name voxenv python=3.12
41
+ conda activate voxenv
42
+ pip install evoxels[torch,jax,dev,notebooks]
43
+ pip install --upgrade "jax[cuda12]"
44
+ ```
45
+
46
+ The package is available on pypi but can also be installed by cloning the repository
47
+ ```
48
+ git clone git@github.com:daubners/evoxels.git
49
+ ```
50
+
51
+ and then locally installing in editable mode.
52
+ It is recommended to install the package inside a Python virtual environment so
53
+ that the dependencies do not interfere with your system packages. Create and
54
+ activate a virtual environment e.g. using miniconda
55
+
56
+ ```bash
57
+ conda create --name myenv python=3.12
58
+ conda activate myenv
59
+ ```
60
+ Navigate to the evoxels folder, then
61
+ ```
62
+ pip install -e .[torch] # install with torch backend
63
+ pip install -e .[jax] # install with jax backend
64
+ pip install -e .[dev, notebooks] # install testing and notebooks
65
+ ```
66
+ Note that the default `[jax]` installation is only CPU compatible. To install the corresponding CUDA libraries check your CUDA version with
67
+ ```bash
68
+ nvidia-smi
69
+ ```
70
+ then install the CUDA-enabled JAX backend via (in this case for CUDA version 12)
71
+ ```bash
72
+ pip install -U "jax[cuda12]"
73
+ ```
74
+ To install both backends within one environment it is important to install torch first and then upgrade the `jax` installation e.g.
75
+ ```bash
76
+ pip install evoxels[torch, jax, dev, notebooks]
77
+ pip install --upgrade "jax[cuda12]"
78
+ ```
79
+ To work with the example notebooks install Jupyter and all notebook related dependencies via
80
+ ```
81
+ pip install -e .[notebooks]
82
+ ```
83
+ Launch the notebooks with
84
+ ```
85
+ jupyter notebook
86
+ ```
87
+ If you are using VSCode open the Command Palette and select
88
+ "Jupyter: Create New Blank Notebook" or open an existing notebook file.
89
+
90
+
91
+ ## Usage
92
+
93
+ Example of creating a voxel field object and running a Cahn-Hilliard simulation based on a semi-implicit FFT approach
94
+
95
+ ```
96
+ import evoxels as evo
97
+ import numpy as np
98
+
99
+ nx, ny, nz = [100, 100, 100]
100
+
101
+ vf = evo.VoxelFields((nx, ny, nz), (nx,ny,nz))
102
+ noise = 0.5 + 0.1*np.random.rand(nx, ny, nz)
103
+ vf.add_field("c", noise)
104
+
105
+ dt = 0.1
106
+ final_time = 100
107
+ steps = int(final_time/dt)
108
+
109
+ evo.run_cahn_hilliard_solver(
110
+ vf, 'c', 'torch', jit=True, device='cuda',
111
+ time_increment=dt, frames=10, max_iters=steps,
112
+ verbose='plot', vtk_out=False, plot_bounds=(0,1)
113
+ )
114
+ ```
115
+ As the simulation is running, the "c" field will be overwritten each frame. Therefore, ``vf.fields["c"]`` will give you the last frame of the simulation. This code design has been chosen specifically for large data such that the RAM requirements are rather low.
116
+ For visual inspection of your simulation results, you can plot individual slices (e.g. slice=10) for a given direction (e.g. x)
117
+ ```
118
+ vf.plot_slice("c", 10, direction='x', colormap='viridis')
119
+ ```
120
+ or use the following code for interactive plotting with a slider to go through the volume
121
+ ```
122
+ %matplotlib widget
123
+ vf.plot_field_interactive("c", direction='x', colormap='turbo')
124
+ ```
125
+
126
+ ## License
127
+ This code has been published under the MIT licence.
@@ -0,0 +1,13 @@
1
+ """Public API for the evoxels package."""
2
+
3
+ from .voxelfields import VoxelFields
4
+ from .precompiled_solvers.cahn_hilliard import (run_cahn_hilliard_solver)
5
+ from .precompiled_solvers.allen_cahn import (run_allen_cahn_solver)
6
+ from .inversion import InversionModel
7
+
8
+ __all__ = [
9
+ "VoxelFields",
10
+ "run_cahn_hilliard_solver",
11
+ "run_allen_cahn_solver",
12
+ "InversionModel"
13
+ ]
@@ -0,0 +1,138 @@
1
+ # Shorthands in slicing logic
2
+ __ = slice(None) # all elements [:]
3
+ _i_ = slice(1, -1) # inner elements [1:-1]
4
+
5
+ class CellCenteredBCs:
6
+ def __init__(self, vg):
7
+ self.vg = vg
8
+
9
+ def pad_periodic(self, field):
10
+ """
11
+ Periodic boundary conditions in all directions.
12
+ Consistent with cell centered grid.
13
+ """
14
+ return self.vg.pad_periodic(field)
15
+
16
+ def pad_dirichlet_periodic(self, field, bc0=0, bc1=0):
17
+ """
18
+ Homogenous Dirichlet boundary conditions in x-drection,
19
+ periodic in y- and z-direction. Consistent with cell centered grid,
20
+ but loss of 2nd order convergence.
21
+ """
22
+ padded = self.vg.pad_periodic(field)
23
+ padded = self.vg.set(padded, (__, 0,__,__), 2.0*bc0 - padded[:, 1,:,:])
24
+ padded = self.vg.set(padded, (__,-1,__,__), 2.0*bc1 - padded[:,-2,:,:])
25
+ return padded
26
+
27
+ def pad_zero_flux_periodic(self, field):
28
+ padded = self.vg.pad_periodic(field)
29
+ padded = self.vg.set(padded, (__, 0,__,__), padded[:, 1,:,:])
30
+ padded = self.vg.set(padded, (__,-1,__,__), padded[:,-2,:,:])
31
+ return padded
32
+
33
+ def pad_zero_flux(self, field):
34
+ padded = self.vg.pad_zeros(field)
35
+ padded = self.vg.set(padded, (__, 0,__,__), padded[:, 1,:,:])
36
+ padded = self.vg.set(padded, (__,-1,__,__), padded[:,-2,:,:])
37
+ padded = self.vg.set(padded, (__,__, 0,__), padded[:,:, 1,:])
38
+ padded = self.vg.set(padded, (__,__,-1,__), padded[:,:,-2,:])
39
+ padded = self.vg.set(padded, (__,__,__, 0), padded[:,:,:, 1])
40
+ padded = self.vg.set(padded, (__,__,__,-1), padded[:,:,:,-2])
41
+ return padded
42
+
43
+ def pad_fft_periodic(self, field):
44
+ """Periodic field needs no fft padding."""
45
+ return field
46
+
47
+ def pad_fft_dirichlet_periodic(self, field):
48
+ """Pad with inverse of flipped field in x direction."""
49
+ return self.vg.concatenate((field, -self.vg.lib.flip(field, [0])), 1)
50
+
51
+ def pad_fft_zero_flux_periodic(self, field):
52
+ """Pad with flipped field in x direction."""
53
+ return self.vg.concatenate((field, self.vg.lib.flip(field, [0])), 1)
54
+
55
+ def trim_boundary_nodes(self, field):
56
+ return field
57
+
58
+ def trim_ghost_nodes(self, field):
59
+ if field[0,_i_,_i_,_i_].shape == self.vg.shape:
60
+ return field[:,_i_,_i_,_i_]
61
+ else:
62
+ raise ValueError(
63
+ f"The provided field has the wrong shape {self.vg.shape}."
64
+ )
65
+
66
+
67
+ class StaggeredXBCs:
68
+ def __init__(self, vg):
69
+ self.vg = vg
70
+
71
+ def pad_periodic_BC_staggered_x(self, field):
72
+ """
73
+ If field is fully periodic it should be in
74
+ cell center convention!
75
+ """
76
+ raise NotImplementedError
77
+
78
+ def pad_dirichlet_periodic(self, field, bc0=0, bc1=0):
79
+ """
80
+ Homogenous Dirichlet boundary conditions in x-drection,
81
+ periodic in y- and z-direction. Consistent with staggered_x grid,
82
+ maintains 2nd order convergence.
83
+ """
84
+ padded = self.vg.pad_periodic(field)
85
+ padded = self.vg.set(padded, (__, 0,__,__), bc0)
86
+ padded = self.vg.set(padded, (__,-1,__,__), bc1)
87
+ return padded
88
+
89
+ def pad_zero_flux_periodic(self, field):
90
+ """
91
+ The following comes out of on interpolation polynomial p with
92
+ p'(0) = 0, p(dx) = f(dx,...), p(2*dx) = f(2*dx,...)
93
+ and then use p(0) for the ghost cell.
94
+ This should be of sufficient order of f'(0) = 0, and even better if
95
+ also f'''(0) = 0 (as it holds for cos(k*pi*x) )
96
+ """
97
+ padded = self.vg.pad_periodic(field)
98
+ fac1 = 4/3
99
+ fac2 = 1/3
100
+ padded = self.vg.set(padded, (__, 0,__,__), fac1*padded[:, 1,:,:] - fac2*padded[:, 2,:,:])
101
+ padded = self.vg.set(padded, (__,-1,__,__), fac1*padded[:,-2,:,:] - fac2*padded[:,-3,:,:])
102
+ return padded
103
+
104
+ def pad_zero_flux(self, field):
105
+ raise NotImplementedError
106
+
107
+ def pad_fft_periodic(self, field):
108
+ """
109
+ If field is fully periodic it should be in
110
+ cell center convention!
111
+ """
112
+ raise NotImplementedError
113
+
114
+ def pad_fft_dirichlet_periodic(self, field):
115
+ """Pad with inverse of flipped field in x direction."""
116
+ bc = self.vg.lib.zeros_like(field[:,0:1])
117
+ return self.vg.concatenate((field, bc, -self.vg.lib.flip(field, [0]), bc), 1)
118
+
119
+ def pad_fft_zero_flux_periodic(self, field):
120
+ """Pad with flipped field in x direction."""
121
+ raise NotImplementedError
122
+
123
+ def trim_boundary_nodes(self, field):
124
+ """Trim boundary nodes of ``field`` for staggered grids."""
125
+ if field.shape[1] == self.vg.shape[0]:
126
+ return field[:,_i_,:,:]
127
+ else:
128
+ raise ValueError(
129
+ f"The provided field must have the shape {self.vg.shape}."
130
+ )
131
+
132
+ def trim_ghost_nodes(self, field):
133
+ if field[0,:,_i_,_i_].shape == self.vg.shape:
134
+ return field[:,:,_i_,_i_]
135
+ else:
136
+ raise ValueError(
137
+ f"The provided field has the wrong shape {self.vg.shape}."
138
+ )
@@ -0,0 +1,103 @@
1
+ # Shorthands in slicing logic
2
+ __ = slice(None) # all elements [:]
3
+ _i_ = slice(1, -1) # inner elements [1:-1]
4
+
5
+ CENTER = (__, _i_, _i_, _i_)
6
+ LEFT = (__, slice(None,-2), _i_, _i_)
7
+ RIGHT = (__, slice(2, None), _i_, _i_)
8
+ BOTTOM = (__, _i_, slice(None,-2), _i_)
9
+ TOP = (__, _i_, slice(2, None), _i_)
10
+ BACK = (__, _i_, _i_, slice(None,-2))
11
+ FRONT = (__, _i_, _i_, slice(2, None))
12
+
13
+ class FDStencils:
14
+ """Class wrapper for finite difference stencils
15
+
16
+ Is inherited by the VoxelGrid to apply stencils to
17
+ backend arrays.
18
+ """
19
+
20
+ def to_x_face(self, field):
21
+ """Interpolate to face position staggered in x"""
22
+ return 0.5 * (field[:,1:,:,:] + field[:,:-1,:,:])
23
+
24
+ def to_y_face(self, field):
25
+ """Interpolate to face position staggered in y"""
26
+ return 0.5 * (field[:,:,1:,:] + field[:,:,:-1,:])
27
+
28
+ def to_z_face(self, field):
29
+ """Interpolate to face position staggered in z"""
30
+ return 0.5 * (field[:,:,:,1:] + field[:,:,:,:-1])
31
+
32
+ def grad_x_face(self, field):
33
+ """Gradient at face position staggered in x"""
34
+ return (field[:,1:,:,:] - field[:,:-1,:,:]) * self.div_dx[0]
35
+
36
+ def grad_y_face(self, field):
37
+ """Gradient at face position staggered in y"""
38
+ return (field[:,:,1:,:] - field[:,:,:-1,:]) * self.div_dx[1]
39
+
40
+ def grad_z_face(self, field):
41
+ """Gradient at face position staggered in z"""
42
+ return (field[:,:,:,1:] - field[:,:,:,:-1]) * self.div_dx[2]
43
+
44
+ def grad_x_center(self, field):
45
+ """Gradient in x at cell center"""
46
+ return 0.5 * (field[RIGHT] - field[LEFT]) * self.div_dx[0]
47
+
48
+ def grad_y_center(self, field):
49
+ """Gradient in x at cell center"""
50
+ return 0.5 * (field[TOP] - field[BOTTOM]) * self.div_dx[1]
51
+
52
+ def grad_z_center(self, field):
53
+ """Gradient in x at cell center"""
54
+ return 0.5 * (field[FRONT] - field[BACK]) * self.div_dx[2]
55
+
56
+ def gradient_norm_squared(self, field):
57
+ """Gradient norm squared at cell centers"""
58
+ return self.grad_x_center(field)**2 +\
59
+ self.grad_y_center(field)**2 + \
60
+ self.grad_z_center(field)**2
61
+
62
+ def laplace(self, field):
63
+ r"""Calculate laplace based on compact 2nd order stencil.
64
+
65
+ Laplace given as $\nabla\cdot(\nabla u)$ which in 3D is given by
66
+ $\partial^2 u/\partial^2 x + \partial^2 u/\partial^2 y+ \partial^2 u/\partial^2 z$
67
+ Returned field has same shape as the input field (padded with zeros)
68
+ """
69
+ # Manual indexing is ~10x faster than conv3d with laplace kernel in torch
70
+ laplace = \
71
+ (field[RIGHT] + field[LEFT]) * self.div_dx2[0] + \
72
+ (field[TOP] + field[BOTTOM]) * self.div_dx2[1] + \
73
+ (field[FRONT] + field[BACK]) * self.div_dx2[2] - \
74
+ 2 * field[CENTER] * self.lib.sum(self.div_dx2)
75
+ return laplace
76
+
77
+ def normal_laplace(self, field):
78
+ r"""Calculate the normal component of the laplacian
79
+
80
+ which is identical to the full laplacian minus curvature.
81
+ It is defined as $\partial^2_n u = \nabla\cdot(\nabla u\cdot n)\cdot n$
82
+ where $n$ denotes the surface normal.
83
+ In the context of phasefield models $n$ is defined as
84
+ $\frac{\nabla u}{|\nabla u|}$.
85
+ The calaculation is based on a compact 2nd order stencil.
86
+ """
87
+ n_laplace =\
88
+ self.grad_x_center(field)**2 * (field[RIGHT] - 2*field[CENTER] + field[LEFT]) * self.div_dx2[0] +\
89
+ self.grad_y_center(field)**2 * (field[TOP] - 2*field[CENTER] + field[BOTTOM]) * self.div_dx2[1]+\
90
+ self.grad_z_center(field)**2 * (field[FRONT] - 2*field[CENTER] + field[BACK]) * self.div_dx2[2]+\
91
+ 0.5 * self.grad_x_center(field) * self.grad_y_center(field) *\
92
+ (field[:,2:,2:,1:-1] + field[:,:-2,:-2,1:-1] -\
93
+ field[:,:-2,2:,1:-1] - field[:,2:,:-2,1:-1]) * self.div_dx[0] * self.div_dx[1] +\
94
+ 0.5 *self.grad_x_center(field) * self.grad_z_center(field) *\
95
+ (field[:,2:,1:-1,2:] + field[:,:-2,1:-1,:-2] -\
96
+ field[:,:-2,1:-1,2:] - field[:,2:,1:-1,:-2]) * self.div_dx[0] * self.div_dx[2] +\
97
+ 0.5 * self.grad_y_center(field) * self.grad_z_center(field) *\
98
+ (field[:,1:-1,2:,2:] + field[:,1:-1,:-2,:-2] -\
99
+ field[:,1:-1,:-2,2:] - field[:,1:-1,2:,:-2]) * self.div_dx[1] * self.div_dx[2]
100
+ norm2 = self.gradient_norm_squared(field)
101
+ bulk = self.lib.where(norm2 <= 1e-7)
102
+ norm2 = self.set(norm2, bulk, 1.0)
103
+ return n_laplace/norm2
@@ -0,0 +1,97 @@
1
+ import numpy as np
2
+
3
+ try:
4
+ import jax
5
+ import jax.numpy as jnp
6
+ _HAS_JAX = True
7
+ except ImportError:
8
+ _HAS_JAX = False
9
+ class DummyJax:
10
+ @staticmethod
11
+ def jit(f):
12
+ return f
13
+ class DummyJnp:
14
+ @staticmethod
15
+ def ones_like(x):
16
+ return np.ones_like(x)
17
+ @staticmethod
18
+ def exp(x):
19
+ return np.exp(x)
20
+
21
+ jax = DummyJax()
22
+ jnp = DummyJnp()
23
+
24
+ import dataclasses
25
+
26
+ @dataclasses.dataclass
27
+ class DiffusionLegendrePolynomials:
28
+ max_degree: int
29
+
30
+ def __post_init__(self):
31
+ self.leg_poly = ExpLegendrePolynomials(self.max_degree)
32
+
33
+ def __call__(self, params, inputs):
34
+ return self.leg_poly(params, 2.0 * inputs - 1.0)
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class ChemicalPotentialLegendrePolynomials:
39
+ max_degree: int
40
+
41
+ def __post_init__(self):
42
+ self.leg_poly = LegendrePolynomialRecurrence(self.max_degree)
43
+
44
+ def __call__(self, params, inputs):
45
+ return self.leg_poly(params, 2.0 * inputs - 1.0)
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class ExpLegendrePolynomials:
50
+ max_degree: int
51
+
52
+ def __post_init__(self):
53
+ leg_poly = LegendrePolynomialRecurrence(self.max_degree)
54
+ self.func = jax.jit(lambda p, x: jnp.exp(leg_poly(p, x)))
55
+
56
+ def __call__(self, params, inputs):
57
+ return self.func(params, inputs)
58
+
59
+ # TODO: This can be made more efficient
60
+ @dataclasses.dataclass
61
+ class LegendrePolynomialRecurrence:
62
+ max_degree: int
63
+
64
+ def __post_init__(self):
65
+ # Create a JIT-compiled function that computes the Legendre polynomial sum
66
+ def compute_polynomial_sum(params, x):
67
+ result = params[0] * self.T0(x)
68
+ for i in range(1, self.max_degree + 1):
69
+ result += params[i] * self._compute_legendre(i, x)
70
+ return result
71
+
72
+ self.func = jax.jit(compute_polynomial_sum)
73
+
74
+ def __call__(self, params, inputs):
75
+ return self.func(params, inputs)
76
+
77
+ def T0(self, x):
78
+ return 1.0 * jnp.ones_like(x)
79
+
80
+ def _compute_legendre(self, n, x):
81
+ """Compute the nth Legendre polynomial using the three-term recurrence relation."""
82
+ if n == 0:
83
+ return self.T0(x)
84
+ elif n == 1:
85
+ return x
86
+
87
+ # Initialize P₀ and P₁
88
+ p_prev = self.T0(x) # P₀
89
+ p_curr = x # P₁
90
+
91
+ # Compute Pₙ using the recurrence relation
92
+ for i in range(1, n):
93
+ p_next = ((2 * i + 1) * x * p_curr - i * p_prev) / (i + 1)
94
+ p_prev = p_curr
95
+ p_curr = p_next
96
+
97
+ return p_curr