fluxfem 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {fluxfem-0.2.1 → fluxfem-0.2.2}/PKG-INFO +19 -3
  2. {fluxfem-0.2.1 → fluxfem-0.2.2}/README.md +18 -2
  3. {fluxfem-0.2.1 → fluxfem-0.2.2}/pyproject.toml +2 -2
  4. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/bc.py +116 -0
  5. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/block_matrix.py +16 -2
  6. {fluxfem-0.2.1 → fluxfem-0.2.2}/LICENSE +0 -0
  7. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/__init__.py +0 -0
  8. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/__init__.py +0 -0
  9. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/assembly.py +0 -0
  10. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/basis.py +0 -0
  11. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/context_types.py +0 -0
  12. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/data.py +0 -0
  13. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/dtypes.py +0 -0
  14. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/forms.py +0 -0
  15. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/interp.py +0 -0
  16. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/mixed_assembly.py +0 -0
  17. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/mixed_space.py +0 -0
  18. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/mixed_weakform.py +0 -0
  19. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/solver.py +0 -0
  20. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/space.py +0 -0
  21. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/core/weakform.py +0 -0
  22. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/helpers_ts.py +0 -0
  23. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/helpers_wf.py +0 -0
  24. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/__init__.py +0 -0
  25. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/base.py +0 -0
  26. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/contact.py +0 -0
  27. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/dtypes.py +0 -0
  28. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/hex.py +0 -0
  29. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/io.py +0 -0
  30. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/mortar.py +0 -0
  31. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/predicate.py +0 -0
  32. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/supermesh.py +0 -0
  33. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/surface.py +0 -0
  34. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/mesh/tet.py +0 -0
  35. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/__init__.py +0 -0
  36. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/diffusion.py +0 -0
  37. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/__init__.py +0 -0
  38. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/hyperelastic.py +0 -0
  39. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/linear.py +0 -0
  40. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/materials.py +0 -0
  41. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/elasticity/stress.py +0 -0
  42. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/operators.py +0 -0
  43. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/physics/postprocess.py +0 -0
  44. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/__init__.py +0 -0
  45. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/block_system.py +0 -0
  46. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/cg.py +0 -0
  47. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/dirichlet.py +0 -0
  48. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/history.py +0 -0
  49. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/newton.py +0 -0
  50. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/petsc.py +0 -0
  51. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/preconditioner.py +0 -0
  52. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/result.py +0 -0
  53. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/solve_runner.py +0 -0
  54. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/solver.py +0 -0
  55. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/solver/sparse.py +0 -0
  56. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/tools/__init__.py +0 -0
  57. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/tools/jit.py +0 -0
  58. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/tools/timer.py +0 -0
  59. {fluxfem-0.2.1 → fluxfem-0.2.2}/src/fluxfem/tools/visualizer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fluxfem
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: FluxFEM: A weak-form-centric differentiable finite element framework in JAX
5
5
  License: Apache-2.0
6
6
  Author: Kohei Watanabe
@@ -245,9 +245,8 @@ K_cc = ...
245
245
  K_uc = ...
246
246
 
247
247
  blocks = ff_solver.make_block_matrix(
248
- diag=ff_solver.block_diag(u=K_uu, c=K_cc),
248
+ diag=ff_solver.block_diag(order=("u", "c"), u=K_uu, c=K_cc),
249
249
  rel={("u", "c"): K_uc},
250
- sizes={"u": K_uu.shape[0], "c": K_cc.shape[0]},
251
250
  symmetric=True,
252
251
  transpose_rule="T",
253
252
  )
@@ -297,8 +296,19 @@ poetry add fluxfem[cuda12]
297
296
 
298
297
  Optional PETSc-based solvers are available via `petsc4py`. Enable with the extra:
299
298
 
299
+
300
+ ```bash
301
+ pip install "fluxfem[petsc]"
302
+ or
303
+ pip install "fluxfem[petsc,cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
304
+ ```
305
+
300
306
  ```bash
301
307
  poetry add fluxfem --extras "petsc"
308
+ or
309
+ poetry add "fluxfem[petsc,cuda12]"
310
+ or
311
+ poetry add fluxfem --extras "petsc" --extras "cuda12"
302
312
  ```
