fluxfem 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fluxfem/__init__.py +68 -0
  2. fluxfem/core/__init__.py +115 -10
  3. fluxfem/core/assembly.py +676 -91
  4. fluxfem/core/basis.py +73 -52
  5. fluxfem/core/dtypes.py +9 -1
  6. fluxfem/core/forms.py +10 -0
  7. fluxfem/core/mixed_assembly.py +263 -0
  8. fluxfem/core/mixed_space.py +348 -0
  9. fluxfem/core/mixed_weakform.py +97 -0
  10. fluxfem/core/solver.py +2 -0
  11. fluxfem/core/space.py +262 -17
  12. fluxfem/core/weakform.py +768 -7
  13. fluxfem/helpers_wf.py +49 -0
  14. fluxfem/mesh/__init__.py +54 -2
  15. fluxfem/mesh/base.py +316 -7
  16. fluxfem/mesh/contact.py +825 -0
  17. fluxfem/mesh/dtypes.py +12 -0
  18. fluxfem/mesh/hex.py +17 -16
  19. fluxfem/mesh/io.py +6 -4
  20. fluxfem/mesh/mortar.py +3907 -0
  21. fluxfem/mesh/supermesh.py +316 -0
  22. fluxfem/mesh/surface.py +22 -4
  23. fluxfem/mesh/tet.py +10 -4
  24. fluxfem/physics/diffusion.py +3 -0
  25. fluxfem/physics/elasticity/hyperelastic.py +3 -0
  26. fluxfem/physics/elasticity/linear.py +9 -2
  27. fluxfem/solver/__init__.py +42 -2
  28. fluxfem/solver/bc.py +38 -2
  29. fluxfem/solver/block_matrix.py +132 -0
  30. fluxfem/solver/block_system.py +454 -0
  31. fluxfem/solver/cg.py +115 -33
  32. fluxfem/solver/dirichlet.py +334 -4
  33. fluxfem/solver/newton.py +237 -60
  34. fluxfem/solver/petsc.py +439 -0
  35. fluxfem/solver/preconditioner.py +106 -0
  36. fluxfem/solver/result.py +18 -0
  37. fluxfem/solver/solve_runner.py +168 -1
  38. fluxfem/solver/solver.py +12 -1
  39. fluxfem/solver/sparse.py +124 -9
  40. fluxfem-0.2.0.dist-info/METADATA +303 -0
  41. fluxfem-0.2.0.dist-info/RECORD +59 -0
  42. fluxfem-0.1.4.dist-info/METADATA +0 -127
  43. fluxfem-0.1.4.dist-info/RECORD +0 -48
  44. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/LICENSE +0 -0
  45. {fluxfem-0.1.4.dist-info → fluxfem-0.2.0.dist-info}/WHEEL +0 -0
fluxfem/solver/newton.py CHANGED
@@ -14,6 +14,7 @@ from ..core.assembly import (
14
14
  )
15
15
  from ..core.solver import spdirect_solve_cpu, spdirect_solve_gpu
16
16
  from .cg import cg_solve, cg_solve_jax
17
+ from .preconditioner import make_block_jacobi_preconditioner
17
18
  from .result import SolverResult
18
19
  from .sparse import SparsityPattern, FluxSparseMatrix
19
20
  from .dirichlet import _normalize_dirichlet
