fluxfem 0.1.4__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. fluxfem-0.2.1/PKG-INFO +314 -0
  2. fluxfem-0.2.1/README.md +287 -0
  3. {fluxfem-0.1.4 → fluxfem-0.2.1}/pyproject.toml +31 -7
  4. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/__init__.py +69 -13
  5. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/__init__.py +140 -53
  6. fluxfem-0.2.1/src/fluxfem/core/assembly.py +1382 -0
  7. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/basis.py +75 -54
  8. fluxfem-0.2.1/src/fluxfem/core/context_types.py +60 -0
  9. fluxfem-0.2.1/src/fluxfem/core/dtypes.py +12 -0
  10. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/forms.py +10 -0
  11. fluxfem-0.2.1/src/fluxfem/core/mixed_assembly.py +263 -0
  12. fluxfem-0.2.1/src/fluxfem/core/mixed_space.py +382 -0
  13. fluxfem-0.2.1/src/fluxfem/core/mixed_weakform.py +97 -0
  14. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/solver.py +2 -0
  15. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/space.py +315 -30
  16. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/weakform.py +821 -42
  17. fluxfem-0.2.1/src/fluxfem/helpers_wf.py +97 -0
  18. fluxfem-0.2.1/src/fluxfem/mesh/__init__.py +81 -0
  19. fluxfem-0.2.1/src/fluxfem/mesh/base.py +558 -0
  20. fluxfem-0.2.1/src/fluxfem/mesh/contact.py +841 -0
  21. fluxfem-0.2.1/src/fluxfem/mesh/dtypes.py +12 -0
  22. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/mesh/hex.py +17 -16
  23. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/mesh/io.py +9 -6
  24. fluxfem-0.2.1/src/fluxfem/mesh/mortar.py +3970 -0
  25. fluxfem-0.2.1/src/fluxfem/mesh/supermesh.py +318 -0
  26. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/mesh/surface.py +104 -26
  27. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/mesh/tet.py +16 -7
  28. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/diffusion.py +3 -0
  29. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/elasticity/hyperelastic.py +35 -3
  30. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/elasticity/linear.py +22 -4
  31. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/elasticity/stress.py +9 -5
  32. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/operators.py +12 -5
  33. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/postprocess.py +29 -3
  34. fluxfem-0.2.1/src/fluxfem/solver/__init__.py +92 -0
  35. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/bc.py +38 -2
  36. fluxfem-0.2.1/src/fluxfem/solver/block_matrix.py +284 -0
  37. fluxfem-0.2.1/src/fluxfem/solver/block_system.py +477 -0
  38. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/cg.py +150 -55
  39. fluxfem-0.2.1/src/fluxfem/solver/dirichlet.py +479 -0
  40. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/history.py +15 -3
  41. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/newton.py +260 -70
  42. fluxfem-0.2.1/src/fluxfem/solver/petsc.py +445 -0
  43. fluxfem-0.2.1/src/fluxfem/solver/preconditioner.py +109 -0
  44. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/result.py +18 -0
  45. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/solve_runner.py +208 -23
  46. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/solver.py +35 -12
  47. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/solver/sparse.py +149 -15
  48. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/tools/jit.py +19 -7
  49. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/tools/timer.py +14 -12
  50. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/tools/visualizer.py +16 -4
  51. fluxfem-0.1.4/PKG-INFO +0 -127
  52. fluxfem-0.1.4/README.md +0 -107
  53. fluxfem-0.1.4/src/fluxfem/core/assembly.py +0 -788
  54. fluxfem-0.1.4/src/fluxfem/core/context_types.py +0 -36
  55. fluxfem-0.1.4/src/fluxfem/core/dtypes.py +0 -4
  56. fluxfem-0.1.4/src/fluxfem/helpers_wf.py +0 -48
  57. fluxfem-0.1.4/src/fluxfem/mesh/__init__.py +0 -29
  58. fluxfem-0.1.4/src/fluxfem/mesh/base.py +0 -249
  59. fluxfem-0.1.4/src/fluxfem/solver/__init__.py +0 -47
  60. fluxfem-0.1.4/src/fluxfem/solver/dirichlet.py +0 -126
  61. {fluxfem-0.1.4 → fluxfem-0.2.1}/LICENSE +0 -0
  62. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/data.py +0 -0
  63. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/core/interp.py +0 -0
  64. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/helpers_ts.py +0 -0
  65. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/mesh/predicate.py +0 -0
  66. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/__init__.py +0 -0
  67. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/elasticity/__init__.py +0 -0
  68. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/physics/elasticity/materials.py +0 -0
  69. {fluxfem-0.1.4 → fluxfem-0.2.1}/src/fluxfem/tools/__init__.py +0 -0
