fluxfem 0.1.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. fluxfem/__init__.py +343 -0
  2. fluxfem/core/__init__.py +316 -0
  3. fluxfem/core/assembly.py +788 -0
  4. fluxfem/core/basis.py +996 -0
  5. fluxfem/core/data.py +64 -0
  6. fluxfem/core/dtypes.py +4 -0
  7. fluxfem/core/forms.py +234 -0
  8. fluxfem/core/interp.py +55 -0
  9. fluxfem/core/solver.py +113 -0
  10. fluxfem/core/space.py +419 -0
  11. fluxfem/core/weakform.py +818 -0
  12. fluxfem/helpers_num.py +11 -0
  13. fluxfem/helpers_wf.py +42 -0
  14. fluxfem/mesh/__init__.py +29 -0
  15. fluxfem/mesh/base.py +244 -0
  16. fluxfem/mesh/hex.py +327 -0
  17. fluxfem/mesh/io.py +87 -0
  18. fluxfem/mesh/predicate.py +45 -0
  19. fluxfem/mesh/surface.py +257 -0
  20. fluxfem/mesh/tet.py +246 -0
  21. fluxfem/physics/__init__.py +53 -0
  22. fluxfem/physics/diffusion.py +18 -0
  23. fluxfem/physics/elasticity/__init__.py +39 -0
  24. fluxfem/physics/elasticity/hyperelastic.py +99 -0
  25. fluxfem/physics/elasticity/linear.py +58 -0
  26. fluxfem/physics/elasticity/materials.py +32 -0
  27. fluxfem/physics/elasticity/stress.py +46 -0
  28. fluxfem/physics/operators.py +109 -0
  29. fluxfem/physics/postprocess.py +113 -0
  30. fluxfem/solver/__init__.py +47 -0
  31. fluxfem/solver/bc.py +439 -0
  32. fluxfem/solver/cg.py +326 -0
  33. fluxfem/solver/dirichlet.py +126 -0
  34. fluxfem/solver/history.py +31 -0
  35. fluxfem/solver/newton.py +400 -0
  36. fluxfem/solver/result.py +62 -0
  37. fluxfem/solver/solve_runner.py +534 -0
  38. fluxfem/solver/solver.py +148 -0
  39. fluxfem/solver/sparse.py +188 -0
  40. fluxfem/tools/__init__.py +7 -0
  41. fluxfem/tools/jit.py +51 -0
  42. fluxfem/tools/timer.py +659 -0
  43. fluxfem/tools/visualizer.py +101 -0
  44. fluxfem-0.1.1a0.dist-info/METADATA +111 -0
  45. fluxfem-0.1.1a0.dist-info/RECORD +47 -0
  46. fluxfem-0.1.1a0.dist-info/WHEEL +4 -0
  47. fluxfem-0.1.1a0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,18 @@