303
313
 
304
314
  Note: newer `petsc4py` expects PETSc builds that include the `PetscRegressor`
@@ -306,6 +316,12 @@ API. If your PETSc build does not have it, `petsc4py` will fail to compile. In
306
316
  that case, rebuild PETSc with regressor support or pin `petsc4py` to a version
307
317
  compatible with your PETSc build.
308
318
 
319
+ Important: you must match the `petsc4py` version to the PETSc version you have
320
+ installed. The current FluxFEM extra pins `petsc4py==3.23.6` (see
321
+ `[project.optional-dependencies]`), so make sure your PETSc install is
322
+ compatible with that `petsc4py` release, or override it to match your PETSc
323
+ build.
324
+
309
325
  GPU note: this repo currently tests CUDA via the `cuda12` extra only. Other CUDA
310
326
  versions are not covered by CI and may require manual JAX installation.
311
327
 
@@ -219,9 +219,8 @@ K_cc = ...
219
219
  K_uc = ...
220
220
 
221
221
  blocks = ff_solver.make_block_matrix(
222
- diag=ff_solver.block_diag(u=K_uu, c=K_cc),
222
+ diag=ff_solver.block_diag(order=("u", "c"), u=K_uu, c=K_cc),
223
223
  rel={("u", "c"): K_uc},
224
- sizes={"u": K_uu.shape[0], "c": K_cc.shape[0]},
225
224
  symmetric=True,
226
225
  transpose_rule="T",
227
226
  )
@@ -271,8 +270,19 @@ poetry add fluxfem[cuda12]
271
270
 
272
271
  Optional PETSc-based solvers are available via `petsc4py`. Enable with the extra:
273
272
 
273
+
274
+ ```bash
275
+ pip install "fluxfem[petsc]"
276
+ or
277
+ pip install "fluxfem[petsc,cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
278
+ ```
279
+
274
280
  ```bash
275
281
  poetry add fluxfem --extras "petsc"
282
+ or
283
+ poetry add "fluxfem[petsc,cuda12]"
284
+ or
285
+ poetry add fluxfem --extras "petsc" --extras "cuda12"
276
286
  ```
277
287
 
278
288
  Note: newer `petsc4py` expects PETSc builds that include the `PetscRegressor`
@@ -280,6 +290,12 @@ API. If your PETSc build does not have it, `petsc4py` will fail to compile. In
280
290
  that case, rebuild PETSc with regressor support or pin `petsc4py` to a version
281
291
  compatible with your PETSc build.
282
292
 
293
+ Important: you must match the `petsc4py` version to the PETSc version you have
294
+ installed. The current FluxFEM extra pins `petsc4py==3.23.6` (see
295
+ `[project.optional-dependencies]`), so make sure your PETSc install is
296
+ compatible with that `petsc4py` release, or override it to match your PETSc
297
+ build.
298
+
283
299
  GPU note: this repo currently tests CUDA via the `cuda12` extra only. Other CUDA
284
300
  versions are not covered by CI and may require manual JAX installation.
285
301
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "fluxfem"
3
- version = "0.2.1"
3
+ version = "0.2.2"
4
4
  description = ""
