fluxfem 0.1.4__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. fluxfem-0.2.2/PKG-INFO +330 -0
  2. fluxfem-0.2.2/README.md +303 -0
  3. {fluxfem-0.1.4 → fluxfem-0.2.2}/pyproject.toml +31 -7
  4. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/__init__.py +69 -13
  5. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/__init__.py +140 -53
  6. fluxfem-0.2.2/src/fluxfem/core/assembly.py +1382 -0
  7. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/basis.py +75 -54
  8. fluxfem-0.2.2/src/fluxfem/core/context_types.py +60 -0
  9. fluxfem-0.2.2/src/fluxfem/core/dtypes.py +12 -0
  10. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/forms.py +10 -0
  11. fluxfem-0.2.2/src/fluxfem/core/mixed_assembly.py +263 -0
  12. fluxfem-0.2.2/src/fluxfem/core/mixed_space.py +382 -0
  13. fluxfem-0.2.2/src/fluxfem/core/mixed_weakform.py +97 -0
  14. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/solver.py +2 -0
  15. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/space.py +315 -30
  16. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/weakform.py +821 -42
  17. fluxfem-0.2.2/src/fluxfem/helpers_wf.py +97 -0
  18. fluxfem-0.2.2/src/fluxfem/mesh/__init__.py +81 -0
  19. fluxfem-0.2.2/src/fluxfem/mesh/base.py +558 -0
  20. fluxfem-0.2.2/src/fluxfem/mesh/contact.py +841 -0
  21. fluxfem-0.2.2/src/fluxfem/mesh/dtypes.py +12 -0
  22. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/mesh/hex.py +17 -16
  23. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/mesh/io.py +9 -6
  24. fluxfem-0.2.2/src/fluxfem/mesh/mortar.py +3970 -0
  25. fluxfem-0.2.2/src/fluxfem/mesh/supermesh.py +318 -0
  26. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/mesh/surface.py +104 -26
  27. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/mesh/tet.py +16 -7
  28. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/diffusion.py +3 -0
  29. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/hyperelastic.py +35 -3
  30. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/linear.py +22 -4
  31. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/stress.py +9 -5
  32. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/operators.py +12 -5
  33. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/postprocess.py +29 -3
  34. fluxfem-0.2.2/src/fluxfem/solver/__init__.py +92 -0
  35. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/bc.py +154 -2
  36. fluxfem-0.2.2/src/fluxfem/solver/block_matrix.py +298 -0
  37. fluxfem-0.2.2/src/fluxfem/solver/block_system.py +477 -0
  38. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/cg.py +150 -55
  39. fluxfem-0.2.2/src/fluxfem/solver/dirichlet.py +479 -0
  40. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/history.py +15 -3
  41. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/newton.py +260 -70
  42. fluxfem-0.2.2/src/fluxfem/solver/petsc.py +445 -0
  43. fluxfem-0.2.2/src/fluxfem/solver/preconditioner.py +109 -0
  44. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/result.py +18 -0
  45. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/solve_runner.py +208 -23
  46. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/solver.py +35 -12
  47. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/solver/sparse.py +149 -15
  48. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/tools/jit.py +19 -7
  49. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/tools/timer.py +14 -12
  50. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/tools/visualizer.py +16 -4
  51. fluxfem-0.1.4/PKG-INFO +0 -127
  52. fluxfem-0.1.4/README.md +0 -107
  53. fluxfem-0.1.4/src/fluxfem/core/assembly.py +0 -788
  54. fluxfem-0.1.4/src/fluxfem/core/context_types.py +0 -36
  55. fluxfem-0.1.4/src/fluxfem/core/dtypes.py +0 -4
  56. fluxfem-0.1.4/src/fluxfem/helpers_wf.py +0 -48
  57. fluxfem-0.1.4/src/fluxfem/mesh/__init__.py +0 -29
  58. fluxfem-0.1.4/src/fluxfem/mesh/base.py +0 -249
  59. fluxfem-0.1.4/src/fluxfem/solver/__init__.py +0 -47
  60. fluxfem-0.1.4/src/fluxfem/solver/dirichlet.py +0 -126
  61. {fluxfem-0.1.4 → fluxfem-0.2.2}/LICENSE +0 -0
  62. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/data.py +0 -0
  63. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/core/interp.py +0 -0
  64. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/helpers_ts.py +0 -0
  65. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/mesh/predicate.py +0 -0
  66. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/__init__.py +0 -0
  67. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/__init__.py +0 -0
  68. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/materials.py +0 -0
  69. {fluxfem-0.1.4 → fluxfem-0.2.2}/src/fluxfem/tools/__init__.py +0 -0