@@ -28,10 +29,12 @@ def newton_solve(
28
29
  tol: float = 1e-8,
29
30
  atol: float = 0.0,
30
31
  maxiter: int = 20,
31
- linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg" (jax), "cg_jax", or "cg_custom"
32
+ linear_solver: str = "spsolve", # "spsolve", "spdirect_solve_gpu", "cg", "cg_jax", "cg_custom", or "cg_matfree"
32
33
  linear_maxiter: int | None = None,
33
34
  linear_tol: float | None = None,
34
35
  linear_preconditioner=None,
36
+ matfree_mode: str = "linearize",
37
+ matfree_cache: dict | None = None,
35
38
  dirichlet=None,
36
39
  callback=None,
37
40
  line_search: bool = False,
@@ -39,6 +42,7 @@ def newton_solve(
39
42
  ls_c: float = 1e-4,
40
43
  external_vector=None,
41
44
  jacobian_pattern=None,
45
+ extra_terms=None,
42
46
  ):
43
47
  """
44
48
  Gridap-style Newton–Raphson solver on free DOFs only.
@@ -48,9 +52,12 @@ def newton_solve(
48
52
  - Convergence: ||R_free||_inf < max(atol, tol * ||R_free0||_inf).
49
53
  - external_vector: optional global RHS (internal - external).
50
54
  - CG path accepts an operator with matvec that acts on free DOFs via a wrapper.
51
- - linear_preconditioner: forwarded to cg_solve/cg_solve_jax (None | "jacobi" | "block_jacobi" | callable).
55
+ - cg_matfree uses JVP/linearize to form a matrix-free matvec (no global Jacobian).
56
+ - linear_preconditioner: forwarded to cg_solve/cg_solve_jax (None | "jacobi" | "block_jacobi" | "diag0" | callable).
57
+ - matfree_cache: optional dict for reusing matrix-free preconditioners across calls.
52
58
  - linear_tol: CG tolerance (defaults to 0.1 * tol if not provided).
53
59
  - jacobian_pattern: optional SparsityPattern to reuse sparsity across load steps.
60
+ - extra_terms: optional list of callbacks returning (K, f[, metrics]) for extra terms.
54
61
  """
55
62
 
56
63
  if dirichlet is not None:
@@ -66,6 +73,13 @@ def newton_solve(
66
73
  dir_dofs = dir_vals = None
67
74
  free_dofs = np.arange(space.n_dofs, dtype=int)
68
75
 
76
+ use_matfree = linear_solver in ("cg_matfree", "cg_jvp")
77
+ if use_matfree and matfree_mode not in ("linearize", "jvp"):
78
+ raise ValueError("matfree_mode must be 'linearize' or 'jvp'")
79
+
80
+ if extra_terms is not None and linear_solver in ("cg", "cg_jax", "cg_custom", "cg_matfree", "cg_jvp"):
81
+ raise ValueError("extra_terms may yield nonsymmetric K; avoid CG-based solvers")
82
+
69
83
  free_dofs_j = jnp.asarray(free_dofs, dtype=jnp.int32)
70
84
  # For block-Jacobi (3x3 per node) we keep node ids of free dofs.
71
85
  node_ids = free_dofs // 3
@@ -109,33 +123,7 @@ def newton_solve(
109
123
  Build 3x3 block-Jacobi inverse per free node.
110
124
  Assumes DOF ordering per node is [ux, uy, uz].
111
125
  """
112
- if len(free_dofs) % 3 != 0:
113
- raise ValueError("block_jacobi assumes 3 DOFs per node.")
114
- rows = np.asarray(J_free.rows)
115
- cols = np.asarray(J_free.cols)
116
- data = np.asarray(J_free.data)
117
- block_rows = node_ids_inv[rows]
118
- block_cols = node_ids_inv[cols]
119
- local_r = rows % 3
120
- local_c = cols % 3
121
- mask_blk = block_rows == block_cols
122
- blk_rows = block_rows[mask_blk]
123
- blk_lr = local_r[mask_blk]
124
- blk_lc = local_c[mask_blk]
125
- blk_data = data[mask_blk]
126
- inv_blocks = np.zeros((n_block, 3, 3), dtype=blk_data.dtype)
127
- inv_blocks[blk_rows, blk_lr, blk_lc] += blk_data
128
- inv_blocks = jnp.asarray(inv_blocks)
129
- # Add tiny damping to avoid singular blocks
130
- inv_blocks = inv_blocks + 1e-12 * jnp.eye(3)[None, :, :]
131
- inv_blocks = jnp.linalg.inv(inv_blocks)
132
-
133
- def precon(r):
134
- r_blocks = r.reshape((n_block, 3))
135
- z_blocks = jnp.einsum("bij,bj->bi", inv_blocks, r_blocks)
136
- return z_blocks.reshape((-1,))
137
-
138
- return precon
126
+ return make_block_jacobi_preconditioner(J_free, dof_per_node=3)
139
127
 
140
128
  def expand_full(u_free: jnp.ndarray) -> jnp.ndarray:
141
129
  if dir_dofs is None:
@@ -145,6 +133,39 @@ def newton_solve(
145
133
  u_full = u_full.at[dir_dofs_j].set(dir_vals_j)
146
134
  return u_full
147
135
 
136
+ extra_metrics = None
137
+
138
+ def _call_extra(u_full_vec):
139
+ if extra_terms is None:
140
+ return None
141
+ K_sum = None
142
+ f_sum = None
143
+ metrics_sum = {}
144
+ for term in extra_terms:
145
+ out = term(np.asarray(u_full_vec))
146
+ if out is None:
147
+ continue
148
+ if len(out) == 2:
149
+ Kc, fc = out
150
+ metrics = None
151
+ else:
152
+ Kc, fc, metrics = out
153
+ Kc = np.asarray(Kc, dtype=float)
154
+ fc = np.asarray(fc, dtype=float)
155
+ if K_sum is None:
156
+ K_sum = Kc
157
+ else:
158
+ K_sum = K_sum + Kc
159
+ if f_sum is None:
160
+ f_sum = fc
161
+ else:
162
+ f_sum = f_sum + fc
163
+ if isinstance(metrics, dict):
164
+ metrics_sum.update(metrics)
165
+ if K_sum is None or f_sum is None:
166
+ return None
167
+ return K_sum, f_sum, (metrics_sum or None)
168
+
148
169
  def eval_residual(u_free_vec):
149
170
  """Residual on free DOFs only."""
150
171
  u_full = expand_full(u_free_vec)
@@ -156,17 +177,32 @@ def newton_solve(
156
177
  res_two = float(jnp.linalg.norm(R_free, ord=2))
157
178
  return R_free, res_inf, res_two, u_full
158
179
 
180
+ def residual_free(u_free_vec):
181
+ u_full = expand_full(u_free_vec)
182
+ R_full = assemble_R(u_full)
183
+ if external_vector is not None:
184
+ R_full = R_full - external_vector
185
+ return R_full[free_dofs_j]
186
+
159
187
  # Pre-jitted element kernels to avoid recompiling inside Newton
160
188
  res_kernel = make_element_residual_kernel(res_form, params)
161
189
  jac_kernel = make_element_jacobian_kernel(res_form, params)
162
190
 
163
191
  def assemble_R(u_full_vec):
164
- return assemble_residual_scatter(space, res_form, u_full_vec, params, kernel=res_kernel)
192
+ nonlocal extra_metrics
193
+ R = assemble_residual_scatter(space, res_form, u_full_vec, params, kernel=res_kernel)
194
+ extra_out = _call_extra(u_full_vec)
195
+ if extra_out is not None:
196
+ _Kc, fc, metrics = extra_out
197
+ extra_metrics = metrics
198
+ R = R + jnp.asarray(fc, dtype=R.dtype)
199
+ return R
165
200
 
166
201
  eff_linear_tol = linear_tol if linear_tol is not None else max(0.1 * tol, 1e-12)
167
202
 
168
203
  def assemble_J(u_full_vec):
169
- return assemble_jacobian_scatter(
204
+ nonlocal extra_metrics
205
+ J = assemble_jacobian_scatter(
170
206
  space,
171
207
  res_form,
172
208
  u_full_vec,
@@ -176,6 +212,40 @@ def newton_solve(
176
212
  return_flux_matrix=True,
177
213
  pattern=J_pattern,
178
214
  )
215
+ extra_out = _call_extra(u_full_vec)
216
+ if extra_out is not None:
217
+ Kc, _fc, metrics = extra_out
218
+ extra_metrics = metrics
219
+ rows = np.asarray(J.pattern.rows, dtype=int)
220
+ cols = np.asarray(J.pattern.cols, dtype=int)
221
+ data = jnp.asarray(J.data) + jnp.asarray(Kc[rows, cols], dtype=J.data.dtype)
222
+ J = J.with_data(data)
223
+ return J
224
+
225
+ matfree_precon = None
226
+ if use_matfree and linear_preconditioner == "diag0":
227
+ cached = matfree_cache.get("inv_diag0") if matfree_cache is not None else None
228
+ if cached is not None:
229
+ inv_diag0 = cached
230
+ matfree_precon = lambda r: inv_diag0 * r
231
+ print("[PRECOND] reuse diag0", flush=True)
232
+ else:
233
+ print("[PRECOND] build diag0", flush=True)
234
+ t_pre0 = time.perf_counter()
235
+ J0 = assemble_J(expand_full(u))
236
+ J0_free = restrict_free_matrix(J0)
237
+ diag0 = jnp.asarray(J0_free.diag(), dtype=u.dtype)
238
+ diag0 = jax.block_until_ready(diag0)
239
+ inv_diag0 = jnp.where(diag0 != 0.0, 1.0 / diag0, 0.0)
240
+
241
+ def precon(r):
242
+ return inv_diag0 * r
243
+
244
+ matfree_precon = precon
245
+ if matfree_cache is not None:
246
+ matfree_cache["inv_diag0"] = inv_diag0
247
+ pre_dt0 = time.perf_counter() - t_pre0
248
+ print(f"[PRECOND] diag0 ready dt={pre_dt0:.3f}s", flush=True)
179
249
 
180
250
  # Initial residual/Jacobian
181
251
  R_full_init = assemble_R(expand_full(u))
@@ -210,14 +280,25 @@ def newton_solve(
210
280
  )
211
281
 
212
282
  if callback is not None:
213
- callback({"iter": 0, "res_inf": res0_inf, "res_two": res0_two, "rel_residual": 1.0, "alpha": 1.0, "step_norm": np.nan})
214
-
215
- J = assemble_J(u_full)
216
- finite_j = jnp.all(jnp.isfinite(J.data))
217
- if not bool(jax.block_until_ready(finite_j)):
218
- n_bad = int(jnp.size(J.data) - jnp.count_nonzero(jnp.isfinite(J.data)))
219
- raise RuntimeError(f"[newton] init Jacobian has non-finite entries: {n_bad}")
220
- J_free = restrict_free_matrix(J)
283
+ payload = {"iter": 0, "res_inf": res0_inf, "res_two": res0_two, "rel_residual": 1.0, "alpha": 1.0, "step_norm": np.nan}
284
+ if extra_metrics is not None:
285
+ payload["extra_metrics"] = extra_metrics
286
+ callback(payload)
287
+
288
+ if not use_matfree:
289
+ J = assemble_J(u_full)
290
+ finite_j = jnp.all(jnp.isfinite(J.data))
291
+ if not bool(jax.block_until_ready(finite_j)):
292
+ n_bad = int(jnp.size(J.data) - jnp.count_nonzero(jnp.isfinite(J.data)))
293
+ raise RuntimeError(f"[newton] init Jacobian has non-finite entries: {n_bad}")
294
+ J_free = restrict_free_matrix(J)
295
+ else:
296
+ J = None
297
+ J_free = None
298
+ lin_info = {}
299
+ step_norm = float("nan")
300
+ linear_converged = True
301
+ lr = None
221
302
  for k in range(maxiter):
222
303
  # --- Newton residual (iteration start) ---
223
304
  t_iter0 = time.perf_counter()
@@ -232,13 +313,34 @@ def newton_solve(
232
313
  raise RuntimeError("[newton] residual became non-finite; aborting.")
233
314
 
234
315
  crit = max(atol, tol * res0_inf)
316
+ contact_log = ""
317
+ if extra_metrics is not None and isinstance(extra_metrics, dict):
318
+ min_g = extra_metrics.get("min_g")
319
+ pen = extra_metrics.get("penetration")
320
+ if min_g is not None and pen is not None:
321
+ contact_log = f" min_g={float(min_g):.3e} pen={float(pen):.3e}"
235
322
  print(
236
- f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}",
323
+ f"[newton] k={k:02d} START |R|inf={res_prev_inf_f:.3e} |R|2={res_prev_two_f:.3e} crit={crit:.3e}{contact_log}",
237
324
  flush=True,
238
325
  )
239
326
 
240
327
  # --- Linear solve (J_free * du = -R_free) ---
328
+ t_rhs0 = time.perf_counter()
241
329
  rhs = jnp.asarray(-R_free, dtype=u.dtype)
330
+ rhs_norm = jnp.linalg.norm(rhs)
331
+ rhs_norm_f = float(jax.block_until_ready(rhs_norm))
332
+ rhs_dt = time.perf_counter() - t_rhs0
333
+ if rhs_norm_f <= atol:
334
+ print(
335
+ f"[newton] k={k:02d} CONVERGED rhs<=atol ({rhs_norm_f:.3e} <= {atol:.3e})",
336
+ flush=True,
337
+ )
338
+ return expand_full(u), SolverResult(
339
+ converged=True,
340
+ iters=k,
341
+ stop_reason="rhs_atol",
342
+ nan_detected=False,
343
+ )
242
344
 
243
345
  # Separate preconditioner build time from linear solve time.
244
346
  t_pre0 = time.perf_counter()
@@ -247,7 +349,54 @@ def newton_solve(
247
349
  linear_residual = None
248
350
  lin_iters = None
249
351
 
250
- if linear_solver in ("cg", "cg_jax", "cg_custom"):
352
+ linearize_dt = 0.0
353
+ if use_matfree:
354
+ if linear_preconditioner in ("jacobi", "block_jacobi"):
355
+ raise ValueError("cg_matfree does not support jacobi preconditioners")
356
+ if linear_preconditioner == "diag0":
357
+ cg_precon = matfree_precon
358
+ elif linear_preconditioner is not None and not callable(linear_preconditioner):
359
+ raise ValueError("cg_matfree preconditioner must be callable or None")
360
+ pre_dt = 0.0
361
+ if linear_preconditioner not in ("diag0", None):
362
+ cg_precon = linear_preconditioner
363
+ print(f"[linear] k={k:02d} {linear_solver}: linearize...", flush=True)
364
+ t_lin0 = time.perf_counter()
365
+ if matfree_mode == "linearize":
366
+ _res, lin_fun = jax.linearize(residual_free, u)
367
+ mv = lambda v: lin_fun(v)
368
+ else:
369
+ mv = lambda v: jax.jvp(residual_free, (u,), (v,))[1]
370
+ linearize_dt = time.perf_counter() - t_lin0
371
+ t_mv0 = 0.0
372
+ mv0_norm_f = None
373
+ t_mv0_0 = time.perf_counter()
374
+ mv0 = mv(rhs)
375
+ mv0_norm = jnp.linalg.norm(mv0)
376
+ mv0_norm_f = float(jax.block_until_ready(mv0_norm))
377
+ t_mv0 = time.perf_counter() - t_mv0_0
378
+ print(
379
+ f"[linear] k={k:02d} {linear_solver}: rhs_dt={rhs_dt:.3f}s "
380
+ f"mv0_dt={t_mv0:.3f}s ||b||={rhs_norm_f:.3e} ||Jb||={mv0_norm_f:.3e}",
381
+ flush=True,
382
+ )
383
+ cg_solver = cg_solve
384
+ print(f"[linear] k={k:02d} {linear_solver}: solve...", flush=True)
385
+ t_cg0 = time.perf_counter()
386
+ du_free, lin_info = cg_solver(
387
+ mv,
388
+ rhs,
389
+ tol=eff_linear_tol,
390
+ maxiter=linear_maxiter,
391
+ preconditioner=cg_precon,
392
+ )
393
+ du_free = jax.block_until_ready(du_free)
394
+ lin_dt = time.perf_counter() - t_cg0
395
+ linear_residual = lin_info.get("residual_norm")
396
+ linear_converged = bool(lin_info.get("converged", True))
397
+ lin_iters = lin_info.get("iters", None)
398
+
399
+ elif linear_solver in ("cg", "cg_jax", "cg_custom"):
251
400
  # Preconditioner build
252
401
  if linear_preconditioner == "jacobi":
253
402
  print(f"[newton] k={k:02d} PRECOND jacobi: diag...", flush=True)
@@ -307,11 +456,18 @@ def newton_solve(
307
456
  raise ValueError(f"Unknown linear solver: {linear_solver}")
308
457
 
309
458
  lr = float(linear_residual) if linear_residual is not None else float("nan")
310
- print(
311
- f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
312
- f"pre_dt={pre_dt:.3f}s lin_dt={lin_dt:.3f}s",
313
- flush=True,
314
- )
459
+ if use_matfree:
460
+ print(
461
+ f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
462
+ f"linz_dt={linearize_dt:.3f}s cg_dt={lin_dt:.3f}s",
463
+ flush=True,
464
+ )
465
+ else:
466
+ print(
467
+ f"[linear] k={k:02d} done iters={lin_iters} conv={linear_converged} lin_res={lr:.3e} "
468
+ f"pre_dt={pre_dt:.3f}s lin_dt={lin_dt:.3f}s",
469
+ flush=True,
470
+ )
315
471
 
316
472
  # --- Trial update & residual evaluation ---
317
473
  # Start with alpha=1 and eval_residual (if heavy, assemble_R is heavy/compiled).
@@ -362,20 +518,21 @@ def newton_solve(
362
518
 
363
519
  # callback
364
520
  if callback is not None:
365
- callback(
366
- {
367
- "iter": k + 1,
368
- "res_inf": res_trial_inf,
369
- "res_two": res_trial_two,
370
- "rel_residual": res_trial_inf / res0_inf,
371
- "alpha": alpha,
372
- "step_norm": step_norm,
373
- "linear_iters": lin_info.get("iters"),
374
- "linear_converged": linear_converged,
375
- "linear_residual": lr,
376
- "nan_detected": bool(np.isnan(res_trial_inf)),
377
- }
378
- )
521
+ payload = {
522
+ "iter": k + 1,
523
+ "res_inf": res_trial_inf,
524
+ "res_two": res_trial_two,
525
+ "rel_residual": res_trial_inf / res0_inf,
526
+ "alpha": alpha,
527
+ "step_norm": step_norm,
528
+ "linear_iters": lin_info.get("iters"),
529
+ "linear_converged": linear_converged,
530
+ "linear_residual": lr,
531
+ "nan_detected": bool(np.isnan(res_trial_inf)),
532
+ }
533
+ if extra_metrics is not None:
534
+ payload["extra_metrics"] = extra_metrics
535
+ callback(payload)
379
536
 
380
537
  # --- Convergence check ---
381
538
  if res_trial_inf < crit and linear_converged and not np.isnan(res_trial_inf):
@@ -398,3 +555,23 @@ def newton_solve(
398
555
  stop_reason="converged",
399
556
  nan_detected=bool(np.isnan(res_trial_inf)),
400
557
  )
558
+
559
+ res_final_inf = float(jnp.linalg.norm(R_free, ord=jnp.inf))
560
+ res_final_two = float(jnp.linalg.norm(R_free, ord=2))
561
+ return u_full, SolverResult(
562
+ converged=False,
563
+ iters=maxiter,
564
+ residual_norm=res_final_inf,
565
+ residual0=res0_inf,
566
+ rel_residual=(res_final_inf / res0_inf if res0_inf != 0.0 else float("inf")),
567
+ line_search_steps=0,
568
+ linear_iters=lin_info.get("iters"),
569
+ linear_converged=linear_converged,
570
+ linear_residual=lr,
571
+ tol=tol,
572
+ atol=atol,
573
+ stopping_criterion=crit,
574
+ step_norm=step_norm,
575
+ stop_reason="maxiter",
576
+ nan_detected=bool(np.isnan(res_final_inf)),
577
+ )