5
5
  authors = [
6
6
  {name = "Kohei Watanabe",email = "koheitech001@gmail.com"}
@@ -30,7 +30,7 @@ cuda12 = ["jax (==0.8.2)", "jaxlib (==0.8.2)", "jax-cuda12-plugin (==0.8.2)", "j
30
30
 
31
31
  [tool.poetry]
32
32
  name = "fluxfem"
33
- version = "0.2.1"
33
+ version = "0.2.2"
34
34
  description = "FluxFEM: A weak-form-centric differentiable finite element framework in JAX"
35
35
  authors = ["Kohei Watanabe <koheitech001@gmail.com>"]
36
36
  readme = "README.md"
@@ -4,6 +4,7 @@ from dataclasses import dataclass
4
4
  from typing import Optional, Sequence
5
5
  import numpy as np
6
6
  import numpy.typing as npt
7
+ import jax
7
8
  import jax.numpy as jnp
8
9
 
9
10
  from ..mesh.surface import SurfaceMesh
@@ -212,6 +213,121 @@ def assemble_surface_linear_form(
212
213
  """
213
214
  Assemble a linear form over surface facets using a weak-form callback.
214
215
  """
216
+ def _is_jax(x) -> bool:
217
+ return isinstance(x, jax.Array) or isinstance(x, jax.core.Tracer)
218
+
219
+ def _assemble_surface_linear_form_jax():
220
+ facets = jnp.asarray(surface.conn, dtype=jnp.int32)
221
+ coords = jnp.asarray(surface.coords)
222
+ if facets.ndim != 2 or facets.shape[1] not in (3, 4):
223
+ raise NotImplementedError("JAX surface assembly supports only tri/quad facets.")
224
+
225
+ n_nodes = surface.n_nodes if n_total_nodes is None else int(n_total_nodes)
226
+ n_dofs = n_nodes * dim
227
+ dtype = coords.dtype if F0 is None else jnp.asarray(F0).dtype
228
+ F = jnp.zeros(n_dofs, dtype=dtype) if F0 is None else jnp.asarray(F0, dtype=dtype)
229
+ if F.shape[0] != n_dofs:
230
+ raise ValueError(f"F length {F.shape[0]} does not match expected {n_dofs}")
231
+
232
+ coords_f = coords[facets]
233
+ facet_ids = jnp.arange(facets.shape[0], dtype=jnp.int32)
234
+ outward_from = jnp.mean(coords, axis=0)
235
+
236
+ def facet_fe_tri(facet_id, nodes):
237
+ p0, p1, p2 = nodes[0], nodes[1], nodes[2]
238
+ cross = jnp.cross(p1 - p0, p2 - p0)
239
+ area = 0.5 * jnp.linalg.norm(cross)
240
+ centroid = (p0 + p1 + p2) / 3.0
241
+ norm = jnp.linalg.norm(cross)
242
+ norm = jnp.where(norm < 1e-12, 1e-12, norm)
243
+ normal = cross / norm
244
+ v = centroid - outward_from
245
+ normal = jnp.where(jnp.dot(normal, v) < 0.0, -normal, normal)
246
+
247
+ N = jnp.array([[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]], dtype=dtype)
248
+ w = jnp.array([1.0], dtype=dtype)
249
+ ctx = SurfaceFormContext(
250
+ v=SurfaceFormField(N=N, value_dim=dim),
251
+ x_q=centroid[None, :],
252
+ w=w,
253
+ detJ=jnp.array([area], dtype=dtype),
254
+ facet_id=facet_id,
255
+ normal=normal,
256
+ )
257
+ fe_q = form(ctx, params)
258
+ if getattr(form, "_includes_measure", False):
259
+ fe = jnp.einsum("qi->i", fe_q)
260
+ else:
261
+ fe = jnp.einsum("qi,q->i", fe_q, w * area)
262
+ return fe
263
+
264
+ def facet_fe_quad(facet_id, nodes):
265
+ p0, p1, p2, p3 = nodes[0], nodes[1], nodes[2], nodes[3]
266
+ cross = jnp.cross(p1 - p0, p3 - p0)
267
+ norm = jnp.linalg.norm(cross)
268
+ norm = jnp.where(norm < 1e-12, 1e-12, norm)
269
+ normal = cross / norm
270
+ centroid = (p0 + p1 + p2 + p3) / 4.0
271
+ v = centroid - outward_from
272
+ normal = jnp.where(jnp.dot(normal, v) < 0.0, -normal, normal)
273
+
274
+ gp = jnp.array([-1.0 / jnp.sqrt(3.0), 1.0 / jnp.sqrt(3.0)], dtype=dtype)
275
+ xi, eta = jnp.meshgrid(gp, gp, indexing="xy")
276
+ xi = xi.reshape(-1)
277
+ eta = eta.reshape(-1)
278
+ w = jnp.ones_like(xi)
279
+
280
+ N = 0.25 * jnp.stack(
281
+ [
282
+ (1 - xi) * (1 - eta),
283
+ (1 + xi) * (1 - eta),
284
+ (1 + xi) * (1 + eta),
285
+ (1 - xi) * (1 + eta),
286
+ ],
287
+ axis=1,
288
+ )
289
+ dN_dxi = 0.25 * jnp.stack(
290
+ [-(1 - eta), (1 - eta), (1 + eta), -(1 + eta)], axis=1
291
+ )
292
+ dN_deta = 0.25 * jnp.stack(
293
+ [-(1 - xi), -(1 + xi), (1 + xi), (1 - xi)], axis=1
294
+ )
295
+
296
+ dx_dxi = jnp.einsum("qa,ai->qi", dN_dxi, nodes)
297
+ dx_deta = jnp.einsum("qa,ai->qi", dN_deta, nodes)
298
+ detJ = jnp.linalg.norm(jnp.cross(dx_dxi, dx_deta), axis=1)
299
+ x_q = jnp.einsum("qa,ai->qi", N, nodes)
300
+
301
+ ctx = SurfaceFormContext(
302
+ v=SurfaceFormField(N=N, value_dim=dim),
303
+ x_q=x_q,
304
+ w=w,
305
+ detJ=detJ,
306
+ facet_id=facet_id,
307
+ normal=normal,
308
+ )
309
+ fe_q = form(ctx, params)
310
+ if getattr(form, "_includes_measure", False):
311
+ fe = jnp.einsum("qi->i", fe_q)
312
+ else:
313
+ fe = jnp.einsum("qi,q->i", fe_q, w * detJ)
314
+ return fe
315
+
316
+ if facets.shape[1] == 3:
317
+ fe = jax.vmap(facet_fe_tri)(facet_ids, coords_f)
318
+ else:
319
+ fe = jax.vmap(facet_fe_quad)(facet_ids, coords_f)
320
+
321
+ fe = fe.astype(dtype)
322
+ offsets = jnp.arange(dim, dtype=jnp.int32)
323
+ dofs = facets[:, :, None] * dim + offsets[None, None, :]
324
+ dofs = dofs.reshape(facets.shape[0], -1)
325
+ F = F.at[dofs].add(fe)
326
+ return F
327
+
328
+ if _is_jax(surface.coords) or _is_jax(surface.conn) or _is_jax(F0):
329
+ return _assemble_surface_linear_form_jax()
330
+
215
331
  facets = np.asarray(surface.conn, dtype=int)
216
332
  coords = np.asarray(surface.coords)
217
333
  n_nodes = surface.n_nodes if n_total_nodes is None else int(n_total_nodes)
@@ -18,8 +18,22 @@ FieldKey: TypeAlias = str | int
18
18
  BlockMap: TypeAlias = dict[FieldKey, dict[FieldKey, MatrixLike]]
19
19
 
20
20
 
21
- def diag(**blocks: MatrixLike) -> dict[str, MatrixLike]:
22
- return dict(blocks)
21
+ def diag(*, order: Sequence[FieldKey] | None = None, **blocks: MatrixLike) -> dict[FieldKey, MatrixLike]:
22
+ """
23
+ Build a dict of diagonal blocks with an optional explicit field order.
24
+ """
25
+ if order is None:
26
+ return dict(blocks)
27
+ ordered_blocks: dict[FieldKey, MatrixLike] = {}
28
+ for name in order:
29
+ if name not in blocks:
30
+ raise KeyError(f"Missing block '{name}' in block_diag order")
31
+ ordered_blocks[name] = blocks[name]
32
+ extra = set(blocks) - set(order)
33
+ if extra:
34
+ extra_list = ", ".join(str(name) for name in sorted(extra, key=str))
35
+ raise KeyError(f"Unknown block(s) not in order: {extra_list}")
36
+ return ordered_blocks
23
37
 
24
38
 
25
39
  def _infer_sizes_from_diag(diag_blocks: Mapping[FieldKey, MatrixLike]) -> dict[FieldKey, int]:
File without changes
File without changes
File without changes
File without changes
File without changes