fluxfem-0.2.2/PKG-INFO ADDED
@@ -0,0 +1,330 @@
1
+ Metadata-Version: 2.1
2
+ Name: fluxfem
3
+ Version: 0.2.2
4
+ Summary: FluxFEM: A weak-form-centric differentiable finite element framework in JAX
5
+ License: Apache-2.0
6
+ Author: Kohei Watanabe
7
+ Author-email: koheitech001@gmail.com
8
+ Requires-Python: >=3.12,<3.14
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Provides-Extra: cpu
13
+ Provides-Extra: cuda12
14
+ Provides-Extra: petsc
15
+ Requires-Dist: jax (>=0.8.2,<0.9.0) ; extra == "cpu" or extra == "cuda12"
16
+ Requires-Dist: jax-cuda12-pjrt (==0.8.2) ; extra == "cuda12"
17
+ Requires-Dist: jax-cuda12-plugin (==0.8.2) ; extra == "cuda12"
18
+ Requires-Dist: jaxlib (>=0.8.2,<0.9.0) ; extra == "cpu" or extra == "cuda12"
19
+ Requires-Dist: matplotlib (>=3.10.7,<4.0.0)
20
+ Requires-Dist: meshio (>=5.3.5,<6.0.0)
21
+ Requires-Dist: petsc4py (==3.23.6) ; extra == "petsc"
22
+ Requires-Dist: pyproject (>=1!0.1.2,<1!0.2.0)
23
+ Requires-Dist: pyproject-toml (>=0.1.0,<0.2.0)
24
+ Requires-Dist: pyvista (>=0.46.4,<0.47.0)
25
+ Description-Content-Type: text/markdown
26
+
27
+ [![PyPI version](https://img.shields.io/pypi/v/fluxfem.svg?cacheSeconds=60)](https://pypi.org/project/fluxfem/)
28
+ [![License: Apache-2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
29
+ [![Python Version](https://img.shields.io/pypi/pyversions/fluxfem.svg)](https://pypi.org/project/fluxfem/)
30
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/python-tests.yml/badge.svg)
31
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/sphinx.yml/badge.svg)
32
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.18055465.svg)](https://doi.org/10.5281/zenodo.18055465)
33
+
34
+
35
+ # FluxFEM
36
+ A weak-form-centric differentiable finite element framework in JAX,
37
+ where variational forms are treated as first-class, differentiable programs.
38
+
39
+ ## Examples and Features
40
+ <table>
41
+ <tr>
42
+ <td align="center"><b>Example 1: Diffusion</b></td>
43
+ <td align="center"><b>Example 2: Neo Neohookean Hyper Elasticity</b></td>
44
+ </tr>
45
+ <tr>
46
+ <td align="center">
47
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/diffusion_mms_timeseries.gif" alt="Diffusion-mms" width="400">
48
+ </td>
49
+ <td align="center">
50
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/Neo-Hookean-deformedx20000.png" alt="Neo-Hookean" width="400">
51
+ </td>
52
+ </tr>
53
+ </table>
54
+
55
+
56
+ ## Features
57
+ - Built on JAX, enabling automatic differentiation with grad, jit, vmap, and related transformations.
58
+ - Weak-form–centric API that keeps formulations close to code; weak forms are represented as expression trees and compiled into element kernels, enabling automatic differentiation of residuals, tangents, and objectives.
59
+ - Two assembly approaches: tensor-based (scikit-fem–style) assembly and weak-form-based assembly.
60
+ - Handles both linear and nonlinear analyses with AD in JAX.
61
+ - Optional PETSc/PETSc-shell solvers via `petsc4py` for scalable linear solves (add `fluxfem[petsc]`).
62
+
63
+ ## Usage
64
+
65
+ This library provides two assembly approaches.
66
+
67
+ - A tensor-based assembly, where trial and test functions are represented explicitly as element-level tensors and assembled accordingly (in the style of scikit-fem).
68
+ - A weak-form-based assembly, where the variational form is written symbolically and compiled before assembly.
69
+
70
+ The two approaches are functionally equivalent and share the same element-level execution model,
71
+ but they differ in how you author the weak form. The example below mirrors the paper's diffusion
72
+ case and makes the distinction explicit with `jnp`.
73
+
74
+
75
+ ## Assembly Flow
76
+ All expressions are first compiled into an element-level evaluation plan,
77
+ which operates on quadrature-point–major tensors.
78
+ This plan is then executed independently for each element during assembly.
79
+
80
+ As a result, both assembly approaches:
81
+ - use the same quadrature-major (q, a, i) data layout,
82
+ - perform element-local tensor contractions,
83
+ - and are fully compatible with JAX transformations such as `jit`, `vmap`, and automatic differentiation.
84
+
85
+ ### kernel-based assembly (explicit JIT units)
86
+ If you want to control JIT boundaries explicitly, build a JIT-compiled element kernel
87
+ and pass it to `space.assemble`. The kernel must return the integrated element
88
+ contribution (not the quadrature integrand). For untagged raw kernels, pass `kind=`.
89
+
90
+ ```Python
91
+ import fluxfem as ff
92
+ import jax
93
+ import jax.numpy as jnp
94
+
95
+ space = ff.make_hex_space(mesh, dim=1, intorder=2)
96
+
97
+ # bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
98
+ ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
99
+ K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)
100
+
101
+ # linear: kernel(ctx) -> (n_ldofs,)
102
+ def linear_kernel(ctx):
103
+ integrand = ff.scalar_body_force_form(ctx, 2.0)
104
+ wJ = ctx.w * ctx.test.detJ
105
+ return (integrand * wJ[:, None]).sum(axis=0)
106
+
107
+ ker_F = jax.jit(linear_kernel)
108
+ F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)
109
+ ```
110
+
111
+ ### tensor-based vs weak-form-based (diffusion example)
112
+
113
+ #### tensor-based assembly
114
+ The tensor-based assembly provides an explicit, low-level formulation with element kernels written using jax.numpy.(`jnp`).
115
+ ```Python
116
+ import fluxfem as ff
117
+ import jax.numpy as jnp
118
+
119
+ @ff.kernel(kind="bilinear", domain="volume")
120
+ def diffusion_form(ctx: ff.FormContext, kappa):
121
+ # ctx.test.gradN / ctx.trial.gradN: (n_qp, n_nodes, dim)
122
+ # output tensor: (n_qp, n_nodes, n_nodes)
123
+ return kappa * jnp.einsum("qia,qja->qij", ctx.test.gradN, ctx.trial.gradN)
124
+
125
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
126
+ params = ff.Params(kappa=1.0)
127
+ K_ts = space.assemble(diffusion_form, params=params.kappa)
128
+ ```
129
+
130
+ #### weak-form-based assembly
131
+ In the weak-form-based assembly, the variational formulation itself is the primary object.
132
+ The expression below defines a symbolic computation graph, which is later compiled and executed at the element level.
133
+
134
+ ```Python
135
+ import fluxfem as ff
136
+ import fluxfem.helpers_wf as h_wf
137
+
138
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
139
+ params = ff.Params(kappa=1.0)
140
+
141
+ # u, v are symbolic trial/test fields (weak-form DSL objects).
142
+ # u.grad / v.grad are symbolic nodes (expression tree), not numeric arrays.
143
+ # dOmega() is the integral measure; the whole expression is compiled before assembly.
144
+ form_wf = ff.BilinearForm.volume(
145
+ lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
146
+ ).get_compiled()
147
+
148
+ K_wf = space.assemble(form_wf, params=params)
149
+ ```
150
+
151
+ ### Linear Elasticity assembly (weak-form based assembly)
152
+
153
+ ```Python
154
+ import fluxfem as ff
155
+ import fluxfem.helpers_wf as h_wf
156
+
157
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
158
+ D = ff.isotropic_3d_D(1.0, 0.3)
159
+
160
+ form_wf = ff.BilinearForm.volume(
161
+ lambda u, v, D: h_wf.ddot(v.sym_grad, D @ u.sym_grad) * h_wf.dOmega()
162
+ ).get_compiled()
163
+
164
+ K = space.assemble(form_wf, params=D)
165
+ ```
166
+
167
+ ### Neo-Hookean residual assembly (weak-form DSL)
168
+ Below is a Neo-Hookean hyperelasticity example written in weak form.
169
+ The residual is expressed symbolically and compiled into element-level kernels executed per element.
170
+ No manual derivation of tangent operators is required; consistent tangents (Jacobians) for Newton-type solvers are obtained automatically via JAX AD.
171
+
172
+ ```Python
173
+ def neo_hookean_residual_wf(v, u, params):
174
+ mu = params["mu"]
175
+ lam = params["lam"]
176
+ F = h_wf.I(3) + h_wf.grad(u) # deformation gradient
177
+ C = h_wf.matmul(h_wf.transpose(F), F)
178
+ C_inv = h_wf.inv(C)
179
+ J = h_wf.det(F)
180
+
181
+ S = mu * (h_wf.I(3) - C_inv) + lam * h_wf.log(J) * C_inv
182
+ dE = 0.5 * (h_wf.matmul(h_wf.grad(v), F) + h_wf.transpose(h_wf.matmul(h_wf.grad(v), F)))
183
+ return h_wf.ddot(S, dE) * h_wf.dOmega()
184
+
185
+ res_form = ff.ResidualForm.volume(neo_hookean_residual_wf).get_compiled()
186
+ ```
187
+
188
+
189
+ ### autodiff + jit compile
190
+
191
+ You can differentiate through the solve and JIT compile the hot path.
192
+ The inverse diffusion tutorial shows this pattern:
193
+
194
+ ```Python
195
+ def loss_theta(theta):
196
+ kappa = jnp.exp(theta)
197
+ u = solve_u_jit(kappa, traction_true)
198
+ diff = u[obs_idx_j] - u_obs[obs_idx_j]
199
+ return 0.5 * jnp.mean(diff * diff)
200
+
201
+ solve_u_jit = jax.jit(solve_u)
202
+ loss_theta_jit = jax.jit(loss_theta)
203
+ grad_fn = jax.jit(jax.grad(loss_theta))
204
+ ```
205
+
206
+ ### FESpace vs FESpacePytree
207
+
208
+ Use `FESpace` for standard workflows with a fixed mesh. When you need to carry
209
+ the space through JAX transformations (e.g., shape optimization where mesh
210
+ coordinates are part of the computation), use `FESpacePytree` via
211
+ `make_*_space_pytree(...)`. This keeps the mesh/basis in the pytree so
212
+ `jax.jit`/`jax.grad` can see geometry changes.
213
+
214
+ ### Mixed systems
215
+
216
+ Mixed problems can be assembled from residual blocks and solved as a coupled system.
217
+
218
+ ```Python
219
+ import fluxfem as ff
220
+ import jax.numpy as jnp
221
+
222
+ mixed = ff.MixedFESpace({"u": space_u, "p": space_p})
223
+ residuals = ff.make_mixed_residuals(
224
+ u=res_u, # (v, u, params) -> Expr
225
+ p=res_p, # (q, u, params) -> Expr
226
+ )
227
+ problem = ff.MixedProblem(mixed, residuals, params=ff.Params(alpha=1.0))
228
+
229
+ u0 = jnp.zeros(mixed.n_dofs)
230
+ R = problem.assemble_residual(u0)
231
+ J = problem.assemble_jacobian(u0, return_flux_matrix=True)
232
+ ```
233
+
234
+ ### Block assembly
235
+
236
+ For constraints like contact problems (e.g., adding Lagrange multipliers), build
237
+ a block matrix explicitly:
238
+
239
+ ```Python
240
+ from fluxfem import solver as ff_solver
241
+
242
+ # Example blocks from contact coupling
243
+ K_uu = ...
244
+ K_cc = ...
245
+ K_uc = ...
246
+
247
+ blocks = ff_solver.make_block_matrix(
248
+ diag=ff_solver.block_diag(order=("u", "c"), u=K_uu, c=K_cc),
249
+ rel={("u", "c"): K_uc},
250
+ symmetric=True,
251
+ transpose_rule="T",
252
+ )
253
+
254
+ # Lazy container; assemble when you need the global matrix.
255
+ K = blocks.assemble()
256
+ ```
257
+
258
+ FluxFEM also provides contact utilities like `ContactSurfaceSpace` to build constraint contributions.
259
+
260
+
261
+ ## Documentation
262
+
263
+ Full documentation, tutorials, and API reference are hosted at [this site](https://fluxfem.readthedocs.io/en/latest/).
264
+
265
+ ## Tutorials
266
+
267
+ - `tutorials/linearelastic_tensile_bar.py` (linear elasticity, weak-form assembly)
268
+ - `tutorials/neo_hookean_cantilever.py` (nonlinear hyperelasticity)
269
+ - `tutorials/thermoelastic_bar_1d.py` / `tutorials/thermoelastic_bar_1d_mixed.py` (thermoelastic coupling)
270
+ - `tutorials/petsc_shell_poisson_demo.py` (PETSc shell solver integration; see also `tutorials/petsc_shell_poisson_pmat_demo.py`)
271
+
272
+ ## Setup
273
+
274
+ You can install **FluxFEM** either via **pip** or **Poetry**.
275
+
276
+ #### Supported Python Versions
277
+
278
+ FluxFEM supports **Python 3.11–3.13**:
279
+
280
+
281
+ **Choose one of the following methods:**
282
+
283
+ ### Using pip
284
+ ```bash
285
+ pip install fluxfem
286
+ pip install "fluxfem[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
287
+ ```
288
+
289
+ ### Using poetry
290
+ ```bash
291
+ poetry add fluxfem
292
+ poetry add fluxfem[cuda12]
293
+ ```
294
+
295
+ ## PETSc Integration
296
+
297
+ Optional PETSc-based solvers are available via `petsc4py`. Enable with the extra:
298
+
299
+
300
+ ```bash
301
+ pip install "fluxfem[petsc]"
302
+ or
303
+ pip install "fluxfem[petsc,cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
304
+ ```
305
+
306
+ ```bash
307
+ poetry add fluxfem --extras "petsc"
308
+ or
309
+ poetry add "fluxfem[petsc,cuda12]"
310
+ or
311
+ poetry add fluxfem --extras "petsc" --extras "cuda12"
312
+ ```
313
+
314
+ Note: newer `petsc4py` expects PETSc builds that include the `PetscRegressor`
315
+ API. If your PETSc build does not have it, `petsc4py` will fail to compile. In
316
+ that case, rebuild PETSc with regressor support or pin `petsc4py` to a version
317
+ compatible with your PETSc build.
318
+
319
+ Important: you must match the `petsc4py` version to the PETSc version you have
320
+ installed. The current FluxFEM extra pins `petsc4py==3.23.6` (see
321
+ `[project.optional-dependencies]`), so make sure your PETSc install is
322
+ compatible with that `petsc4py` release, or override it to match your PETSc
323
+ build.
324
+
325
+ GPU note: this repo currently tests CUDA via the `cuda12` extra only. Other CUDA
326
+ versions are not covered by CI and may require manual JAX installation.
327
+
328
+ ## Acknowledgements
329
+ I acknowledge the open-source software, libraries, and communities that made this work possible.
330
+
@@ -0,0 +1,303 @@
1
+ [![PyPI version](https://img.shields.io/pypi/v/fluxfem.svg?cacheSeconds=60)](https://pypi.org/project/fluxfem/)
2
+ [![License: Apache-2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
3
+ [![Python Version](https://img.shields.io/pypi/pyversions/fluxfem.svg)](https://pypi.org/project/fluxfem/)
4
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/python-tests.yml/badge.svg)
5
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/sphinx.yml/badge.svg)
6
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.18055465.svg)](https://doi.org/10.5281/zenodo.18055465)
7
+
8
+
9
+ # FluxFEM
10
+ A weak-form-centric differentiable finite element framework in JAX,
11
+ where variational forms are treated as first-class, differentiable programs.
12
+
13
+ ## Examples and Features
14
+ <table>
15
+ <tr>
16
+ <td align="center"><b>Example 1: Diffusion</b></td>
17
+ <td align="center"><b>Example 2: Neo Neohookean Hyper Elasticity</b></td>
18
+ </tr>
19
+ <tr>
20
+ <td align="center">
21
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/diffusion_mms_timeseries.gif" alt="Diffusion-mms" width="400">
22
+ </td>
23
+ <td align="center">
24
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/Neo-Hookean-deformedx20000.png" alt="Neo-Hookean" width="400">
25
+ </td>
26
+ </tr>
27
+ </table>
28
+
29
+
30
+ ## Features
31
+ - Built on JAX, enabling automatic differentiation with grad, jit, vmap, and related transformations.
32
+ - Weak-form–centric API that keeps formulations close to code; weak forms are represented as expression trees and compiled into element kernels, enabling automatic differentiation of residuals, tangents, and objectives.
33
+ - Two assembly approaches: tensor-based (scikit-fem–style) assembly and weak-form-based assembly.
34
+ - Handles both linear and nonlinear analyses with AD in JAX.
35
+ - Optional PETSc/PETSc-shell solvers via `petsc4py` for scalable linear solves (add `fluxfem[petsc]`).
36
+
37
+ ## Usage
38
+
39
+ This library provides two assembly approaches.
40
+
41
+ - A tensor-based assembly, where trial and test functions are represented explicitly as element-level tensors and assembled accordingly (in the style of scikit-fem).
42
+ - A weak-form-based assembly, where the variational form is written symbolically and compiled before assembly.
43
+
44
+ The two approaches are functionally equivalent and share the same element-level execution model,
45
+ but they differ in how you author the weak form. The example below mirrors the paper's diffusion
46
+ case and makes the distinction explicit with `jnp`.
47
+
48
+
49
+ ## Assembly Flow
50
+ All expressions are first compiled into an element-level evaluation plan,
51
+ which operates on quadrature-point–major tensors.
52
+ This plan is then executed independently for each element during assembly.
53
+
54
+ As a result, both assembly approaches:
55
+ - use the same quadrature-major (q, a, i) data layout,
56
+ - perform element-local tensor contractions,
57
+ - and are fully compatible with JAX transformations such as `jit`, `vmap`, and automatic differentiation.
58
+
59
+ ### kernel-based assembly (explicit JIT units)
60
+ If you want to control JIT boundaries explicitly, build a JIT-compiled element kernel
61
+ and pass it to `space.assemble`. The kernel must return the integrated element
62
+ contribution (not the quadrature integrand). For untagged raw kernels, pass `kind=`.
63
+
64
+ ```Python
65
+ import fluxfem as ff
66
+ import jax
67
+ import jax.numpy as jnp
68
+
69
+ space = ff.make_hex_space(mesh, dim=1, intorder=2)
70
+
71
+ # bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
72
+ ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
73
+ K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)
74
+
75
+ # linear: kernel(ctx) -> (n_ldofs,)
76
+ def linear_kernel(ctx):
77
+ integrand = ff.scalar_body_force_form(ctx, 2.0)
78
+ wJ = ctx.w * ctx.test.detJ
79
+ return (integrand * wJ[:, None]).sum(axis=0)
80
+
81
+ ker_F = jax.jit(linear_kernel)
82
+ F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)
83
+ ```
84
+
85
+ ### tensor-based vs weak-form-based (diffusion example)
86
+
87
+ #### tensor-based assembly
88
+ The tensor-based assembly provides an explicit, low-level formulation with element kernels written using jax.numpy.(`jnp`).
89
+ ```Python
90
+ import fluxfem as ff
91
+ import jax.numpy as jnp
92
+
93
+ @ff.kernel(kind="bilinear", domain="volume")
94
+ def diffusion_form(ctx: ff.FormContext, kappa):
95
+ # ctx.test.gradN / ctx.trial.gradN: (n_qp, n_nodes, dim)
96
+ # output tensor: (n_qp, n_nodes, n_nodes)
97
+ return kappa * jnp.einsum("qia,qja->qij", ctx.test.gradN, ctx.trial.gradN)
98
+
99
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
100
+ params = ff.Params(kappa=1.0)
101
+ K_ts = space.assemble(diffusion_form, params=params.kappa)
102
+ ```
103
+
104
+ #### weak-form-based assembly
105
+ In the weak-form-based assembly, the variational formulation itself is the primary object.
106
+ The expression below defines a symbolic computation graph, which is later compiled and executed at the element level.
107
+
108
+ ```Python
109
+ import fluxfem as ff
110
+ import fluxfem.helpers_wf as h_wf
111
+
112
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
113
+ params = ff.Params(kappa=1.0)
114
+
115
+ # u, v are symbolic trial/test fields (weak-form DSL objects).
116
+ # u.grad / v.grad are symbolic nodes (expression tree), not numeric arrays.
117
+ # dOmega() is the integral measure; the whole expression is compiled before assembly.
118
+ form_wf = ff.BilinearForm.volume(
119
+ lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
120
+ ).get_compiled()
121
+
122
+ K_wf = space.assemble(form_wf, params=params)
123
+ ```
124
+
125
+ ### Linear Elasticity assembly (weak-form based assembly)
126
+
127
+ ```Python
128
+ import fluxfem as ff
129
+ import fluxfem.helpers_wf as h_wf
130
+
131
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
132
+ D = ff.isotropic_3d_D(1.0, 0.3)
133
+
134
+ form_wf = ff.BilinearForm.volume(
135
+ lambda u, v, D: h_wf.ddot(v.sym_grad, D @ u.sym_grad) * h_wf.dOmega()
136
+ ).get_compiled()
137
+
138
+ K = space.assemble(form_wf, params=D)
139
+ ```
140
+
141
+ ### Neo-Hookean residual assembly (weak-form DSL)
142
+ Below is a Neo-Hookean hyperelasticity example written in weak form.
143
+ The residual is expressed symbolically and compiled into element-level kernels executed per element.
144
+ No manual derivation of tangent operators is required; consistent tangents (Jacobians) for Newton-type solvers are obtained automatically via JAX AD.
145
+
146
+ ```Python
147
+ def neo_hookean_residual_wf(v, u, params):
148
+ mu = params["mu"]
149
+ lam = params["lam"]
150
+ F = h_wf.I(3) + h_wf.grad(u) # deformation gradient
151
+ C = h_wf.matmul(h_wf.transpose(F), F)
152
+ C_inv = h_wf.inv(C)
153
+ J = h_wf.det(F)
154
+
155
+ S = mu * (h_wf.I(3) - C_inv) + lam * h_wf.log(J) * C_inv
156
+ dE = 0.5 * (h_wf.matmul(h_wf.grad(v), F) + h_wf.transpose(h_wf.matmul(h_wf.grad(v), F)))
157
+ return h_wf.ddot(S, dE) * h_wf.dOmega()
158
+
159
+ res_form = ff.ResidualForm.volume(neo_hookean_residual_wf).get_compiled()
160
+ ```
161
+
162
+
163
+ ### autodiff + jit compile
164
+
165
+ You can differentiate through the solve and JIT compile the hot path.
166
+ The inverse diffusion tutorial shows this pattern:
167
+
168
+ ```Python
169
+ def loss_theta(theta):
170
+ kappa = jnp.exp(theta)
171
+ u = solve_u_jit(kappa, traction_true)
172
+ diff = u[obs_idx_j] - u_obs[obs_idx_j]
173
+ return 0.5 * jnp.mean(diff * diff)
174
+
175
+ solve_u_jit = jax.jit(solve_u)
176
+ loss_theta_jit = jax.jit(loss_theta)
177
+ grad_fn = jax.jit(jax.grad(loss_theta))
178
+ ```
179
+
180
+ ### FESpace vs FESpacePytree
181
+
182
+ Use `FESpace` for standard workflows with a fixed mesh. When you need to carry
183
+ the space through JAX transformations (e.g., shape optimization where mesh
184
+ coordinates are part of the computation), use `FESpacePytree` via
185
+ `make_*_space_pytree(...)`. This keeps the mesh/basis in the pytree so
186
+ `jax.jit`/`jax.grad` can see geometry changes.
187
+
188
+ ### Mixed systems
189
+
190
+ Mixed problems can be assembled from residual blocks and solved as a coupled system.
191
+
192
+ ```Python
193
+ import fluxfem as ff
194
+ import jax.numpy as jnp
195
+
196
+ mixed = ff.MixedFESpace({"u": space_u, "p": space_p})
197
+ residuals = ff.make_mixed_residuals(
198
+ u=res_u, # (v, u, params) -> Expr
199
+ p=res_p, # (q, u, params) -> Expr
200
+ )
201
+ problem = ff.MixedProblem(mixed, residuals, params=ff.Params(alpha=1.0))
202
+
203
+ u0 = jnp.zeros(mixed.n_dofs)
204
+ R = problem.assemble_residual(u0)
205
+ J = problem.assemble_jacobian(u0, return_flux_matrix=True)
206
+ ```
207
+
208
+ ### Block assembly
209
+
210
+ For constraints like contact problems (e.g., adding Lagrange multipliers), build
211
+ a block matrix explicitly:
212
+
213
+ ```Python
214
+ from fluxfem import solver as ff_solver
215
+
216
+ # Example blocks from contact coupling
217
+ K_uu = ...
218
+ K_cc = ...
219
+ K_uc = ...
220
+
221
+ blocks = ff_solver.make_block_matrix(
222
+ diag=ff_solver.block_diag(order=("u", "c"), u=K_uu, c=K_cc),
223
+ rel={("u", "c"): K_uc},
224
+ symmetric=True,
225
+ transpose_rule="T",
226
+ )
227
+
228
+ # Lazy container; assemble when you need the global matrix.
229
+ K = blocks.assemble()
230
+ ```
231
+
232
+ FluxFEM also provides contact utilities like `ContactSurfaceSpace` to build constraint contributions.
233
+
234
+
235
+ ## Documentation
236
+
237
+ Full documentation, tutorials, and API reference are hosted at [this site](https://fluxfem.readthedocs.io/en/latest/).
238
+
239
+ ## Tutorials
240
+
241
+ - `tutorials/linearelastic_tensile_bar.py` (linear elasticity, weak-form assembly)
242
+ - `tutorials/neo_hookean_cantilever.py` (nonlinear hyperelasticity)
243
+ - `tutorials/thermoelastic_bar_1d.py` / `tutorials/thermoelastic_bar_1d_mixed.py` (thermoelastic coupling)
244
+ - `tutorials/petsc_shell_poisson_demo.py` (PETSc shell solver integration; see also `tutorials/petsc_shell_poisson_pmat_demo.py`)
245
+
246
+ ## Setup
247
+
248
+ You can install **FluxFEM** either via **pip** or **Poetry**.
249
+
250
+ #### Supported Python Versions
251
+
252
+ FluxFEM supports **Python 3.11–3.13**:
253
+
254
+
255
+ **Choose one of the following methods:**
256
+
257
+ ### Using pip
258
+ ```bash
259
+ pip install fluxfem
260
+ pip install "fluxfem[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
261
+ ```
262
+
263
+ ### Using poetry
264
+ ```bash
265
+ poetry add fluxfem
266
+ poetry add fluxfem[cuda12]
267
+ ```
268
+
269
+ ## PETSc Integration
270
+
271
+ Optional PETSc-based solvers are available via `petsc4py`. Enable with the extra:
272
+
273
+
274
+ ```bash
275
+ pip install "fluxfem[petsc]"
276
+ or
277
+ pip install "fluxfem[petsc,cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
278
+ ```
279
+
280
+ ```bash
281
+ poetry add fluxfem --extras "petsc"
282
+ or
283
+ poetry add "fluxfem[petsc,cuda12]"
284
+ or
285
+ poetry add fluxfem --extras "petsc" --extras "cuda12"
286
+ ```
287
+
288
+ Note: newer `petsc4py` expects PETSc builds that include the `PetscRegressor`
289
+ API. If your PETSc build does not have it, `petsc4py` will fail to compile. In
290
+ that case, rebuild PETSc with regressor support or pin `petsc4py` to a version
291
+ compatible with your PETSc build.
292
+
293
+ Important: you must match the `petsc4py` version to the PETSc version you have
294
+ installed. The current FluxFEM extra pins `petsc4py==3.23.6` (see
295
+ `[project.optional-dependencies]`), so make sure your PETSc install is
296
+ compatible with that `petsc4py` release, or override it to match your PETSc
297
+ build.
298
+
299
+ GPU note: this repo currently tests CUDA via the `cuda12` extra only. Other CUDA
300
+ versions are not covered by CI and may require manual JAX installation.
301
+
302
+ ## Acknowledgements
303
+ I acknowledge the open-source software, libraries, and communities that made this work possible.