evoxels 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
evoxels/inversion.py CHANGED
@@ -72,7 +72,7 @@ class InversionModel:
72
72
  solver = PseudoSpectralIMEX_dfx(problem.fourier_symbol)
73
73
 
74
74
  solution = dfx.diffeqsolve(
75
- dfx.ODETerm(lambda t, y, args: problem.rhs(y, t)),
75
+ dfx.ODETerm(lambda t, y, args: problem.rhs(t, y)),
76
76
  solver,
77
77
  t0=saveat.subs.ts[0],
78
78
  t1=saveat.subs.ts[-1],
@@ -23,7 +23,7 @@ def run_allen_cahn_solver(
23
23
  plot_bounds = None,
24
24
  ):
25
25
  """
26
- Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
26
+ Solves time-dependent Allen-Cahn problem with ForwardEuler timestepper.
27
27
  """
28
28
  solver = TimeDependentSolver(
29
29
  voxelfields,
@@ -20,7 +20,7 @@ def run_cahn_hilliard_solver(
20
20
  plot_bounds = None,
21
21
  ):
22
22
  """
23
- Runs the Cahn-Hilliard solver with a predefined problem and timestepper.
23
+ Solves time-dependent Cahn-Hilliard problem with PseudoSpectralIMEX timestepper.
24
24
  """
25
25
  solver = TimeDependentSolver(
26
26
  voxelfields,
@@ -13,17 +13,17 @@ _i_ = slice(1, -1) # inner elements [1:-1]
13
13
  class ODE(ABC):
14
14
  @property
15
15
  @abstractmethod
16
- def order(self):
16
+ def order(self) -> int:
17
17
  """Spatial order of convergence for numerical right-hand side."""
18
18
  pass
19
19
 
20
20
  @abstractmethod
21
- def rhs_analytic(self, u, t):
21
+ def rhs_analytic(self, t, u):
22
22
  """Sympy expression of the problem right-hand side.
23
23
 
24
24
  Args:
25
- u : Sympy function of current state.
26
25
  t (float): Current time.
26
+ u : Sympy function of current state.
27
27
 
28
28
  Returns:
29
29
  Sympy function of problem right-hand side.
@@ -31,12 +31,12 @@ class ODE(ABC):
31
31
  pass
32
32
 
33
33
  @abstractmethod
34
- def rhs(self, u, t):
34
+ def rhs(self, t, u):
35
35
  """Numerical right-hand side of the ODE system.
36
36
 
37
37
  Args:
38
- u (array): Current state.
39
38
  t (float): Current time.
39
+ u (array): Current state.
40
40
 
41
41
  Returns:
42
42
  Same type as ``u`` containing the time derivative.
@@ -120,9 +120,11 @@ class ReactionDiffusion(SemiLinearODE):
120
120
  bc_fun = self.vg.bc.pad_zero_flux_periodic
121
121
  self.pad_boundary = lambda field, bc0, bc1: bc_fun(field)
122
122
  k_squared = self.vg.fft_k_squared_nonperiodic()
123
+ else:
124
+ raise ValueError(f"Unsupported BC type: {self.BC_type}")
123
125
 
124
126
  self._fourier_symbol = -self.D * self.A * k_squared
125
-
127
+
126
128
  @property
127
129
  def order(self):
128
130
  return 2
@@ -130,14 +132,14 @@ class ReactionDiffusion(SemiLinearODE):
130
132
  @property
131
133
  def fourier_symbol(self):
132
134
  return self._fourier_symbol
133
-
134
- def _eval_f(self, c, t, lib):
135
+
136
+ def _eval_f(self, t, c, lib):
135
137
  """Evaluate source/forcing term using ``self.f``."""
136
138
  try:
137
- return self.f(c, t, lib)
139
+ return self.f(t, c, lib)
138
140
  except TypeError:
139
- return self.f(c, t)
140
-
141
+ return self.f(t, c)
142
+
141
143
  @property
142
144
  def bc_type(self):
143
145
  return self.BC_type
@@ -145,12 +147,12 @@ class ReactionDiffusion(SemiLinearODE):
145
147
  def pad_bc(self, u):
146
148
  return self.pad_boundary(u, self.bcs[0], self.bcs[1])
147
149
 
148
- def rhs_analytic(self, u, t):
149
- return self.D*spv.laplacian(u) + self._eval_f(u, t, sp)
150
-
151
- def rhs(self, u, t):
150
+ def rhs_analytic(self, t, u):
151
+ return self.D*spv.laplacian(u) + self._eval_f(t, u, sp)
152
+
153
+ def rhs(self, t, u):
152
154
  laplace = self.vg.laplace(self.pad_bc(u))
153
- update = self.D * laplace + self._eval_f(u, t, self.vg.lib)
155
+ update = self.D * laplace + self._eval_f(t, u, self.vg.lib)
154
156
  return update
155
157
 
156
158
  @dataclass
@@ -185,15 +187,15 @@ class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
185
187
  def pad_bc(self, u):
186
188
  return self.pad_boundary(u, self.bcs[0], self.bcs[1])
187
189
 
188
- def rhs_analytic(self, mask, u, t):
189
- grad = spv.gradient(u)
190
- norm_grad = sp.sqrt(grad.dot(grad))
190
+ def rhs_analytic(self, t, u, mask):
191
+ grad_m = spv.gradient(mask)
192
+ norm_grad_m = sp.sqrt(grad_m.dot(grad_m))
191
193
 
192
- divergence = spv.divergence(self.D*(grad - u/mask*spv.gradient(mask)))
193
- du = divergence + norm_grad*self.bc_flux + mask*self._eval_f(u/mask, t, sp)
194
+ divergence = spv.divergence(self.D*(spv.gradient(u) - u/mask*grad_m))
195
+ du = divergence + norm_grad_m*self.bc_flux + mask*self._eval_f(t, u/mask, sp)
194
196
  return du
195
197
 
196
- def rhs(self, u, t):
198
+ def rhs(self, t, u):
197
199
  z = self.pad_bc(u)
198
200
  divergence = self.vg.grad_x_face(self.vg.grad_x_face(z) -\
199
201
  self.vg.to_x_face(z/self.mask) * self.vg.grad_x_face(self.mask)
@@ -207,7 +209,7 @@ class ReactionDiffusionSBM(ReactionDiffusion, SmoothedBoundaryODE):
207
209
 
208
210
  update = self.D * divergence + \
209
211
  self.norm*self.bc_flux + \
210
- self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(u/self.mask[:,1:-1,1:-1,1:-1], t, self.vg.lib)
212
+ self.mask[:,1:-1,1:-1,1:-1]*self._eval_f(t, u/self.mask[:,1:-1,1:-1,1:-1], self.vg.lib)
211
213
  return update
212
214
 
213
215
 
@@ -249,13 +251,13 @@ class PeriodicCahnHilliard(SemiLinearODE):
249
251
  except TypeError:
250
252
  return self.mu_hom(c)
251
253
 
252
- def rhs_analytic(self, c, t):
254
+ def rhs_analytic(self, t, c):
253
255
  mu = self._eval_mu(c, sp) - 2*self.eps*spv.laplacian(c)
254
256
  fluxes = self.D*c*(1-c)*spv.gradient(mu)
255
257
  rhs = spv.divergence(fluxes)
256
258
  return rhs
257
259
 
258
- def rhs(self, c, t):
260
+ def rhs(self, t, c):
259
261
  r"""Evaluate :math:`\partial c / \partial t` for the CH equation.
260
262
 
261
263
  Numerical computation of
@@ -341,7 +343,7 @@ class AllenCahnEquation(SemiLinearODE):
341
343
  except TypeError:
342
344
  return self.potential(phi)
343
345
 
344
- def rhs_analytic(self, phi, t):
346
+ def rhs_analytic(self, t, phi):
345
347
  grad = spv.gradient(phi)
346
348
  laplace = spv.laplacian(phi)
347
349
  norm_grad = sp.sqrt(grad.dot(grad))
@@ -354,7 +356,7 @@ class AllenCahnEquation(SemiLinearODE):
354
356
  + 3/self.eps * phi * (1-phi) * self.force
355
357
  return self.M * df_dphi
356
358
 
357
- def rhs(self, phi, t):
359
+ def rhs(self, t, phi):
358
360
  r"""Two-phase Allen-Cahn equation
359
361
 
360
362
  Microstructural evolution of the order parameter ``\phi``
@@ -422,13 +424,13 @@ class CoupledReactionDiffusion(SemiLinearODE):
422
424
  except TypeError:
423
425
  return self.interaction(u)
424
426
 
425
- def rhs_analytic(self, u, t):
427
+ def rhs_analytic(self, t, u):
426
428
  interaction = self._eval_interaction(u, sp)
427
429
  dc_A = self.D_A*spv.laplacian(u[0]) - interaction + self.feed * (1-u[0])
428
430
  dc_B = self.D_B*spv.laplacian(u[1]) + interaction - self.kill * u[1]
429
431
  return (dc_A, dc_B)
430
432
 
431
- def rhs(self, u, t):
433
+ def rhs(self, t, u):
432
434
  r"""Two-component reaction-diffusion system
433
435
 
434
436
  Use batch channels for multiple species:
evoxels/profiler.py CHANGED
@@ -3,12 +3,20 @@ import psutil
3
3
  import os
4
4
  import subprocess
5
5
  import tracemalloc
6
+ import shutil
6
7
  from abc import ABC, abstractmethod
7
8
 
8
9
  class MemoryProfiler(ABC):
9
10
  """Base interface for tracking host and device memory usage."""
11
+ def __init__(self):
12
+ self.max_used_cpu = 0.0
13
+ self.max_used_gpu = 0.0
14
+ self.track_gpu = False # subclasses set this
15
+
10
16
  def get_cuda_memory_from_nvidia_smi(self):
11
17
  """Return currently used CUDA memory in megabytes."""
18
+ if shutil.which("nvidia-smi") is None:
19
+ return None
12
20
  try:
13
21
  output = subprocess.check_output(
14
22
  ['nvidia-smi', '--query-gpu=memory.used',
@@ -23,8 +31,10 @@ class MemoryProfiler(ABC):
23
31
  process = psutil.Process(os.getpid())
24
32
  used_cpu = process.memory_info().rss / 1024**2
25
33
  self.max_used_cpu = np.max((self.max_used_cpu, used_cpu))
26
- used = self.get_cuda_memory_from_nvidia_smi()
27
- self.max_used_gpu = np.max((self.max_used_gpu, used))
34
+ if self.track_gpu:
35
+ used = self.get_cuda_memory_from_nvidia_smi()
36
+ if used is not None:
37
+ self.max_used_gpu = np.max((self.max_used_gpu, used))
28
38
 
29
39
  @abstractmethod
30
40
  def print_memory_stats(self, start: float, end: float, iters: int):
@@ -35,13 +45,14 @@ class TorchMemoryProfiler(MemoryProfiler):
35
45
  def __init__(self, device):
36
46
  """Initialize the profiler for a given torch device."""
37
47
  import torch
48
+ super().__init__()
38
49
  self.torch = torch
39
50
  self.device = device
51
+ self.track_gpu = (device.type == 'cuda')
52
+
40
53
  tracemalloc.start()
41
- if device.type == 'cuda':
54
+ if self.track_gpu:
42
55
  torch.cuda.reset_peak_memory_stats(device=device)
43
- self.max_used_gpu = 0
44
- self.max_used_cpu = 0
45
56
 
46
57
  def print_memory_stats(self, start, end, iters):
47
58
  """Print usage statistics for the Torch backend."""
@@ -60,7 +71,10 @@ class TorchMemoryProfiler(MemoryProfiler):
60
71
  elif self.device.type == 'cuda':
61
72
  self.update_memory_stats()
62
73
  used = self.get_cuda_memory_from_nvidia_smi()
63
- print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
74
+ if used is None:
75
+ print("GPU-RAM (nvidia-smi) unavailable.")
76
+ else:
77
+ print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
64
78
  print(f"GPU-RAM (torch) current: "
65
79
  f"{self.torch.cuda.memory_allocated(self.device) / 1024**2:.2f} MB "
66
80
  f"({self.torch.cuda.max_memory_allocated(self.device) / 1024**2:.2f} MB max, "
@@ -70,9 +84,9 @@ class JAXMemoryProfiler(MemoryProfiler):
70
84
  def __init__(self):
71
85
  """Initialize the profiler for JAX."""
72
86
  import jax
87
+ super().__init__()
73
88
  self.jax = jax
74
- self.max_used_gpu = 0
75
- self.max_used_cpu = 0
89
+ self.track_gpu = any(d.platform == "gpu" for d in jax.devices())
76
90
  tracemalloc.start()
77
91
 
78
92
  def print_memory_stats(self, start, end, iters):
@@ -88,7 +102,10 @@ class JAXMemoryProfiler(MemoryProfiler):
88
102
  current = process.memory_info().rss / 1024**2
89
103
  print(f"CPU-RAM (psutil) current: {current:.2f} MB ({self.max_used_cpu:.2f} MB max)")
90
104
 
91
- if self.jax.default_backend() == 'gpu':
105
+ if self.track_gpu:
92
106
  self.update_memory_stats()
93
107
  used = self.get_cuda_memory_from_nvidia_smi()
94
- print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
108
+ if used is None:
109
+ print("GPU-RAM (nvidia-smi) unavailable.")
110
+ else:
111
+ print(f"GPU-RAM (nvidia-smi) current: {used} MB ({self.max_used_gpu} MB max)")
evoxels/solvers.py CHANGED
@@ -1,13 +1,14 @@
1
1
  from IPython.display import clear_output
2
2
  from dataclasses import dataclass
3
3
  from typing import Callable, Any, Type
4
+ from abc import ABC, abstractmethod
4
5
  from timeit import default_timer as timer
5
6
  import sys
6
7
  from .problem_definition import ODE
7
8
  from .timesteppers import TimeStepper
8
9
 
9
10
  @dataclass
10
- class TimeDependentSolver:
11
+ class BaseSolver(ABC):
11
12
  """Generic wrapper for solving one or more fields with a time stepper."""
12
13
  vf: Any # VoxelFields object
13
14
  fieldnames: str | list[str]
@@ -33,46 +34,24 @@ class TimeDependentSolver:
33
34
  self.profiler = JAXMemoryProfiler()
34
35
  else:
35
36
  raise ValueError(f"Unsupported backend: {self.backend}")
36
-
37
- def solve(
38
- self,
39
- time_increment=0.1,
40
- frames=10,
41
- max_iters=100,
42
- problem_kwargs=None,
43
- jit=True,
44
- verbose=True,
45
- vtk_out=False,
46
- plot_bounds=None,
47
- colormap='viridis'
48
- ):
49
- """Run the time integration loop.
50
-
51
- Args:
52
- time_increment (float): Size of a single time step.
53
- frames (int): Number of output frames (for plotting, vtk, checks).
54
- max_iters (int): Number of time steps to compute.
55
- problem_kwargs (dict | None): Problem-specific input arguments.
56
- jit (bool): Create just-in-time compiled kernel if ``True``
57
- verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
58
- updates an interactive plot.
59
- vtk_out (bool): Write VTK files for each frame if ``True``.
60
- plot_bounds (tuple | None): Optional value range for plots.
61
- """
62
-
63
- problem_kwargs = problem_kwargs or {}
37
+
64
38
  if isinstance(self.fieldnames, str):
65
39
  self.fieldnames = [self.fieldnames]
66
40
  else:
67
41
  self.fieldnames = list(self.fieldnames)
68
42
 
43
+ def _init_fields(self):
44
+ """Initialize fields in the voxel grid."""
69
45
  u_list = [self.vg.init_scalar_field(self.vf.fields[name]) for name in self.fieldnames]
70
46
  u = self.vg.concatenate(u_list, 0)
71
47
  u = self.vg.bc.trim_boundary_nodes(u)
72
-
48
+ return u
49
+
50
+ def _init_stepper(self, time_increment, problem_kwargs, jit):
51
+ problem_kwargs = problem_kwargs or {}
73
52
  if self.step_fn is not None:
74
- step = self.step_fn
75
53
  self.problem = None
54
+ step = self.step_fn
76
55
  else:
77
56
  if self.problem_cls is None or self.timestepper_cls is None:
78
57
  raise ValueError("Either provide step_fn or both problem_cls and timestepper_cls")
@@ -88,23 +67,47 @@ class TimeDependentSolver:
88
67
  import torch
89
68
  step = torch.compile(step)
90
69
 
91
- n_out = max_iters // frames
92
- frame = 0
93
- slice_idx = self.vf.Nz // 2
70
+ return step
71
+
72
+ @abstractmethod
73
+ def _run_loop(self, u, step, time_increment, frames, max_iters,
74
+ vtk_out, verbose, plot_bounds, colormap):
75
+ """Abstract method for running the time integration loop."""
76
+ raise NotImplementedError("Subclasses must implement _run_loop method.")
94
77
 
95
- start = timer()
96
- for i in range(max_iters):
97
- time = i * time_increment
98
- if i % n_out == 0:
99
- self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
100
- frame += 1
78
+ def solve(
79
+ self,
80
+ time_increment=0.1,
81
+ frames=10,
82
+ max_iters=100,
83
+ problem_kwargs=None,
84
+ jit=True,
85
+ verbose=True,
86
+ vtk_out=False,
87
+ plot_bounds=None,
88
+ colormap='viridis'
89
+ ):
90
+ """Run the time integration loop.
101
91
 
102
- u = step(u, time)
92
+ Args:
93
+ time_increment (float): Size of a single time step.
94
+ frames (int): Number of output frames (for plotting, vtk, checks).
95
+ max_iters (int): Number of time steps to compute.
96
+ problem_kwargs (dict | None): Problem-specific input arguments.
97
+ jit (bool): Create just-in-time compiled kernel if ``True``
98
+ verbose (bool | str): If ``True`` prints memory stats, ``'plot'``
99
+ updates an interactive plot.
100
+ vtk_out (bool): Write VTK files for each frame if ``True``.
101
+ plot_bounds (tuple | None): Optional value range for plots.
102
+ """
103
+ u = self._init_fields()
104
+ step = self._init_stepper(time_increment, problem_kwargs, jit)
103
105
 
106
+ start = timer()
107
+ u = self._run_loop(u, step, time_increment, frames, max_iters,
108
+ vtk_out, verbose, plot_bounds, colormap)
104
109
  end = timer()
105
- time = max_iters * time_increment
106
- self._handle_outputs(u, frame, time, slice_idx, vtk_out, verbose, plot_bounds, colormap)
107
-
110
+ self.computation_time = end - start
108
111
  if verbose:
109
112
  self.profiler.print_memory_stats(start, end, max_iters)
110
113
 
@@ -113,10 +116,10 @@ class TimeDependentSolver:
113
116
  if getattr(self, 'problem', None) is not None:
114
117
  u_out = self.vg.bc.trim_ghost_nodes(self.problem.pad_bc(u))
115
118
  else:
116
- u_out = u
119
+ u_out = self.vg.bc.trim_ghost_nodes(self.vg.pad_zeros(u))
117
120
 
118
121
  for i, name in enumerate(self.fieldnames):
119
- self.vf.fields[name] = self.vg.export_scalar_field_to_numpy(u_out[i:i+1])
122
+ self.vf.set_field(name, self.vg.export_scalar_field_to_numpy(u_out[i:i+1]))
120
123
 
121
124
  if verbose:
122
125
  self.profiler.update_memory_stats()
@@ -129,6 +132,70 @@ class TimeDependentSolver:
129
132
  filename = self.problem_cls.__name__ + "_" +\
130
133
  self.fieldnames[0] + f"_{frame:03d}.vtk"
131
134
  self.vf.export_to_vtk(filename=filename, field_names=self.fieldnames)
135
+
132
136
  if verbose == 'plot':
133
137
  clear_output(wait=True)
134
138
  self.vf.plot_slice(self.fieldnames[0], slice_idx, time=time, colormap=colormap, value_bounds=plot_bounds)
139
+
140
+ @dataclass
141
+ class TimeDependentSolver(BaseSolver):
142
+ """Solver for time-dependent problems."""
143
+ def _run_loop(self, u, step, time_increment, frames, max_iters,
144
+ vtk_out, verbose, plot_bounds, colormap):
145
+ n_out = max_iters // frames
146
+ frame = 0
147
+ slice_idx = self.vf.Nz // 2
148
+
149
+ for i in range(max_iters):
150
+ time = i * time_increment
151
+ if i % n_out == 0:
152
+ self._handle_outputs(u, frame, time, slice_idx, vtk_out,
153
+ verbose, plot_bounds, colormap)
154
+ frame += 1
155
+
156
+ u = step(time, u)
157
+ time = max_iters * time_increment
158
+ self._handle_outputs(u, frame, time, slice_idx, vtk_out,
159
+ verbose, plot_bounds, colormap)
160
+ return u
161
+
162
+ @dataclass
163
+ class SteadyStatePseudoTimeSolver(BaseSolver):
164
+ """Solver for steady-state problems."""
165
+ conv_crit: float = 1e-6
166
+ check_freq: int = 10
167
+
168
+ def _run_loop(self, u, step, time_increment, frames, max_iters,
169
+ vtk_out, verbose, plot_bounds, colormap):
170
+ slice_idx = self.vf.Nz // 2
171
+ self.converged = False
172
+ self.iter = 0
173
+
174
+ while not self.converged and self.iter < max_iters:
175
+ time = self.iter * time_increment
176
+ diff = u - step(time, u)
177
+ u = step(time, u)
178
+
179
+ if self.iter % self.check_freq == 0:
180
+ self.converged = self.check_convergence(diff, verbose)
181
+
182
+ self._handle_outputs(u, 0, time, slice_idx, vtk_out,
183
+ verbose, plot_bounds, colormap)
184
+ return u
185
+
186
+ def check_convergence(self, diff, verbose):
187
+ """Check for convergence based on relative change in fields."""
188
+ converged = True
189
+ for i, name in enumerate(self.fieldnames):
190
+ # Check if Frobenius norm of change is below threshold
191
+ rel_change = self.vg.lib.linalg.norm(diff[i]) / \
192
+ self.vg.lib.sqrt(self.vf.Nx * self.vf.Ny * self.vf.Nz)
193
+ if rel_change > self.conv_crit:
194
+ converged = False
195
+ if verbose:
196
+ print(f"Iter {self.iter}: Field '{name}' relative change: {rel_change:.2e}")
197
+
198
+ if converged and verbose:
199
+ print(f"Converged after {self.iter} iterations.")
200
+
201
+ return converged
evoxels/timesteppers.py CHANGED
@@ -16,13 +16,13 @@ class TimeStepper(ABC):
16
16
  pass
17
17
 
18
18
  @abstractmethod
19
- def step(self, u: State, t: float) -> State:
19
+ def step(self, t: float, u: State) -> State:
20
20
  """
21
21
  Take one timestep from t to (t+dt).
22
22
 
23
23
  Args:
24
- u : Current state
25
24
  t : Current time
25
+ u : Current state
26
26
  Returns:
27
27
  Updated state at t + dt.
28
28
  """
@@ -39,8 +39,8 @@ class ForwardEuler(TimeStepper):
39
39
  def order(self) -> int:
40
40
  return 1
41
41
 
42
- def step(self, u: State, t: float) -> State:
43
- return u + self.dt * self.problem.rhs(u, t)
42
+ def step(self, t: float, u: State) -> State:
43
+ return u + self.dt * self.problem.rhs(t, u)
44
44
 
45
45
 
46
46
  @dataclass
@@ -68,8 +68,8 @@ class PseudoSpectralIMEX(TimeStepper):
68
68
  def order(self) -> int:
69
69
  return 1
70
70
 
71
- def step(self, u: State, t: float) -> State:
72
- dc = self.pad(self.problem.rhs(u, t))
71
+ def step(self, t: float, u: State) -> State:
72
+ dc = self.pad(self.problem.rhs(t, u))
73
73
  dc_fft = self._fft_prefac * self.problem.vg.rfftn(dc, dc.shape)
74
74
  update = self.problem.vg.irfftn(dc_fft, dc.shape)[:,:u.shape[1]]
75
75
  return u + update
evoxels/utils.py CHANGED
@@ -3,15 +3,21 @@ import sympy as sp
3
3
  import sympy.vector as spv
4
4
  import evoxels as evo
5
5
  from evoxels.problem_definition import SmoothedBoundaryODE
6
+ from evoxels.solvers import TimeDependentSolver
7
+ import contextlib
8
+ import io
9
+ import matplotlib.pyplot as plt
10
+ from mpl_toolkits.mplot3d import Axes3D # noqa: F401 (needed for 3D projection)
11
+ from matplotlib.patches import Patch
6
12
 
7
13
  ### Generalized test case
8
14
  def rhs_convergence_test(
9
- ODE_class, # an ODE class with callable rhs(field, t)->torch.Tensor (shape [x,y,z])
10
- problem_kwargs, # problem parameters to instantiate ODE
11
- test_function, # exact init_fun(x,y,z)->np.ndarray
12
- mask_function=None,
13
- convention="cell_center",
14
- dtype="float32",
15
+ ODE_class,
16
+ problem_kwargs,
17
+ test_function,
18
+ mask_function = None,
19
+ convention = "cell_center",
20
+ dtype = "float32",
15
21
  powers = np.array([3,4,5,6,7]),
16
22
  backend = "torch"
17
23
  ):
@@ -22,7 +28,7 @@ def rhs_convergence_test(
22
28
  slope arrays have one entry for each provided function.
23
29
 
24
30
  Args:
25
- ODE_class: an ODE class with callable rhs(field, t).
31
+ ODE_class: ODE class with callable rhs(t, u).
26
32
  problem_kwargs: problem-specific parameters to instantiate ODE.
27
33
  test_function: single sympy expression or a list of expressions.
28
34
  mask_function: static mask for smoothed boundary method.
@@ -38,7 +44,9 @@ def rhs_convergence_test(
38
44
  "is not a SmoothedBoundaryODE."
39
45
  )
40
46
  CS = spv.CoordSys3D('CS')
47
+
41
48
  # Prepare lambdified mask if needed
49
+ # Assumed to be static i.e. no function of t
42
50
  mask = (
43
51
  sp.lambdify((CS.x, CS.y, CS.z), mask_function, "numpy")
44
52
  if mask_function is not None
@@ -67,50 +75,51 @@ def rhs_convergence_test(
67
75
  elif convention == 'staggered_x':
68
76
  vf = evo.VoxelFields((2**p + 1, 2**p, 2**p), (1, 1, 1), convention=convention)
69
77
  vf.precision = dtype
70
- grid = vf.meshgrid()
78
+ dx[i] = vf.spacing[0]
79
+
71
80
  if backend == 'torch':
72
81
  vg = evo.voxelgrid.VoxelGridTorch(vf.grid_info(), precision=vf.precision, device='cpu')
73
82
  elif backend == 'jax':
74
83
  vg = evo.voxelgrid.VoxelGridJax(vf.grid_info(), precision=vf.precision)
75
-
84
+
85
+ # Init mask if smoothed boundary ODE
86
+ numpy_grid = vf.meshgrid()
87
+ if mask is not None:
88
+ problem_kwargs["mask"] = mask(*numpy_grid)
89
+
76
90
  # Initialise fields
77
91
  u_list = []
78
92
  for func in test_functions:
79
93
  init_fun = sp.lambdify((CS.x, CS.y, CS.z), func, "numpy")
80
- init_data = init_fun(*grid)
94
+ init_data = init_fun(*numpy_grid)
81
95
  u_list.append(vg.init_scalar_field(init_data))
82
96
 
83
97
  u = vg.concatenate(u_list, 0)
84
98
  u = vg.bc.trim_boundary_nodes(u)
85
99
 
86
- # Init mask if smoothed boundary ODE
87
- if mask is not None:
88
- problem_kwargs["mask"] = mask(*grid)
89
-
90
100
  ODE = ODE_class(vg, **problem_kwargs)
91
- rhs_numeric = ODE.rhs(u, 0)
101
+ rhs_numeric = ODE.rhs(0, u)
92
102
 
93
103
  if n_funcs > 1 and mask is not None:
94
- rhs_analytic = ODE.rhs_analytic(mask_function, test_functions, 0)
104
+ rhs_analytic = ODE.rhs_analytic(0, test_functions, mask_function)
95
105
  elif n_funcs > 1 and mask is None:
96
- rhs_analytic = ODE.rhs_analytic(test_functions, 0)
106
+ rhs_analytic = ODE.rhs_analytic(0, test_functions)
97
107
  elif n_funcs == 1 and mask is not None:
98
- rhs_analytic = [ODE.rhs_analytic(mask_function, test_functions[0], 0)]
108
+ rhs_analytic = [ODE.rhs_analytic(0, test_functions[0], mask_function)]
99
109
  else:
100
- rhs_analytic = [ODE.rhs_analytic(test_functions[0], 0)]
110
+ rhs_analytic = [ODE.rhs_analytic(0, test_functions[0])]
101
111
 
102
112
  # Compute solutions
103
113
  for j, func in enumerate(test_functions):
104
114
  comp = vg.export_scalar_field_to_numpy(rhs_numeric[j:j+1])
105
115
  exact_fun = sp.lambdify((CS.x, CS.y, CS.z), rhs_analytic[j], "numpy")
106
- exact = exact_fun(*grid)
116
+ exact = exact_fun(*numpy_grid)
107
117
  if convention == "staggered_x":
108
118
  exact = exact[1:-1, :, :]
109
119
 
110
120
  # Error norm
111
121
  diff = comp - exact
112
122
  errors[j, i] = np.linalg.norm(diff) / np.linalg.norm(exact)
113
- dx[i] = vf.spacing[0]
114
123
 
115
124
  # Fit slope after loop
116
125
  slopes = np.array(
@@ -121,4 +130,316 @@ def rhs_convergence_test(
121
130
  order = ODE.order
122
131
 
123
132
  return dx, errors if errors.shape[0] > 1 else errors[0], slopes, order
124
-
133
+
134
+
135
+ def mms_convergence_test(
136
+ ODE_class, # an ODE class with callable rhs(field, t)->torch.Tensor (shape [x,y,z])
137
+ problem_kwargs, # problem parameters to instantiate ODE
138
+ test_function, # exact init_fun(x,y,z)->np.ndarray
139
+ mask_function=None,
140
+ timestepper_cls=None,
141
+ convention="cell_center",
142
+ dtype="float32",
143
+ mode = 'temporal',
144
+ g_powers = np.array([3,4,5,6,7]),
145
+ t_powers = np.array([3,4,5,6,7]),
146
+ t_final = 1,
147
+ backend = "jax",
148
+ device = 'cpu'
149
+ ):
150
+ """Evaluate temporal and spatial order of ODE solution.
151
+
152
+ ``test_function`` can be a single sympy expression or a list of
153
+ expressions representing multiple variables. The returned error and
154
+ slope arrays have one entry for each provided function.
155
+
156
+ Args:
157
+ ODE_class: ODE class with callable rhs(t, u).
158
+ problem_kwargs: problem-specific parameters to instantiate ODE.
159
+ test_function: single sympy expression or a list of expressions.
160
+ mask_function: static mask for smoothed boundary method.
161
+ timestepper_cls: timestepper class with callable step(t, u).
162
+ convention: grid convention.
163
+ dtype: floate precision (``float32`` or ``float64``).
164
+ mode: Use ``temporal`` or ``spatial`` to construct MMS forcing.
165
+ g_powers: refine grid in powers of two (i.e. ``Nx = 2**p``).
166
+ t_powers: refine time increment in powers of two (i.e. ``dt = 2**p``).
167
+ t_final: End time for evaluation. Should be order of L^2/D.
168
+ backend: use ``torch`` or ``jax`` for testing.
169
+ device: use ``cpu`` or ``cuda`` for testing in torch.
170
+ """
171
+ # Verify mask_function only used with SmoothedBoundaryODE
172
+ if mask_function is not None and not issubclass(ODE_class, SmoothedBoundaryODE):
173
+ raise TypeError(
174
+ f"Mask function provided but {ODE_class.__name__} "
175
+ "is not a SmoothedBoundaryODE."
176
+ )
177
+ CS = spv.CoordSys3D('CS')
178
+ t = sp.symbols('t', real=True)
179
+
180
+ # Prepare lambdified mask if needed
181
+ # Assumed to be static i.e. no function of t
182
+ mask = (
183
+ sp.lambdify((CS.x, CS.y, CS.z), mask_function, "numpy")
184
+ if mask_function is not None
185
+ else None
186
+ )
187
+
188
+ if isinstance(test_function, (list, tuple)):
189
+ test_functions = list(test_function)
190
+ else:
191
+ test_functions = [test_function]
192
+ n_funcs = len(test_functions)
193
+
194
+ # Multiply test functions with mask for SBM testing
195
+ if mask is not None:
196
+ temp_list = []
197
+ for func in test_functions:
198
+ temp_list.append(func*mask_function)
199
+ test_functions = temp_list
200
+
201
+ if mode == 'temporal':
202
+ u_list = [sp.lambdify((t, CS.x, CS.y, CS.z),
203
+ sp.N(func), backend) \
204
+ for func in test_functions]
205
+ u_t_list = [sp.lambdify((t, CS.x, CS.y, CS.z),
206
+ sp.N(sp.diff(func, t)), backend) \
207
+ for func in test_functions]
208
+
209
+ dx = np.zeros(len(g_powers))
210
+ dt = np.zeros(len(t_powers))
211
+ errors = np.zeros((n_funcs, len(t_powers), len(g_powers)))
212
+
213
+ for i, p in enumerate(g_powers):
214
+ if convention == 'cell_center':
215
+ vf = evo.VoxelFields((2**p, 2**p, 2**p), (1, 1, 1), convention=convention)
216
+ elif convention == 'staggered_x':
217
+ vf = evo.VoxelFields((2**p + 1, 2**p, 2**p), (1, 1, 1), convention=convention)
218
+ else:
219
+ raise ValueError("Chosen convention must be cell_center or staggered_x.")
220
+ vf.precision = dtype
221
+ dx[i] = vf.spacing[0]
222
+
223
+ if backend == 'torch':
224
+ vg = evo.voxelgrid.VoxelGridTorch(vf.grid_info(), precision=vf.precision, device=device)
225
+ elif backend == 'jax':
226
+ vg = evo.voxelgrid.VoxelGridJax(vf.grid_info(), precision=vf.precision)
227
+
228
+ # Init mask if smoothed boundary ODE
229
+ numpy_grid = vf.meshgrid()
230
+ if mask is not None:
231
+ problem_kwargs["mask"] = mask(*numpy_grid)
232
+
233
+ ODE = ODE_class(vg, **problem_kwargs)
234
+ rhs_orig = ODE.rhs
235
+ grid = vg.meshgrid()
236
+ if convention == 'staggered_x':
237
+ grid = (grid[0][1:-1,:,:], grid[1][1:-1,:,:], grid[2][1:-1,:,:])
238
+
239
+ # Construct new rhs including forcing term from MMS
240
+ if mode == 'temporal':
241
+ def mms_rhs(t, u):
242
+ """Manufactured solution rhs
243
+ with numerical evaluation of rhs in forcing, i.e.
244
+ forcing = du/dt_exact(t,grid) - rhs_num(t, u_exact(t,grid))
245
+ """
246
+ rhs = rhs_orig(t, u)
247
+ t_ = vg.to_backend(t)
248
+ u_ex_list = []
249
+ for j, func in enumerate(test_functions):
250
+ u_ex_list.append(vg.expand_dim(u_list[j](t_, *grid), 0))
251
+ rhs = vg.set(rhs, j, rhs[j] + u_t_list[j](t_, *grid))
252
+ u_ex = vg.concatenate(u_ex_list, 0)
253
+ rhs -= rhs_orig(t, u_ex)
254
+ return rhs
255
+
256
+ elif mode == 'spatial':
257
+ if n_funcs > 1 and mask is not None:
258
+ rhs_func = ODE.rhs_analytic(t, test_functions, mask_function)
259
+ rhs_analytic = [sp.lambdify((t, CS.x, CS.y, CS.z), sp.N(func), backend) for func in rhs_func]
260
+ elif n_funcs > 1 and mask is None:
261
+ rhs_func = ODE.rhs_analytic(t, test_functions)
262
+ rhs_analytic = [sp.lambdify((t, CS.x, CS.y, CS.z), sp.N(func), backend) for func in rhs_func]
263
+ elif n_funcs == 1 and mask is not None:
264
+ rhs_func = ODE.rhs_analytic(t, test_functions[0], mask_function)
265
+ rhs_analytic = [sp.lambdify((t, CS.x, CS.y, CS.z), sp.N(rhs_func), backend)]
266
+ else:
267
+ rhs_func = ODE.rhs_analytic(t, test_functions[0])
268
+ rhs_analytic = [sp.lambdify((t, CS.x, CS.y, CS.z), sp.N(rhs_func), backend)]
269
+
270
+ def mms_rhs(t, u):
271
+ """Manufactured solution rhs
272
+ with analytical evaluation of rhs in forcing, i.e.
273
+ forcing = du/dt_exact(t,grid) - rhs_exact(t, grid)
274
+ """
275
+ rhs = rhs_orig(t, u)
276
+ t_ = vg.to_backend(t)
277
+ for j, func in enumerate(test_functions):
278
+ rhs = vg.set(rhs, j, rhs[j] - rhs_analytic[j](t_, *grid))
279
+ rhs = vg.set(rhs, j, rhs[j] + u_t_list[j](t_, *grid))
280
+ return rhs
281
+ else:
282
+ raise ValueError("Mode must be 'temporal' or 'spatial'.")
283
+
284
+ # Over-write original rhs with contructed mms_rhs
285
+ ODE.rhs = mms_rhs
286
+
287
+ # Loop over time refinements
288
+ for k, q in enumerate(t_powers):
289
+ # Initialise fields
290
+ field_names = []
291
+ for j, func in enumerate(test_functions):
292
+ fun = sp.lambdify((t, CS.x, CS.y, CS.z), func, "numpy")
293
+ init_data = fun(0, *numpy_grid)
294
+ final_data = fun(t_final, *numpy_grid)
295
+ vf.add_field(f'u{j}', init_data)
296
+ vf.add_field(f'u{j}_final', final_data)
297
+ field_names.append(f'u{j}')
298
+
299
+ # Init time increment and step function
300
+ dt[k] = t_final / 2**q
301
+ timestepper = timestepper_cls(ODE, dt[k])
302
+ step = timestepper.step
303
+
304
+ # Init solver
305
+ solver = TimeDependentSolver(
306
+ vf, field_names,
307
+ backend, device=device,
308
+ step_fn=step
309
+ )
310
+
311
+ # Wrap solve to capture NaN exit
312
+ nan_hit = False
313
+ buf = io.StringIO()
314
+ with contextlib.redirect_stdout(buf):
315
+ try:
316
+ solver.solve(dt[k], 8, int(2**q), problem_kwargs, verbose=False)
317
+ except SystemExit:
318
+ nan_hit = True
319
+
320
+ if nan_hit:
321
+ errors[:, k, i] = np.nan
322
+ continue
323
+
324
+ # Compute relative L2 error
325
+ for j, func in enumerate(test_functions):
326
+ exact = vf.fields[f'u{j}_final']
327
+ diff = vf.fields[f'u{j}'] - exact
328
+ if convention == 'staggered_x':
329
+ errors[j, k, i] = np.linalg.norm(diff[1:-1,:,:]) / np.linalg.norm(exact[1:-1,:,:])
330
+ else:
331
+ errors[j, k, i] = np.linalg.norm(diff) / np.linalg.norm(exact)
332
+
333
+ # Fit slope after loop
334
+ def calc_slope(x, y):
335
+ mask = np.isfinite(y)
336
+ if mask.sum() < 2:
337
+ return np.nan
338
+ return np.polyfit(np.log(x[mask]), np.log(y[mask]), 1)[0]
339
+
340
+ t_slopes = np.array([calc_slope(dt, err[:,0]) for err in errors])
341
+ g_slopes = np.array([calc_slope(dx, err[-1,:]) for err in errors])
342
+
343
+ results = {
344
+ 'dt': dt,
345
+ 'dx': dx,
346
+ 'error': errors if n_funcs > 1 else errors[0],
347
+ 't_slopes': t_slopes if n_funcs > 1 else t_slopes[0],
348
+ 'g_slopes': g_slopes if n_funcs > 1 else g_slopes[0],
349
+ 'n_funcs': n_funcs,
350
+ 't_order': timestepper.order,
351
+ 'g_order': ODE.order,
352
+ }
353
+ return results
354
+
355
+
356
+ def plot_error_surface(series, log_axes=(True, True, True), z_max=0, title=None, alpha=0.4):
357
+ """
358
+ Plot one or more 3D surfaces z(x, y) with semi-transparent tiles and solid mesh lines.
359
+
360
+ Parameters
361
+ ----------
362
+ series : tuple[list] of dict
363
+ Each dict must have:
364
+ - 'dt': 1D array-like of dt-values (length Nx)
365
+ - 'dx': 1D array-like of dx-values (length Ny)
366
+ - 'error': 2D array-like of values Z(X, Y) with shape (Nx, Ny)
367
+ - 'name': (optional) label for legend
368
+ log_axes : tuple(bool, bool, bool)
369
+ (log_x, log_y, log_z): apply log10 to respective axis data when True.
370
+ For Z, nonpositive values are masked to NaN before log10.
371
+ title : str or None
372
+ Plot title.
373
+ alpha : float
374
+ Face transparency for surfaces.
375
+ """
376
+ if not isinstance(series, (list, tuple)) or len(series) == 0:
377
+ raise ValueError("`series` must be a non-empty tuple/list of dictionaries.")
378
+
379
+ log_x, log_y, log_z = log_axes
380
+
381
+ # Distinct colors
382
+ base_colors = ['tab:red', 'tab:blue', 'tab:green', 'tab:gray',
383
+ 'tab:purple', 'tab:brown', 'tab:pink', 'tab:orange',
384
+ 'tab:olive', 'tab:cyan']
385
+
386
+ fig = plt.figure(figsize=(5, 5))
387
+ ax = fig.add_subplot(111, projection='3d')
388
+ legend_patches = []
389
+
390
+ count = 0
391
+ for i, s in enumerate(series):
392
+ if not isinstance(s, dict) or not all(k in s for k in ('dt', 'dx', 'error')):
393
+ raise ValueError(f"Item {i} must be a dict with keys 'dt', 'dx', 'error' (and optional 'name').")
394
+
395
+ x_in = np.asarray(s['dt'])
396
+ y_in = np.asarray(s['dx'])
397
+ Z = np.asarray(s['error'])
398
+ Z = np.expand_dims(Z, axis=0) if Z.ndim == 2 else Z
399
+ name = s.get('name', f'[{i}]')
400
+
401
+ # Handle (1D,1D,2D) or (2D,2D,2D)
402
+ if x_in.ndim == 1 and y_in.ndim == 1:
403
+ X, Y = np.meshgrid(x_in, y_in, indexing='ij') # (Nx, Ny)
404
+ else:
405
+ raise ValueError(f"Item {i}: dt and dx must both be 1D grids.")
406
+
407
+ if Z.shape[1:] != X.shape:
408
+ raise ValueError(f"Item {i}: z.shape {Z.shape} must match x/y grid shape {X.shape}.")
409
+
410
+ # Apply log scaling
411
+ Xp = np.log10(X) if log_x else X
412
+ Yp = np.log10(Y) if log_y else Y
413
+ if log_z:
414
+ Z = np.where(Z > 0, Z, np.nan)
415
+ Zp = np.log10(Z)
416
+ else:
417
+ Zp = Z
418
+
419
+ for j in range(s['n_funcs']):
420
+ color = base_colors[count % len(base_colors)]
421
+ ax.plot_surface(
422
+ Xp, Yp, Zp[j],
423
+ color=color, # uniform color per surface
424
+ alpha=alpha, # semi-transparent tiles
425
+ edgecolor=color, # solid mesh lines
426
+ linewidth=0.6,
427
+ antialiased=True,
428
+ shade=False
429
+ )
430
+ label = name + f"_u{j}" if j > 0 else name
431
+ legend_patches.append(Patch(facecolor=color, edgecolor=color, alpha=alpha, label=label))
432
+ count += 1
433
+
434
+ # Axis labels reflect log choice
435
+ ax.set_xlabel('log10(dt)' if log_x else 'dt')
436
+ ax.set_ylabel('log10(dx)' if log_y else 'dx')
437
+ ax.text2D(0.0, 0.8, 'log10(error)' if log_z else 'error',
438
+ transform=ax.transAxes, va="top", ha="left")
439
+ ax.set_zlim(top=z_max)
440
+ ax.set_title(title or 'Error Surfaces')
441
+ ax.view_init(elev=25., azim=-145, roll=0)
442
+
443
+ ax.legend(handles=legend_patches, loc='best')
444
+ fig.tight_layout()
445
+ plt.show()
evoxels/voxelfields.py CHANGED
@@ -120,30 +120,40 @@ class VoxelFields:
120
120
  grid = Grid(self.shape, self.origin, self.spacing, self.convention)
121
121
  return grid
122
122
 
123
- def add_field(self, name: str, array=None):
123
+ def set_field(self, name: str, array: np.ndarray):
124
124
  """
125
- Adds a field to the voxel grid.
125
+ Set field values for an existing field in the voxel grid.
126
126
 
127
127
  Args:
128
128
  name (str): Name of the field.
129
- array (numpy.ndarray, optional): 3D array to initialize the field. If None, initializes with zeros.
129
+ array (numpy.ndarray, optional): 3D array.
130
130
 
131
131
  Raises:
132
132
  ValueError: If the provided array does not match the voxel grid dimensions.
133
133
  TypeError: If the provided array is not a numpy array.
134
134
  """
135
- if array is not None:
136
- if isinstance(array, np.ndarray):
137
- if array.shape == self.shape:
138
- self.fields[name] = array
139
- else:
140
- raise ValueError(
141
- f"The provided array must have the shape {self.shape}."
142
- )
135
+ if isinstance(array, np.ndarray):
136
+ if array.shape == self.shape:
137
+ self.fields[name] = array
143
138
  else:
144
- raise TypeError("The provided array must be a numpy array.")
139
+ raise ValueError(
140
+ f"The provided array must have the shape {self.shape}."
141
+ )
142
+ else:
143
+ raise TypeError("The provided array must be a numpy array.")
144
+
145
+ def add_field(self, name: str, array=None):
146
+ """
147
+ Adds a field to the voxel grid.
148
+
149
+ Args:
150
+ name (str): Name of the field.
151
+ array (numpy.ndarray, optional): 3D array to initialize the field. If None, initializes with zeros.
152
+ """
153
+ if array is not None:
154
+ self.set_field(name, array)
145
155
  else:
146
- self.fields[name] = np.zeros(self.shape)
156
+ self.set_field(name, np.zeros(self.shape))
147
157
 
148
158
  def set_voxel_sphere(self, name: str, center, radius, label: int | float = 1):
149
159
  """Create a voxelized representation of a sphere in 3D
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evoxels
3
- Version: 0.1.1
3
+ Version: 1.0.0
4
4
  Summary: Differentiable physics framework for voxel-based microstructure simulations
5
5
  Author-email: Simon Daubner <s.daubner@imperial.ac.uk>
6
6
  License: MIT
@@ -40,6 +40,7 @@ Provides-Extra: notebooks
40
40
  Requires-Dist: ipywidgets; extra == "notebooks"
41
41
  Requires-Dist: ipympl; extra == "notebooks"
42
42
  Requires-Dist: notebook; extra == "notebooks"
43
+ Requires-Dist: taufactor; extra == "notebooks"
43
44
  Dynamic: license-file
44
45
 
45
46
  [![Python package](https://github.com/daubners/evoxels/actions/workflows/python-package.yml/badge.svg?branch=main)](https://github.com/daubners/evoxels/actions/workflows/python-package.yml)
@@ -81,7 +82,7 @@ TL;DR
81
82
  ```bash
82
83
  conda create --name voxenv python=3.12
83
84
  conda activate voxenv
84
- pip install evoxels[torch,jax,dev,notebooks]
85
+ pip install "evoxels[torch,jax,dev,notebooks]"
85
86
  pip install --upgrade "jax[cuda12]"
86
87
  ```
87
88
 
@@ -103,7 +104,7 @@ Navigate to the evoxels folder, then
103
104
  ```
104
105
  pip install -e .[torch] # install with torch backend
105
106
  pip install -e .[jax] # install with jax backend
106
- pip install -e .[dev, notebooks] # install testing and notebooks
107
+ pip install -e .[dev,notebooks] # install testing and notebooks
107
108
  ```
108
109
  Note that the default `[jax]` installation is only CPU compatible. To install the corresponding CUDA libraries check your CUDA version with
109
110
  ```bash
@@ -115,7 +116,7 @@ pip install -U "jax[cuda12]"
115
116
  ```
116
117
  To install both backends within one environment it is important to install torch first and then upgrade the `jax` installation e.g.
117
118
  ```bash
118
- pip install evoxels[torch, jax, dev, notebooks]
119
+ pip install "evoxels[torch,jax,dev,notebooks]"
119
120
  pip install --upgrade "jax[cuda12]"
120
121
  ```
121
122
  To work with the example notebooks install Jupyter and all notebook related dependencies via
@@ -0,0 +1,20 @@
1
+ evoxels/__init__.py,sha256=LLogj7BjzLNi9voNIQkmDYmj_dT82Aw8ph9RAlaPLuU,376
2
+ evoxels/boundary_conditions.py,sha256=IZbPFvPVknWtTyXdDYEQiWCmw1RhwAlCfx0GBpGt7wg,5129
3
+ evoxels/fd_stencils.py,sha256=kpFszQqjSPuWzRPyscR13uaDM9GWiBvmBVopBRosohs,4581
4
+ evoxels/function_approximators.py,sha256=_WwsypBWSigMlyvT_Qgc8rHN7SsyvyFL9WgB1QqPdHY,2545
5
+ evoxels/inversion.py,sha256=9EQcr2-Xe5zAEhbqd-ZbOPRqqfRtrHE8pKW2TFPXdC8,8694
6
+ evoxels/problem_definition.py,sha256=2i0_qwgxX_6Fi0qZR4gB7u_p5OU4IKuTjSYVFVX20MA,14641
7
+ evoxels/profiler.py,sha256=3-kIku5ebOEawKL73LIKZ6HnM0odqhGRrLf54fbY7SY,4544
8
+ evoxels/solvers.py,sha256=s3s_tjvO8YTnYFc_u-2WY6Jyoz47O8tWWkci4k9rLLg,7886
9
+ evoxels/timesteppers.py,sha256=VUvPrnhsHW3H7FwfrGpwvXl0GAWpThf-ULH-tIT44_8,3509
10
+ evoxels/utils.py,sha256=YrUz4JGORrhogjYykojEvUVSfuAkJqfgw9pA3DI1eCQ,17306
11
+ evoxels/voxelfields.py,sha256=UYY2mg5xONkiKwWBji-XIMu_t3isg3ey9J8XKmBYb5A,13573
12
+ evoxels/voxelgrid.py,sha256=r5yoo2J6ogHEOzFbhjikbi4wXQfurmZB4EWf6AEHIK0,9549
13
+ evoxels/precompiled_solvers/__init__.py,sha256=V8oekjyg13ziQ1Gdf0pHHYubZcB6Oec5a8RG0d09Y1c,50
14
+ evoxels/precompiled_solvers/allen_cahn.py,sha256=h44HKG_NNalJBN6_uyBciAthAjUaoQ4rw_JsV1rQCeo,1372
15
+ evoxels/precompiled_solvers/cahn_hilliard.py,sha256=yC-iGEvD8Efdmp6YEUfysl2L1jbfjU62-TR1T05S68E,1166
16
+ evoxels-1.0.0.dist-info/licenses/LICENSE,sha256=2ScNJCT83dGOKEIpmeO0sq7sf6-Rru3SsnDHJ2UFdcg,1065
17
+ evoxels-1.0.0.dist-info/METADATA,sha256=TSASC3rlti5rGIZQYB-kb3N6b-RcoTlW6YQJpJunwMk,7665
18
+ evoxels-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
19
+ evoxels-1.0.0.dist-info/top_level.txt,sha256=g6OihMiKjYgojKrMM8ckpmFVh-ExPN8f4MZPWscCbqo,8
20
+ evoxels-1.0.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,20 +0,0 @@
1
- evoxels/__init__.py,sha256=LLogj7BjzLNi9voNIQkmDYmj_dT82Aw8ph9RAlaPLuU,376
2
- evoxels/boundary_conditions.py,sha256=IZbPFvPVknWtTyXdDYEQiWCmw1RhwAlCfx0GBpGt7wg,5129
3
- evoxels/fd_stencils.py,sha256=kpFszQqjSPuWzRPyscR13uaDM9GWiBvmBVopBRosohs,4581
4
- evoxels/function_approximators.py,sha256=_WwsypBWSigMlyvT_Qgc8rHN7SsyvyFL9WgB1QqPdHY,2545
5
- evoxels/inversion.py,sha256=wKRaJ14eLGBWIeJbAk1pSFP6kFmKDeM9OBvSHndUAz4,8694
6
- evoxels/problem_definition.py,sha256=nfr0M1yH0v16btJoE9_0xlukxjyHx3BmnM1pWyXIsew,14559
7
- evoxels/profiler.py,sha256=2nGEbQBTaP_FUdA7OojmxQe2npLJwgDC-RVF4DNJ51A,3995
8
- evoxels/solvers.py,sha256=tA-3bqHApktZnENMNWZIz-bsgX2QKRzW8FzKrwsvQXE,5172
9
- evoxels/timesteppers.py,sha256=YbaWViKklhZeYKjl4SetWUALS4q6Fm7zCYy6cBesGgw,3509
10
- evoxels/utils.py,sha256=wp-tL2SkFIjSM3B1IwbMGLcVGPO6T8wr4bGitgn7z1k,4686
11
- evoxels/voxelfields.py,sha256=e6DEqv1C7MavOqFx9RhEMD-SktVEO1JDLRv43WNiCTU,13300
12
- evoxels/voxelgrid.py,sha256=r5yoo2J6ogHEOzFbhjikbi4wXQfurmZB4EWf6AEHIK0,9549
13
- evoxels/precompiled_solvers/__init__.py,sha256=V8oekjyg13ziQ1Gdf0pHHYubZcB6Oec5a8RG0d09Y1c,50
14
- evoxels/precompiled_solvers/allen_cahn.py,sha256=ZutF6L3LYqyCC3tIEQwHmQchZSbhZFmcLy-UNaN4yH0,1373
15
- evoxels/precompiled_solvers/cahn_hilliard.py,sha256=Bn2Tbv-kbLwinvVkCr07CT8OrR9im0kvKk0o9ySghFI,1158
16
- evoxels-0.1.1.dist-info/licenses/LICENSE,sha256=2ScNJCT83dGOKEIpmeO0sq7sf6-Rru3SsnDHJ2UFdcg,1065
17
- evoxels-0.1.1.dist-info/METADATA,sha256=S5YiDQxd-JWhMDI1Jbwyx4gWkWFXJnk4fiB2khxmiaw,7618
18
- evoxels-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- evoxels-0.1.1.dist-info/top_level.txt,sha256=g6OihMiKjYgojKrMM8ckpmFVh-ExPN8f4MZPWscCbqo,8
20
- evoxels-0.1.1.dist-info/RECORD,,