fluxfem-0.2.1/PKG-INFO ADDED
@@ -0,0 +1,314 @@
1
+ Metadata-Version: 2.1
2
+ Name: fluxfem
3
+ Version: 0.2.1
4
+ Summary: FluxFEM: A weak-form-centric differentiable finite element framework in JAX
5
+ License: Apache-2.0
6
+ Author: Kohei Watanabe
7
+ Author-email: koheitech001@gmail.com
8
+ Requires-Python: >=3.12,<3.14
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Provides-Extra: cpu
13
+ Provides-Extra: cuda12
14
+ Provides-Extra: petsc
15
+ Requires-Dist: jax (>=0.8.2,<0.9.0) ; extra == "cpu" or extra == "cuda12"
16
+ Requires-Dist: jax-cuda12-pjrt (==0.8.2) ; extra == "cuda12"
17
+ Requires-Dist: jax-cuda12-plugin (==0.8.2) ; extra == "cuda12"
18
+ Requires-Dist: jaxlib (>=0.8.2,<0.9.0) ; extra == "cpu" or extra == "cuda12"
19
+ Requires-Dist: matplotlib (>=3.10.7,<4.0.0)
20
+ Requires-Dist: meshio (>=5.3.5,<6.0.0)
21
+ Requires-Dist: petsc4py (==3.23.6) ; extra == "petsc"
22
+ Requires-Dist: pyproject (>=1!0.1.2,<1!0.2.0)
23
+ Requires-Dist: pyproject-toml (>=0.1.0,<0.2.0)
24
+ Requires-Dist: pyvista (>=0.46.4,<0.47.0)
25
+ Description-Content-Type: text/markdown
26
+
27
+ [![PyPI version](https://img.shields.io/pypi/v/fluxfem.svg?cacheSeconds=60)](https://pypi.org/project/fluxfem/)
28
+ [![License: Apache-2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
29
+ [![Python Version](https://img.shields.io/pypi/pyversions/fluxfem.svg)](https://pypi.org/project/fluxfem/)
30
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/python-tests.yml/badge.svg)
31
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/sphinx.yml/badge.svg)
32
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.18055465.svg)](https://doi.org/10.5281/zenodo.18055465)
33
+
34
+
35
+ # FluxFEM
36
+ A weak-form-centric differentiable finite element framework in JAX,
37
+ where variational forms are treated as first-class, differentiable programs.
38
+
39
+ ## Examples and Features
40
+ <table>
41
+ <tr>
42
+ <td align="center"><b>Example 1: Diffusion</b></td>
43
+ <td align="center"><b>Example 2: Neo Neohookean Hyper Elasticity</b></td>
44
+ </tr>
45
+ <tr>
46
+ <td align="center">
47
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/diffusion_mms_timeseries.gif" alt="Diffusion-mms" width="400">
48
+ </td>
49
+ <td align="center">
50
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/Neo-Hookean-deformedx20000.png" alt="Neo-Hookean" width="400">
51
+ </td>
52
+ </tr>
53
+ </table>
54
+
55
+
56
+ ## Features
57
+ - Built on JAX, enabling automatic differentiation with grad, jit, vmap, and related transformations.
58
+ - Weak-form–centric API that keeps formulations close to code; weak forms are represented as expression trees and compiled into element kernels, enabling automatic differentiation of residuals, tangents, and objectives.
59
+ - Two assembly approaches: tensor-based (scikit-fem–style) assembly and weak-form-based assembly.
60
+ - Handles both linear and nonlinear analyses with AD in JAX.
61
+ - Optional PETSc/PETSc-shell solvers via `petsc4py` for scalable linear solves (add `fluxfem[petsc]`).
62
+
63
+ ## Usage
64
+
65
+ This library provides two assembly approaches.
66
+
67
+ - A tensor-based assembly, where trial and test functions are represented explicitly as element-level tensors and assembled accordingly (in the style of scikit-fem).
68
+ - A weak-form-based assembly, where the variational form is written symbolically and compiled before assembly.
69
+
70
+ The two approaches are functionally equivalent and share the same element-level execution model,
71
+ but they differ in how you author the weak form. The example below mirrors the paper's diffusion
72
+ case and makes the distinction explicit with `jnp`.
73
+
74
+
75
+ ## Assembly Flow
76
+ All expressions are first compiled into an element-level evaluation plan,
77
+ which operates on quadrature-point–major tensors.
78
+ This plan is then executed independently for each element during assembly.
79
+
80
+ As a result, both assembly approaches:
81
+ - use the same quadrature-major (q, a, i) data layout,
82
+ - perform element-local tensor contractions,
83
+ - and are fully compatible with JAX transformations such as `jit`, `vmap`, and automatic differentiation.
84
+
85
+ ### kernel-based assembly (explicit JIT units)
86
+ If you want to control JIT boundaries explicitly, build a JIT-compiled element kernel
87
+ and pass it to `space.assemble`. The kernel must return the integrated element
88
+ contribution (not the quadrature integrand). For untagged raw kernels, pass `kind=`.
89
+
90
+ ```Python
91
+ import fluxfem as ff
92
+ import jax
93
+ import jax.numpy as jnp
94
+
95
+ space = ff.make_hex_space(mesh, dim=1, intorder=2)
96
+
97
+ # bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
98
+ ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
99
+ K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)
100
+
101
+ # linear: kernel(ctx) -> (n_ldofs,)
102
+ def linear_kernel(ctx):
103
+ integrand = ff.scalar_body_force_form(ctx, 2.0)
104
+ wJ = ctx.w * ctx.test.detJ
105
+ return (integrand * wJ[:, None]).sum(axis=0)
106
+
107
+ ker_F = jax.jit(linear_kernel)
108
+ F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)
109
+ ```
110
+
111
+ ### tensor-based vs weak-form-based (diffusion example)
112
+
113
+ #### tensor-based assembly
114
+ The tensor-based assembly provides an explicit, low-level formulation with element kernels written using jax.numpy.(`jnp`).
115
+ ```Python
116
+ import fluxfem as ff
117
+ import jax.numpy as jnp
118
+
119
+ @ff.kernel(kind="bilinear", domain="volume")
120
+ def diffusion_form(ctx: ff.FormContext, kappa):
121
+ # ctx.test.gradN / ctx.trial.gradN: (n_qp, n_nodes, dim)
122
+ # output tensor: (n_qp, n_nodes, n_nodes)
123
+ return kappa * jnp.einsum("qia,qja->qij", ctx.test.gradN, ctx.trial.gradN)
124
+
125
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
126
+ params = ff.Params(kappa=1.0)
127
+ K_ts = space.assemble(diffusion_form, params=params.kappa)
128
+ ```
129
+
130
+ #### weak-form-based assembly
131
+ In the weak-form-based assembly, the variational formulation itself is the primary object.
132
+ The expression below defines a symbolic computation graph, which is later compiled and executed at the element level.
133
+
134
+ ```Python
135
+ import fluxfem as ff
136
+ import fluxfem.helpers_wf as h_wf
137
+
138
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
139
+ params = ff.Params(kappa=1.0)
140
+
141
+ # u, v are symbolic trial/test fields (weak-form DSL objects).
142
+ # u.grad / v.grad are symbolic nodes (expression tree), not numeric arrays.
143
+ # dOmega() is the integral measure; the whole expression is compiled before assembly.
144
+ form_wf = ff.BilinearForm.volume(
145
+ lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
146
+ ).get_compiled()
147
+
148
+ K_wf = space.assemble(form_wf, params=params)
149
+ ```
150
+
151
+ ### Linear Elasticity assembly (weak-form based assembly)
152
+
153
+ ```Python
154
+ import fluxfem as ff
155
+ import fluxfem.helpers_wf as h_wf
156
+
157
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
158
+ D = ff.isotropic_3d_D(1.0, 0.3)
159
+
160
+ form_wf = ff.BilinearForm.volume(
161
+ lambda u, v, D: h_wf.ddot(v.sym_grad, D @ u.sym_grad) * h_wf.dOmega()
162
+ ).get_compiled()
163
+
164
+ K = space.assemble(form_wf, params=D)
165
+ ```
166
+
167
+ ### Neo-Hookean residual assembly (weak-form DSL)
168
+ Below is a Neo-Hookean hyperelasticity example written in weak form.
169
+ The residual is expressed symbolically and compiled into element-level kernels executed per element.
170
+ No manual derivation of tangent operators is required; consistent tangents (Jacobians) for Newton-type solvers are obtained automatically via JAX AD.
171
+
172
+ ```Python
173
+ def neo_hookean_residual_wf(v, u, params):
174
+ mu = params["mu"]
175
+ lam = params["lam"]
176
+ F = h_wf.I(3) + h_wf.grad(u) # deformation gradient
177
+ C = h_wf.matmul(h_wf.transpose(F), F)
178
+ C_inv = h_wf.inv(C)
179
+ J = h_wf.det(F)
180
+
181
+ S = mu * (h_wf.I(3) - C_inv) + lam * h_wf.log(J) * C_inv
182
+ dE = 0.5 * (h_wf.matmul(h_wf.grad(v), F) + h_wf.transpose(h_wf.matmul(h_wf.grad(v), F)))
183
+ return h_wf.ddot(S, dE) * h_wf.dOmega()
184
+
185
+ res_form = ff.ResidualForm.volume(neo_hookean_residual_wf).get_compiled()
186
+ ```
187
+
188
+
189
+ ### autodiff + jit compile
190
+
191
+ You can differentiate through the solve and JIT compile the hot path.
192
+ The inverse diffusion tutorial shows this pattern:
193
+
194
+ ```Python
195
+ def loss_theta(theta):
196
+ kappa = jnp.exp(theta)
197
+ u = solve_u_jit(kappa, traction_true)
198
+ diff = u[obs_idx_j] - u_obs[obs_idx_j]
199
+ return 0.5 * jnp.mean(diff * diff)
200
+
201
+ solve_u_jit = jax.jit(solve_u)
202
+ loss_theta_jit = jax.jit(loss_theta)
203
+ grad_fn = jax.jit(jax.grad(loss_theta))
204
+ ```
205
+
206
+ ### FESpace vs FESpacePytree
207
+
208
+ Use `FESpace` for standard workflows with a fixed mesh. When you need to carry
209
+ the space through JAX transformations (e.g., shape optimization where mesh
210
+ coordinates are part of the computation), use `FESpacePytree` via
211
+ `make_*_space_pytree(...)`. This keeps the mesh/basis in the pytree so
212
+ `jax.jit`/`jax.grad` can see geometry changes.
213
+
214
+ ### Mixed systems
215
+
216
+ Mixed problems can be assembled from residual blocks and solved as a coupled system.
217
+
218
+ ```Python
219
+ import fluxfem as ff
220
+ import jax.numpy as jnp
221
+
222
+ mixed = ff.MixedFESpace({"u": space_u, "p": space_p})
223
+ residuals = ff.make_mixed_residuals(
224
+ u=res_u, # (v, u, params) -> Expr
225
+ p=res_p, # (q, u, params) -> Expr
226
+ )
227
+ problem = ff.MixedProblem(mixed, residuals, params=ff.Params(alpha=1.0))
228
+
229
+ u0 = jnp.zeros(mixed.n_dofs)
230
+ R = problem.assemble_residual(u0)
231
+ J = problem.assemble_jacobian(u0, return_flux_matrix=True)
232
+ ```
233
+
234
+ ### Block assembly
235
+
236
+ For constraints like contact problems (e.g., adding Lagrange multipliers), build
237
+ a block matrix explicitly:
238
+
239
+ ```Python
240
+ from fluxfem import solver as ff_solver
241
+
242
+ # Example blocks from contact coupling
243
+ K_uu = ...
244
+ K_cc = ...
245
+ K_uc = ...
246
+
247
+ blocks = ff_solver.make_block_matrix(
248
+ diag=ff_solver.block_diag(u=K_uu, c=K_cc),
249
+ rel={("u", "c"): K_uc},
250
+ sizes={"u": K_uu.shape[0], "c": K_cc.shape[0]},
251
+ symmetric=True,
252
+ transpose_rule="T",
253
+ )
254
+
255
+ # Lazy container; assemble when you need the global matrix.
256
+ K = blocks.assemble()
257
+ ```
258
+
259
+ FluxFEM also provides contact utilities like `ContactSurfaceSpace` to build constraint contributions.
260
+
261
+
262
+ ## Documentation
263
+
264
+ Full documentation, tutorials, and API reference are hosted at [this site](https://fluxfem.readthedocs.io/en/latest/).
265
+
266
+ ## Tutorials
267
+
268
+ - `tutorials/linearelastic_tensile_bar.py` (linear elasticity, weak-form assembly)
269
+ - `tutorials/neo_hookean_cantilever.py` (nonlinear hyperelasticity)
270
+ - `tutorials/thermoelastic_bar_1d.py` / `tutorials/thermoelastic_bar_1d_mixed.py` (thermoelastic coupling)
271
+ - `tutorials/petsc_shell_poisson_demo.py` (PETSc shell solver integration; see also `tutorials/petsc_shell_poisson_pmat_demo.py`)
272
+
273
+ ## Setup
274
+
275
+ You can install **FluxFEM** either via **pip** or **Poetry**.
276
+
277
+ #### Supported Python Versions
278
+
279
+ FluxFEM supports **Python 3.11–3.13**:
280
+
281
+
282
+ **Choose one of the following methods:**
283
+
284
+ ### Using pip
285
+ ```bash
286
+ pip install fluxfem
287
+ pip install "fluxfem[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
288
+ ```
289
+
290
+ ### Using poetry
291
+ ```bash
292
+ poetry add fluxfem
293
+ poetry add fluxfem[cuda12]
294
+ ```
295
+
296
+ ## PETSc Integration
297
+
298
+ Optional PETSc-based solvers are available via `petsc4py`. Enable with the extra:
299
+
300
+ ```bash
301
+ poetry add fluxfem --extras "petsc"
302
+ ```
303
+
304
+ Note: newer `petsc4py` expects PETSc builds that include the `PetscRegressor`
305
+ API. If your PETSc build does not have it, `petsc4py` will fail to compile. In
306
+ that case, rebuild PETSc with regressor support or pin `petsc4py` to a version
307
+ compatible with your PETSc build.
308
+
309
+ GPU note: this repo currently tests CUDA via the `cuda12` extra only. Other CUDA
310
+ versions are not covered by CI and may require manual JAX installation.
311
+
312
+ ## Acknowledgements
313
+ I acknowledge the open-source software, libraries, and communities that made this work possible.
314
+
@@ -0,0 +1,287 @@
1
+ [![PyPI version](https://img.shields.io/pypi/v/fluxfem.svg?cacheSeconds=60)](https://pypi.org/project/fluxfem/)
2
+ [![License: Apache-2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
3
+ [![Python Version](https://img.shields.io/pypi/pyversions/fluxfem.svg)](https://pypi.org/project/fluxfem/)
4
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/python-tests.yml/badge.svg)
5
+ ![CI](https://github.com/kevin-tofu/fluxfem/actions/workflows/sphinx.yml/badge.svg)
6
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.18055465.svg)](https://doi.org/10.5281/zenodo.18055465)
7
+
8
+
9
+ # FluxFEM
10
+ A weak-form-centric differentiable finite element framework in JAX,
11
+ where variational forms are treated as first-class, differentiable programs.
12
+
13
+ ## Examples and Features
14
+ <table>
15
+ <tr>
16
+ <td align="center"><b>Example 1: Diffusion</b></td>
17
+ <td align="center"><b>Example 2: Neo Neohookean Hyper Elasticity</b></td>
18
+ </tr>
19
+ <tr>
20
+ <td align="center">
21
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/diffusion_mms_timeseries.gif" alt="Diffusion-mms" width="400">
22
+ </td>
23
+ <td align="center">
24
+ <img src="https://media.githubusercontent.com/media/kevin-tofu/fluxfem/main/assets/Neo-Hookean-deformedx20000.png" alt="Neo-Hookean" width="400">
25
+ </td>
26
+ </tr>
27
+ </table>
28
+
29
+
30
+ ## Features
31
+ - Built on JAX, enabling automatic differentiation with grad, jit, vmap, and related transformations.
32
+ - Weak-form–centric API that keeps formulations close to code; weak forms are represented as expression trees and compiled into element kernels, enabling automatic differentiation of residuals, tangents, and objectives.
33
+ - Two assembly approaches: tensor-based (scikit-fem–style) assembly and weak-form-based assembly.
34
+ - Handles both linear and nonlinear analyses with AD in JAX.
35
+ - Optional PETSc/PETSc-shell solvers via `petsc4py` for scalable linear solves (add `fluxfem[petsc]`).
36
+
37
+ ## Usage
38
+
39
+ This library provides two assembly approaches.
40
+
41
+ - A tensor-based assembly, where trial and test functions are represented explicitly as element-level tensors and assembled accordingly (in the style of scikit-fem).
42
+ - A weak-form-based assembly, where the variational form is written symbolically and compiled before assembly.
43
+
44
+ The two approaches are functionally equivalent and share the same element-level execution model,
45
+ but they differ in how you author the weak form. The example below mirrors the paper's diffusion
46
+ case and makes the distinction explicit with `jnp`.
47
+
48
+
49
+ ## Assembly Flow
50
+ All expressions are first compiled into an element-level evaluation plan,
51
+ which operates on quadrature-point–major tensors.
52
+ This plan is then executed independently for each element during assembly.
53
+
54
+ As a result, both assembly approaches:
55
+ - use the same quadrature-major (q, a, i) data layout,
56
+ - perform element-local tensor contractions,
57
+ - and are fully compatible with JAX transformations such as `jit`, `vmap`, and automatic differentiation.
58
+
59
+ ### kernel-based assembly (explicit JIT units)
60
+ If you want to control JIT boundaries explicitly, build a JIT-compiled element kernel
61
+ and pass it to `space.assemble`. The kernel must return the integrated element
62
+ contribution (not the quadrature integrand). For untagged raw kernels, pass `kind=`.
63
+
64
+ ```Python
65
+ import fluxfem as ff
66
+ import jax
67
+ import jax.numpy as jnp
68
+
69
+ space = ff.make_hex_space(mesh, dim=1, intorder=2)
70
+
71
+ # bilinear: kernel(ctx) -> (n_ldofs, n_ldofs)
72
+ ker_K = ff.make_element_bilinear_kernel(ff.diffusion_form, 1.0, jit=True)
73
+ K = space.assemble(ff.diffusion_form, 1.0, kernel=ker_K)
74
+
75
+ # linear: kernel(ctx) -> (n_ldofs,)
76
+ def linear_kernel(ctx):
77
+ integrand = ff.scalar_body_force_form(ctx, 2.0)
78
+ wJ = ctx.w * ctx.test.detJ
79
+ return (integrand * wJ[:, None]).sum(axis=0)
80
+
81
+ ker_F = jax.jit(linear_kernel)
82
+ F = space.assemble(ff.scalar_body_force_form, 2.0, kernel=ker_F)
83
+ ```
84
+
85
+ ### tensor-based vs weak-form-based (diffusion example)
86
+
87
+ #### tensor-based assembly
88
+ The tensor-based assembly provides an explicit, low-level formulation with element kernels written using jax.numpy.(`jnp`).
89
+ ```Python
90
+ import fluxfem as ff
91
+ import jax.numpy as jnp
92
+
93
+ @ff.kernel(kind="bilinear", domain="volume")
94
+ def diffusion_form(ctx: ff.FormContext, kappa):
95
+ # ctx.test.gradN / ctx.trial.gradN: (n_qp, n_nodes, dim)
96
+ # output tensor: (n_qp, n_nodes, n_nodes)
97
+ return kappa * jnp.einsum("qia,qja->qij", ctx.test.gradN, ctx.trial.gradN)
98
+
99
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
100
+ params = ff.Params(kappa=1.0)
101
+ K_ts = space.assemble(diffusion_form, params=params.kappa)
102
+ ```
103
+
104
+ #### weak-form-based assembly
105
+ In the weak-form-based assembly, the variational formulation itself is the primary object.
106
+ The expression below defines a symbolic computation graph, which is later compiled and executed at the element level.
107
+
108
+ ```Python
109
+ import fluxfem as ff
110
+ import fluxfem.helpers_wf as h_wf
111
+
112
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
113
+ params = ff.Params(kappa=1.0)
114
+
115
+ # u, v are symbolic trial/test fields (weak-form DSL objects).
116
+ # u.grad / v.grad are symbolic nodes (expression tree), not numeric arrays.
117
+ # dOmega() is the integral measure; the whole expression is compiled before assembly.
118
+ form_wf = ff.BilinearForm.volume(
119
+ lambda u, v, p: p.kappa * (v.grad @ u.grad) * h_wf.dOmega()
120
+ ).get_compiled()
121
+
122
+ K_wf = space.assemble(form_wf, params=params)
123
+ ```
124
+
125
+ ### Linear Elasticity assembly (weak-form based assembly)
126
+
127
+ ```Python
128
+ import fluxfem as ff
129
+ import fluxfem.helpers_wf as h_wf
130
+
131
+ space = ff.make_hex_space(mesh, dim=3, intorder=2)
132
+ D = ff.isotropic_3d_D(1.0, 0.3)
133
+
134
+ form_wf = ff.BilinearForm.volume(
135
+ lambda u, v, D: h_wf.ddot(v.sym_grad, D @ u.sym_grad) * h_wf.dOmega()
136
+ ).get_compiled()
137
+
138
+ K = space.assemble(form_wf, params=D)
139
+ ```
140
+
141
+ ### Neo-Hookean residual assembly (weak-form DSL)
142
+ Below is a Neo-Hookean hyperelasticity example written in weak form.
143
+ The residual is expressed symbolically and compiled into element-level kernels executed per element.
144
+ No manual derivation of tangent operators is required; consistent tangents (Jacobians) for Newton-type solvers are obtained automatically via JAX AD.
145
+
146
+ ```Python
147
+ def neo_hookean_residual_wf(v, u, params):
148
+ mu = params["mu"]
149
+ lam = params["lam"]
150
+ F = h_wf.I(3) + h_wf.grad(u) # deformation gradient
151
+ C = h_wf.matmul(h_wf.transpose(F), F)
152
+ C_inv = h_wf.inv(C)
153
+ J = h_wf.det(F)
154
+
155
+ S = mu * (h_wf.I(3) - C_inv) + lam * h_wf.log(J) * C_inv
156
+ dE = 0.5 * (h_wf.matmul(h_wf.grad(v), F) + h_wf.transpose(h_wf.matmul(h_wf.grad(v), F)))
157
+ return h_wf.ddot(S, dE) * h_wf.dOmega()
158
+
159
+ res_form = ff.ResidualForm.volume(neo_hookean_residual_wf).get_compiled()
160
+ ```
161
+
162
+
163
+ ### autodiff + jit compile
164
+
165
+ You can differentiate through the solve and JIT compile the hot path.
166
+ The inverse diffusion tutorial shows this pattern:
167
+
168
+ ```Python
169
+ def loss_theta(theta):
170
+ kappa = jnp.exp(theta)
171
+ u = solve_u_jit(kappa, traction_true)
172
+ diff = u[obs_idx_j] - u_obs[obs_idx_j]
173
+ return 0.5 * jnp.mean(diff * diff)
174
+
175
+ solve_u_jit = jax.jit(solve_u)
176
+ loss_theta_jit = jax.jit(loss_theta)
177
+ grad_fn = jax.jit(jax.grad(loss_theta))
178
+ ```
179
+
180
+ ### FESpace vs FESpacePytree
181
+
182
+ Use `FESpace` for standard workflows with a fixed mesh. When you need to carry
183
+ the space through JAX transformations (e.g., shape optimization where mesh
184
+ coordinates are part of the computation), use `FESpacePytree` via
185
+ `make_*_space_pytree(...)`. This keeps the mesh/basis in the pytree so
186
+ `jax.jit`/`jax.grad` can see geometry changes.
187
+
188
+ ### Mixed systems
189
+
190
+ Mixed problems can be assembled from residual blocks and solved as a coupled system.
191
+
192
+ ```Python
193
+ import fluxfem as ff
194
+ import jax.numpy as jnp
195
+
196
+ mixed = ff.MixedFESpace({"u": space_u, "p": space_p})
197
+ residuals = ff.make_mixed_residuals(
198
+ u=res_u, # (v, u, params) -> Expr
199
+ p=res_p, # (q, u, params) -> Expr
200
+ )
201
+ problem = ff.MixedProblem(mixed, residuals, params=ff.Params(alpha=1.0))
202
+
203
+ u0 = jnp.zeros(mixed.n_dofs)
204
+ R = problem.assemble_residual(u0)
205
+ J = problem.assemble_jacobian(u0, return_flux_matrix=True)
206
+ ```
207
+
208
+ ### Block assembly
209
+
210
+ For constraints like contact problems (e.g., adding Lagrange multipliers), build
211
+ a block matrix explicitly:
212
+
213
+ ```Python
214
+ from fluxfem import solver as ff_solver
215
+
216
+ # Example blocks from contact coupling
217
+ K_uu = ...
218
+ K_cc = ...
219
+ K_uc = ...
220
+
221
+ blocks = ff_solver.make_block_matrix(
222
+ diag=ff_solver.block_diag(u=K_uu, c=K_cc),
223
+ rel={("u", "c"): K_uc},
224
+ sizes={"u": K_uu.shape[0], "c": K_cc.shape[0]},
225
+ symmetric=True,
226
+ transpose_rule="T",
227
+ )
228
+
229
+ # Lazy container; assemble when you need the global matrix.
230
+ K = blocks.assemble()
231
+ ```
232
+
233
+ FluxFEM also provides contact utilities like `ContactSurfaceSpace` to build constraint contributions.
234
+
235
+
236
+ ## Documentation
237
+
238
+ Full documentation, tutorials, and API reference are hosted at [this site](https://fluxfem.readthedocs.io/en/latest/).
239
+
240
+ ## Tutorials
241
+
242
+ - `tutorials/linearelastic_tensile_bar.py` (linear elasticity, weak-form assembly)
243
+ - `tutorials/neo_hookean_cantilever.py` (nonlinear hyperelasticity)
244
+ - `tutorials/thermoelastic_bar_1d.py` / `tutorials/thermoelastic_bar_1d_mixed.py` (thermoelastic coupling)
245
+ - `tutorials/petsc_shell_poisson_demo.py` (PETSc shell solver integration; see also `tutorials/petsc_shell_poisson_pmat_demo.py`)
246
+
247
+ ## Setup
248
+
249
+ You can install **FluxFEM** either via **pip** or **Poetry**.
250
+
251
+ #### Supported Python Versions
252
+
253
+ FluxFEM supports **Python 3.11–3.13**:
254
+
255
+
256
+ **Choose one of the following methods:**
257
+
258
+ ### Using pip
259
+ ```bash
260
+ pip install fluxfem
261
+ pip install "fluxfem[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
262
+ ```
263
+
264
+ ### Using poetry
265
+ ```bash
266
+ poetry add fluxfem
267
+ poetry add fluxfem[cuda12]
268
+ ```
269
+
270
+ ## PETSc Integration
271
+
272
+ Optional PETSc-based solvers are available via `petsc4py`. Enable with the extra:
273
+
274
+ ```bash
275
+ poetry add fluxfem --extras "petsc"
276
+ ```
277
+
278
+ Note: newer `petsc4py` expects PETSc builds that include the `PetscRegressor`
279
+ API. If your PETSc build does not have it, `petsc4py` will fail to compile. In
280
+ that case, rebuild PETSc with regressor support or pin `petsc4py` to a version
281
+ compatible with your PETSc build.
282
+
283
+ GPU note: this repo currently tests CUDA via the `cuda12` extra only. Other CUDA
284
+ versions are not covered by CI and may require manual JAX installation.
285
+
286
+ ## Acknowledgements
287
+ I acknowledge the open-source software, libraries, and communities that made this work possible.
@@ -1,12 +1,12 @@
1
1
  [project]
2
2
  name = "fluxfem"
3
- version = "0.1.4"
3
+ version = "0.2.1"
4
4
  description = ""
5
5
  authors = [
6
6
  {name = "Kohei Watanabe",email = "koheitech001@gmail.com"}
7
7
  ]
8
8
  readme = "README.md"
9
- requires-python = ">=3.11,<3.14"
9
+ requires-python = ">=3.12,<3.14"
10
10
  classifiers = [
11
11
  "Programming Language :: Python :: 3",
12
12
  "Programming Language :: Python :: 3.11",
@@ -17,14 +17,20 @@ dependencies = [
17
17
  "pyvista (>=0.46.4,<0.47.0)",
18
18
  "meshio (>=5.3.5,<6.0.0)",
19
19
  "matplotlib (>=3.10.7,<4.0.0)",
20
- # Note: jax versioning currently around 0.8.x. Allow up to next minor.
21
- "jax (>=0.8.2,<0.9.0)",
22
- "jaxlib (>=0.8.2,<0.9.0)"
20
+ # CPU JAX by default; CUDA builds are available via the cuda12 extra.
21
+ "jax (==0.8.2)",
22
+ "jaxlib (==0.8.2)",
23
23
  ]
24
24
 
25
+ [project.optional-dependencies]
26
+ # Keep petsc4py aligned with the PETSc build referenced by PETSC_DIR.
27
+ petsc = ["petsc4py (==3.23.6)"]
28
+ cpu = ["jax (==0.8.2)", "jaxlib (==0.8.2)"]
29
+ cuda12 = ["jax (==0.8.2)", "jaxlib (==0.8.2)", "jax-cuda12-plugin (==0.8.2)", "jax-cuda12-pjrt (==0.8.2)"]
30
+
25
31
  [tool.poetry]
26
32
  name = "fluxfem"
27
- version = "0.1.4"
33
+ version = "0.2.1"
28
34
  description = "FluxFEM: A weak-form-centric differentiable finite element framework in JAX"
29
35
  authors = ["Kohei Watanabe <koheitech001@gmail.com>"]
30
36
  readme = "README.md"
@@ -35,12 +41,17 @@ packages = [
35
41
 
36
42
 
37
43
  [tool.poetry.dependencies]
38
- python = ">=3.11,<3.14"
44
+ python = ">=3.12,<3.14"
39
45
  pyvista = ">=0.46.4,<0.47.0"
40
46
  meshio = ">=5.3.5,<6.0.0"
41
47
  matplotlib = ">=3.10.7,<4.0.0"
42
48
  jax = ">=0.8.2,<0.9.0"
43
49
  jaxlib = ">=0.8.2,<0.9.0"
50
+ jax-cuda12-plugin = { version = "==0.8.2", optional = true }
51
+ jax-cuda12-pjrt = { version = "==0.8.2", optional = true }
52
+ petsc4py = { version = "==3.23.6", optional = true }
53
+ pyproject-toml = "^0.1.0"
54
+ pyproject = "^1!0.1.2"
44
55
 
45
56
 
46
57
  [tool.poetry.group.dev]
@@ -59,6 +70,19 @@ sphinx-rtd-theme = { version = "^3.0.2", python = ">=3.11" }
59
70
  myst-parser = { version = "^4.0.1", python = ">=3.11" }
60
71
  sphinx-autodoc-typehints = { version = "^3.2.0", python = ">=3.11" }
61
72
  sphinx-sitemap = { version = "^2.6.0", python = ">=3.11" }
73
+ pydata-sphinx-theme = { version = "^0.16.1", python = ">=3.11" }
74
+ shapely = "^2.1.2"
75
+ mypy = "^1.19.1"
76
+
77
+ [tool.poetry.extras]
78
+ petsc = ["petsc4py"]
79
+ cpu = ["jax", "jaxlib"]
80
+ cuda12 = ["jax", "jaxlib", "jax-cuda12-plugin", "jax-cuda12-pjrt"]
81
+
82
+ [[tool.poetry.source]]
83
+ name = "jax"
84
+ url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
85
+ priority = "supplemental"
62
86
 
63
87
  [build-system]
64
88
  requires = ["poetry-core>=2.0.0,<3.0.0"]