1
+ import jax.numpy as jnp
2
+
3
+ from ..core.forms import FormContext
4
+
5
+
6
+ def diffusion_form(ctx: FormContext, kappa: float) -> jnp.ndarray:
7
+ """
8
+ Scalar diffusion bilinear form: kappa * grad_v · grad_u.
9
+
10
+ Returns the per-quadrature integrand for a standard Laplacian term.
11
+ """
12
+ grad_v = ctx.test.gradN
13
+ grad_u = ctx.trial.gradN
14
+ G = jnp.einsum("qia,qja->qij", grad_v, grad_u) # ∇v_i · ∇u_j
15
+ return kappa * G
16
+
17
+
18
+ __all__ = ["diffusion_form"]
@@ -0,0 +1,39 @@
1
+ """Elasticity-related helpers (linear models, materials, forms)."""
2
+
3
+ from .materials import lame_parameters, isotropic_3d_D
4
+ from .linear import (
5
+ linear_elasticity_form,
6
+ vector_body_force_form,
7
+ constant_body_force_vector_form,
8
+ assemble_constant_body_force,
9
+ )
10
+ from .hyperelastic import (
11
+ right_cauchy_green,
12
+ green_lagrange_strain,
13
+ deformation_gradient,
14
+ pk2_neo_hookean,
15
+ neo_hookean_residual_form,
16
+ make_elastic_point_data,
17
+ write_elastic_vtu,
18
+ )
19
+ from .stress import principal_stresses, principal_sum, max_shear_stress, von_mises_stress
20
+
21
+ __all__ = [
22
+ "lame_parameters",
23
+ "isotropic_3d_D",
24
+ "linear_elasticity_form",
25
+ "vector_body_force_form",
26
+ "constant_body_force_vector_form",
27
+ "assemble_constant_body_force",
28
+ "right_cauchy_green",
29
+ "green_lagrange_strain",
30
+ "deformation_gradient",
31
+ "pk2_neo_hookean",
32
+ "neo_hookean_residual_form",
33
+ "make_elastic_point_data",
34
+ "write_elastic_vtu",
35
+ "principal_stresses",
36
+ "principal_sum",
37
+ "max_shear_stress",
38
+ "von_mises_stress",
39
+ ]
@@ -0,0 +1,99 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+ from ...core.forms import FormContext
6
+ from ...core.basis import build_B_matrices_finite
7
+ from ..postprocess import make_point_data_displacement, write_point_data_vtu
8
+
9
+
10
+ def right_cauchy_green(F: jnp.ndarray) -> jnp.ndarray:
11
+ """C = F^T F (right Cauchy-Green)."""
12
+ return jnp.einsum("...ik,...jk->...ij", F, F)
13
+
14
+
15
+ def green_lagrange_strain(F: jnp.ndarray) -> jnp.ndarray:
16
+ """E = 0.5 (C - I)."""
17
+ I = jnp.eye(F.shape[-1], dtype=F.dtype)
18
+ C = right_cauchy_green(F)
19
+ return 0.5 * (C - I)
20
+
21
+
22
+ def deformation_gradient(ctx: FormContext, u_elem: jnp.ndarray) -> jnp.ndarray:
23
+ """
24
+ Compute deformation gradient F = I + grad_u for a 3D vector displacement.
25
+
26
+ ctx: FormContext with test/trial grads (reference configuration)
27
+ u_elem: (n_ldofs,) element displacement in dof ordering [u0,v0,w0, u1,v1,w1, ...]
28
+ returns: (n_q, 3, 3) F per quadrature point
29
+ """
30
+ u_nodes = u_elem.reshape(-1, 3) # (n_nodes, 3)
31
+ grad_u = jnp.einsum("qaj,ai->qij", ctx.trial.gradN, u_nodes) # ∂u_i/∂X_j
32
+ I = jnp.eye(3, dtype=u_elem.dtype)
33
+ return I[None, ...] + grad_u
34
+
35
+
36
+ def pk2_neo_hookean(F: jnp.ndarray, mu: float, lam: float) -> jnp.ndarray:
37
+ """
38
+ Compressible Neo-Hookean PK2 stress:
39
+ S = mu * (I - C^{-1}) + lam * ln J * C^{-1}
40
+ C = F^T F, J = sqrt(det C)
41
+ """
42
+ C = right_cauchy_green(F)
43
+ C_inv = jnp.linalg.inv(C)
44
+ J = jnp.sqrt(jnp.linalg.det(C))
45
+ I = jnp.eye(3, dtype=F.dtype)
46
+ return mu * (I - C_inv) + lam * jnp.log(J)[..., None, None] * C_inv
47
+
48
+
49
+ def neo_hookean_residual_form(ctx: FormContext, u_elem: jnp.ndarray, params) -> jnp.ndarray:
50
+ """
51
+ Compressible Neo-Hookean residual (Total Lagrangian, PK2).
52
+ params: dict-like with keys \"mu\", \"lam\" or tuple (mu, lam)
53
+ returns: (n_q, n_ldofs)
54
+ """
55
+ if isinstance(params, dict):
56
+ mu = params["mu"]
57
+ lam = params["lam"]
58
+ else:
59
+ mu, lam = params
60
+
61
+ F = deformation_gradient(ctx, u_elem) # (n_q, 3, 3)
62
+ S = pk2_neo_hookean(F, mu, lam) # (n_q, 3, 3)
63
+
64
+ S_voigt = jnp.stack(
65
+ [
66
+ S[..., 0, 0],
67
+ S[..., 1, 1],
68
+ S[..., 2, 2],
69
+ S[..., 0, 1],
70
+ S[..., 1, 2],
71
+ S[..., 2, 0],
72
+ ],
73
+ axis=-1,
74
+ ) # (n_q, 6)
75
+
76
+ B = build_B_matrices_finite(ctx.trial.gradN, F) # (n_q, 6, n_ldofs)
77
+ BT = jnp.swapaxes(B, 1, 2) # (n_q, n_ldofs, 6)
78
+ return jnp.einsum("qik,qk->qi", BT, S_voigt) # (n_q, n_ldofs)
79
+
80
+
81
+ __all__ = [
82
+ "right_cauchy_green",
83
+ "green_lagrange_strain",
84
+ "deformation_gradient",
85
+ "pk2_neo_hookean",
86
+ "neo_hookean_residual_form",
87
+ "make_elastic_point_data",
88
+ "write_elastic_vtu",
89
+ ]
90
+
91
+
92
+ def make_elastic_point_data(mesh, space, u, *, compute_j: bool = True, deformed_scale: float = 1.0):
93
+ """Alias to postprocess.make_point_data_displacement for backward compatibility."""
94
+ return make_point_data_displacement(mesh, space, u, compute_j=compute_j, deformed_scale=deformed_scale)
95
+
96
+
97
+ def write_elastic_vtu(mesh, space, u, filepath: str, *, compute_j: bool = True, deformed_scale: float = 1.0):
98
+ """Alias to postprocess.write_point_data_vtu for backward compatibility."""
99
+ return write_point_data_vtu(mesh, space, u, filepath, compute_j=compute_j, deformed_scale=deformed_scale)
@@ -0,0 +1,58 @@
1
+ import jax.numpy as jnp
2
+
3
+ from ...core.assembly import assemble_linear_form
4
+ from ...core.forms import FormContext, vector_load_form
5
+ from ...core.basis import build_B_matrices
6
+ from ...physics.operators import sym_grad
7
+
8
+ # from ...mechanics.kinematics import build_B_matrices
9
+
10
+
11
+ # def linear_elasticity_form(ctx: FormContext, D: jnp.ndarray) -> jnp.ndarray:
12
+ # """3D linear elasticity bilinear form B^T D B."""
13
+ # grad_v = ctx.test.grad
14
+ # grad_u = ctx.trial.grad
15
+ # B = build_B_matrices(grad_u) # (n_q, 6, 24)
16
+ # BT = jnp.swapaxes(build_B_matrices(grad_v), 1, 2) # (n_q, 24, 6)
17
+ # BDB = jnp.einsum("qik,kl,qlm->qim", BT, D, B) # (n_q, 24, 24)
18
+ # return BDB
19
+
20
+
21
+ def linear_elasticity_form(ctx: FormContext, D: jnp.ndarray) -> jnp.ndarray:
22
+ """
23
+ Linear-elasticity bilinear form in Voigt notation.
24
+
25
+ Returns the per-quadrature integrand for Bv^T D Bu, where B is the
26
+ symmetric-gradient operator for the test/trial fields.
27
+ """
28
+ Bu = sym_grad(ctx.trial) # (n_q, 6, ndofs_e)
29
+ Bv = sym_grad(ctx.test) # (n_q, 6, ndofs_e)
30
+ return jnp.einsum("qik,kl,qlm->qim", jnp.swapaxes(Bv, 1, 2), D, Bu)
31
+
32
+
33
+ def vector_body_force_form(ctx: FormContext, load_vec: jnp.ndarray) -> jnp.ndarray:
34
+ """Linear form for 3D vector body force f (constant in space)."""
35
+ return vector_load_form(ctx.test, load_vec)
36
+
37
+
38
+ def assemble_constant_body_force(space, gravity_vec, density: float, *, sparse: bool = False):
39
+ """
40
+ Convenience: assemble body force from density * gravity vector.
41
+ gravity_vec: length-3 array-like (direction and magnitude of g)
42
+ density: scalar density (consistent with unit system)
43
+ """
44
+ g = jnp.asarray(gravity_vec)
45
+ f_vec = density * g
46
+ return assemble_linear_form(space, vector_body_force_form, params=f_vec, sparse=sparse)
47
+
48
+
49
+ # Backward compatibility alias
50
+ constant_body_force_vector_form = vector_body_force_form
51
+
52
+
53
+ __all__ = [
54
+ "linear_elasticity_form",
55
+ "vector_body_force_form",
56
+ "constant_body_force_vector_form",
57
+ "assemble_constant_body_force",
58
+ ]
@@ -0,0 +1,32 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ DTYPE = jnp.float64 if jax.config.read("jax_enable_x64") else jnp.float32
5
+
6
+
7
+ def lame_parameters(E: float, nu: float) -> tuple[float, float]:
8
+ """Return Lamé parameters (lambda, mu) from Young's modulus and Poisson ratio."""
9
+ lam = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))
10
+ mu = E / (2.0 * (1.0 + nu))
11
+ return float(lam), float(mu)
12
+
13
+
14
+ def isotropic_3d_D(E: float, nu: float) -> jnp.ndarray:
15
+ """Return 3D isotropic linear elasticity constitutive matrix in Voigt form."""
16
+ lam, mu = lame_parameters(E, nu)
17
+
18
+ D = jnp.array(
19
+ [
20
+ [lam + 2 * mu, lam, lam, 0.0, 0.0, 0.0],
21
+ [lam, lam + 2 * mu, lam, 0.0, 0.0, 0.0],
22
+ [lam, lam, lam + 2 * mu, 0.0, 0.0, 0.0],
23
+ [0.0, 0.0, 0.0, mu, 0.0, 0.0],
24
+ [0.0, 0.0, 0.0, 0.0, mu, 0.0],
25
+ [0.0, 0.0, 0.0, 0.0, 0.0, mu],
26
+ ],
27
+ dtype=DTYPE,
28
+ )
29
+ return D
30
+
31
+
32
+ __all__ = ["lame_parameters", "isotropic_3d_D"]
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+ import jax.numpy as jnp
4
+
5
+
6
+ def _sym(A: jnp.ndarray) -> jnp.ndarray:
7
+ return 0.5 * (A + jnp.swapaxes(A, -1, -2))
8
+
9
+
10
+ def principal_stresses(S: jnp.ndarray) -> jnp.ndarray:
11
+ """
12
+ Return principal stresses (eigvals) for symmetric 3x3 stress tensor.
13
+ Supports batching over leading dimensions.
14
+ """
15
+ S_sym = _sym(S)
16
+ return jnp.linalg.eigvalsh(S_sym)
17
+
18
+
19
+ def principal_sum(S: jnp.ndarray) -> jnp.ndarray:
20
+ """Sum of principal stresses (trace)."""
21
+ return jnp.trace(S, axis1=-2, axis2=-1)
22
+
23
+
24
+ def max_shear_stress(S: jnp.ndarray) -> jnp.ndarray:
25
+ """
26
+ Maximum shear stress = (sigma_max - sigma_min) / 2.
27
+ """
28
+ vals = principal_stresses(S)
29
+ return 0.5 * (vals[..., -1] - vals[..., 0])
30
+
31
+
32
+ def von_mises_stress(S: jnp.ndarray) -> jnp.ndarray:
33
+ """
34
+ von Mises equivalent stress: sqrt(3/2 * dev(S):dev(S)).
35
+ """
36
+ tr = jnp.trace(S, axis1=-2, axis2=-1)[..., None, None]
37
+ dev = S - tr / 3.0
38
+ return jnp.sqrt(1.5 * jnp.sum(dev * dev, axis=(-2, -1)))
39
+
40
+
41
+ __all__ = [
42
+ "principal_stresses",
43
+ "principal_sum",
44
+ "max_shear_stress",
45
+ "von_mises_stress",
46
+ ]
@@ -0,0 +1,109 @@
1
+
2
+
3
+ # fluxfem/mechanics/operators.py
4
+ from __future__ import annotations
5
+ import jax
6
+ import jax.numpy as jnp
7
+
8
+
9
+ def dot(a, b):
10
+ """
11
+ Batched matrix product on the last two axes.
12
+
13
+ If the first argument is a FormField, dispatch to vector_load_form to build
14
+ the linear form contribution for a vector load.
15
+ """
16
+ if hasattr(a, "N") and getattr(a, "value_dim", None) is not None:
17
+ from ..core.forms import vector_load_form
18
+ return vector_load_form(a, b)
19
+ return jnp.matmul(a, b)
20
+
21
+
22
+ def ddot(a, b, c=None):
23
+ """
24
+ Double contraction on the last two axes.
25
+
26
+ - ddot(a, b): sum_ij a_ij * b_ij
27
+ - ddot(a, b, c): a^T b c (Voigt-style linear elasticity blocks)
28
+ """
29
+ if c is None:
30
+ return jnp.einsum("...ij,...ij->...", a, b)
31
+ a_t = jnp.swapaxes(a, -1, -2)
32
+ return jnp.einsum("...ik,kl,...lm->...im", a_t, b, c)
33
+
34
+
35
+ def transpose_last2(a):
36
+ """Swap the last two axes (batched transpose)."""
37
+ return jnp.swapaxes(a, -1, -2)
38
+
39
+
40
+ def sym_grad(field) -> jnp.ndarray:
41
+ """
42
+ Symmetric gradient operator for vector mechanics (small strain).
43
+
44
+ Parameters
45
+ ----------
46
+ field : FormField-like
47
+ Must provide:
48
+ - field.gradN : (n_q, n_nodes, 3)
49
+ - field.basis.dofs_per_node (usually 3)
50
+
51
+ Returns
52
+ -------
53
+ B : jnp.ndarray
54
+ (n_q, 6, dofs_per_node*n_nodes) Voigt order [xx, yy, zz, xy, yz, zx]
55
+ Such that eps_voigt(q,:) = B(q,:,:) @ u_elem
56
+ """
57
+ gradN = field.gradN # (n_q, n_nodes, 3)
58
+ dofs = getattr(field.basis, "dofs_per_node", 3)
59
+ n_q, n_nodes, _ = gradN.shape
60
+ n_dofs = dofs * n_nodes
61
+
62
+ def B_single(dN):
63
+ B = jnp.zeros((6, n_dofs), dtype=dN.dtype)
64
+
65
+ def node_fun(a, B):
66
+ dNdx, dNdy, dNdz = dN[a, 0], dN[a, 1], dN[a, 2]
67
+ col = dofs * a
68
+
69
+ # eps_xx, eps_yy, eps_zz
70
+ B = B.at[0, col + 0].set(dNdx) # dux/dx
71
+ B = B.at[1, col + 1].set(dNdy) # duy/dy
72
+ B = B.at[2, col + 2].set(dNdz) # duz/dz
73
+
74
+ # eps_xy = 1/2(dux/dy + duy/dx)
75
+ B = B.at[3, col + 0].set(dNdy)
76
+ B = B.at[3, col + 1].set(dNdx)
77
+
78
+ # eps_yz = 1/2(duy/dz + duz/dy)
79
+ B = B.at[4, col + 1].set(dNdz)
80
+ B = B.at[4, col + 2].set(dNdy)
81
+
82
+ # eps_zx = 1/2(duz/dx + dux/dz)
83
+ B = B.at[5, col + 0].set(dNdz)
84
+ B = B.at[5, col + 2].set(dNdx)
85
+ return B
86
+
87
+ return jax.lax.fori_loop(0, n_nodes, node_fun, B)
88
+
89
+ return jax.vmap(B_single)(gradN)
90
+
91
+
92
+ def sym_grad_u(field, u_elem: jnp.ndarray) -> jnp.ndarray:
93
+ """
94
+ Apply sym_grad(field) to a local displacement vector.
95
+
96
+ Parameters
97
+ ----------
98
+ field : FormField-like
99
+ Vector field basis data.
100
+ u_elem : jnp.ndarray
101
+ Element displacement vector (dofs_per_node*n_nodes,).
102
+
103
+ Returns
104
+ -------
105
+ jnp.ndarray
106
+ Symmetric strain in Voigt form with shape (n_q, 6).
107
+ """
108
+ B = sym_grad(field)
109
+ return jnp.einsum("qik,k->qi", B, u_elem)
@@ -0,0 +1,113 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ # from ..core.assembly import build_form_contexts
8
+ from ..tools.visualizer import write_vtu
9
+ from ..core.interp import interpolate_field_at_element_points
10
+
11
+
12
+ def make_point_data_displacement(mesh, space, u, *, compute_j: bool = True, deformed_scale: float = 1.0):
13
+ """
14
+ Common postprocess helper to build point data dictionaries:
15
+ - displacement
16
+ - deformed_coords = X + scale * u
17
+ - optional J (nodal average of det(F))
18
+
19
+ Assumes 3 dof/node ordering [u0,v0,w0, u1,v1,w1, ...].
20
+ """
21
+ coords = np.asarray(mesh.coords)
22
+ u_np = np.asarray(u)
23
+ n_nodes = coords.shape[0]
24
+ if u_np.shape[0] != 3 * n_nodes:
25
+ raise ValueError(f"Expected 3 dof/node vector; got {u_np.shape[0]} entries for {n_nodes} nodes")
26
+
27
+ u_nodes = u_np.reshape(n_nodes, 3)
28
+ disp = u_nodes
29
+ deformed_coords = coords + deformed_scale * disp
30
+
31
+ point_data = {
32
+ "displacement": disp.astype(np.float64),
33
+ "deformed_coords": deformed_coords.astype(np.float64),
34
+ }
35
+
36
+ if compute_j:
37
+ elem_conns = np.asarray(space.elem_dofs)
38
+ # ctxs = build_form_contexts(space)
39
+ ctxs = space.build_form_contexts()
40
+ u_arr = jnp.asarray(u)
41
+
42
+ from .elasticity.hyperelastic import deformation_gradient # local import to avoid circular
43
+
44
+ def elem_J(ctx, conn):
45
+ F = deformation_gradient(ctx, u_arr[conn])
46
+ return jnp.mean(jnp.linalg.det(F))
47
+
48
+ J_elem = np.asarray(jax.vmap(elem_J)(ctxs, elem_conns), dtype=float)
49
+ J_sum = np.zeros(n_nodes, dtype=float)
50
+ J_cnt = np.zeros(n_nodes, dtype=float)
51
+ for conn, J_mean in zip(elem_conns, J_elem):
52
+ node_ids = np.unique(np.asarray(conn) // 3)
53
+ J_sum[node_ids] += J_mean
54
+ J_cnt[node_ids] += 1.0
55
+ J_nodal = J_sum / np.maximum(J_cnt, 1.0)
56
+ point_data["J"] = J_nodal.astype(np.float64)
57
+
58
+ return point_data
59
+
60
+
61
+ def write_point_data_vtu(mesh, space, u, filepath: str, *, compute_j: bool = True, deformed_scale: float = 1.0):
62
+ """Write VTU with displacement/deformed_coords and optional J."""
63
+ pdata = make_point_data_displacement(mesh, space, u, compute_j=compute_j, deformed_scale=deformed_scale)
64
+ write_vtu(mesh, filepath, point_data=pdata)
65
+
66
+
67
+ __all__ = ["make_point_data_displacement", "write_point_data_vtu", "interpolate_at_points"]
68
+ def interpolate_at_points(space, u, points: np.ndarray):
69
+ """
70
+ Interpolate displacement field at given physical points (Hex8 only, structured search).
71
+ - points: (m,3) array of physical coordinates.
72
+ Returns: (m,3) interpolated displacement.
73
+ """
74
+ pts = np.asarray(points, dtype=float)
75
+ mesh = space.mesh
76
+ # Only support StructuredHexBox-backed HexMesh (regular grid)
77
+ if not hasattr(mesh, "origin") and not hasattr(mesh, "lx"):
78
+ raise NotImplementedError("interpolate_at_points currently supports StructuredHexBox meshes.")
79
+ coords = np.asarray(mesh.coords)
80
+ # grid dimensions from origin/extents and nx,ny,nz inferred from spacing
81
+ xs = np.unique(coords[:, 0])
82
+ ys = np.unique(coords[:, 1])
83
+ zs = np.unique(coords[:, 2])
84
+ dx, dy, dz = xs[1] - xs[0], ys[1] - ys[0], zs[1] - zs[0]
85
+ ox, oy, oz = xs.min(), ys.min(), zs.min()
86
+ nx, ny, nz = len(xs) - 1, len(ys) - 1, len(zs) - 1
87
+
88
+ # map physical point to element indices and local coords
89
+ def phys_to_elem_local(p):
90
+ x, y, z = p
91
+ i = min(max(int(np.floor((x - ox) / dx)), 0), nx - 1)
92
+ j = min(max(int(np.floor((y - oy) / dy)), 0), ny - 1)
93
+ k = min(max(int(np.floor((z - oz) / dz)), 0), nz - 1)
94
+ # local coords in [-1,1]
95
+ xi = 2 * ((x - (ox + i * dx)) / dx) - 1
96
+ eta = 2 * ((y - (oy + j * dy)) / dy) - 1
97
+ zeta = 2 * ((z - (oz + k * dz)) / dz) - 1
98
+ elem_idx = k * (nx * ny) + j * nx + i
99
+ return elem_idx, np.array([xi, eta, zeta], dtype=float)
100
+
101
+ elem_indices = []
102
+ locals = []
103
+ for p in pts:
104
+ e, loc = phys_to_elem_local(p)
105
+ elem_indices.append(e)
106
+ locals.append(loc)
107
+ locals = np.stack(locals, axis=0) # (m,3)
108
+
109
+ vals_per_elem = interpolate_field_at_element_points(space, u, locals) # (n_elem,m,3) but m is same locals; pick per elem
110
+ vals = np.zeros((len(pts), 3), dtype=float)
111
+ for idx, e in enumerate(elem_indices):
112
+ vals[idx] = vals_per_elem[e, idx]
113
+ return vals
@@ -0,0 +1,47 @@
1
+ from .sparse import SparsityPattern, FluxSparseMatrix
2
+ from .dirichlet import (
3
+ enforce_dirichlet_dense,
4
+ enforce_dirichlet_sparse,
5
+ free_dofs,
6
+ condense_dirichlet_fluxsparse,
7
+ condense_dirichlet_dense,
8
+ expand_dirichlet_solution,
9
+ )
10
+ from .cg import cg_solve, cg_solve_jax
11
+ from .newton import newton_solve
12
+ from .solve_runner import (
13
+ NonlinearAnalysis,
14
+ NewtonLoopConfig,
15
+ LoadStepResult,
16
+ NewtonSolveRunner,
17
+ solve_nonlinear,
18
+ LinearAnalysis,
19
+ LinearSolveConfig,
20
+ LinearStepResult,
21
+ LinearSolveRunner,
22
+ )
23
+ from .solver import LinearSolver, NonlinearSolver
24
+
25
+ __all__ = [
26
+ "SparsityPattern",
27
+ "FluxSparseMatrix",
28
+ "enforce_dirichlet_dense",
29
+ "enforce_dirichlet_sparse",
30
+ "free_dofs",
31
+ "condense_dirichlet_fluxsparse",
32
+ "condense_dirichlet_dense",
33
+ "expand_dirichlet_solution",
34
+ "cg_solve",
35
+ "cg_solve_jax",
36
+ "newton_solve",
37
+ "LinearAnalysis",
38
+ "LinearSolveConfig",
39
+ "LinearStepResult",
40
+ "NonlinearAnalysis",
41
+ "NewtonLoopConfig",
42
+ "LoadStepResult",
43
+ "NewtonSolveRunner",
44
+ "solve_nonlinear",
45
+ "LinearSolver",
46
+ "NonlinearSolver",
47
+ ]