probability-flow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. probability_flow/__init__.py +41 -0
  2. probability_flow/aspic/__init__.py +21 -0
  3. probability_flow/aspic/argument.py +182 -0
  4. probability_flow/aspic/calibrate.py +178 -0
  5. probability_flow/aspic/compile.py +175 -0
  6. probability_flow/aspic/generate.py +397 -0
  7. probability_flow/aspic/handle.py +281 -0
  8. probability_flow/aspic/visualization.py +128 -0
  9. probability_flow/core/__init__.py +29 -0
  10. probability_flow/core/_logmath.py +40 -0
  11. probability_flow/core/bp/__init__.py +5 -0
  12. probability_flow/core/bp/engine.py +198 -0
  13. probability_flow/core/bp/message.py +30 -0
  14. probability_flow/core/cpd/__init__.py +13 -0
  15. probability_flow/core/cpd/base.py +84 -0
  16. probability_flow/core/cpd/independent_evidence.py +155 -0
  17. probability_flow/core/cpd/noisy_and.py +113 -0
  18. probability_flow/core/cpd/noisy_or.py +109 -0
  19. probability_flow/core/cpd/tabular.py +111 -0
  20. probability_flow/core/exact.py +67 -0
  21. probability_flow/core/network.py +111 -0
  22. probability_flow/core/node.py +125 -0
  23. probability_flow/metrics/__init__.py +64 -0
  24. probability_flow/metrics/_util.py +42 -0
  25. probability_flow/metrics/difficulty.py +87 -0
  26. probability_flow/metrics/dseparation.py +83 -0
  27. probability_flow/metrics/loopiness.py +82 -0
  28. probability_flow/metrics/manipulability.py +207 -0
  29. probability_flow/metrics/structure.py +49 -0
  30. probability_flow/py.typed +0 -0
  31. probability_flow/visualization/__init__.py +11 -0
  32. probability_flow/visualization/image.py +402 -0
  33. probability_flow/visualization/style.py +58 -0
  34. probability_flow-0.1.0.dist-info/METADATA +304 -0
  35. probability_flow-0.1.0.dist-info/RECORD +37 -0
  36. probability_flow-0.1.0.dist-info/WHEEL +4 -0
  37. probability_flow-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,41 @@
1
+ """probability_flow: a from-scratch, modular discrete Bayesian-network library.
2
+
3
+ The design is in `docs/SPEC.md`; settled choices in `docs/DECISIONS.md` and
4
+ milestones in `docs/ROADMAP.md`. The public API currently lives in
5
+ `probability_flow.core` and is re-exported here for convenience:
6
+
7
+ from probability_flow import Node, ExactSolver
8
+ """
9
+ from importlib.metadata import PackageNotFoundError, version
10
+
11
+ from .core import (
12
+ CPD,
13
+ BayesianNetwork,
14
+ CompiledCPD,
15
+ ExactSolver,
16
+ IndependentEvidenceCPD,
17
+ LoopySolver,
18
+ Node,
19
+ NoisyAndCPD,
20
+ NoisyOrCPD,
21
+ TabularCPD,
22
+ )
23
+
24
+ try:
25
+ __version__ = version("probability-flow")
26
+ except PackageNotFoundError: # running from a source checkout, not installed
27
+ __version__ = "0.0.0+unknown"
28
+
29
+ __all__ = [
30
+ "__version__",
31
+ "Node",
32
+ "BayesianNetwork",
33
+ "CompiledCPD",
34
+ "ExactSolver",
35
+ "LoopySolver",
36
+ "CPD",
37
+ "TabularCPD",
38
+ "IndependentEvidenceCPD",
39
+ "NoisyOrCPD",
40
+ "NoisyAndCPD",
41
+ ]
@@ -0,0 +1,21 @@
1
+ """probability_flow.aspic: the ASPIC argument-compilation layer.
2
+
3
+ The first of several planned domain wrappers (legal, medical, AI-safety) over the
4
+ pure-BN `core`. Build an argument out of premises, conclusions, and attacks, then
5
+ `compile()` the target to an ordinary `BayesianNetwork`. See `docs/aspic.md`.
6
+
7
+ from probability_flow.aspic import Premise, Axiom, Conclusion
8
+ """
9
+ from .argument import ArgumentWarning, Axiom, Conclusion, Premise
10
+ from .generate import (
11
+ ArgumentGenerator,
12
+ DifficultyTargets,
13
+ StructuralParams,
14
+ generate,
15
+ )
16
+ from .handle import Argument
17
+
18
+ __all__ = [
19
+ "Premise", "Axiom", "Conclusion", "ArgumentWarning", "Argument",
20
+ "ArgumentGenerator", "StructuralParams", "DifficultyTargets", "generate",
21
+ ]
@@ -0,0 +1,182 @@
1
+ """The ASPIC authoring layer: free constructors over the core BN.
2
+
3
+ `Premise`, `Axiom`, and `Conclusion` are core `Node`s (so a compiled argument is
4
+ queried by passing the object straight to a solver, exactly as in core). They are
5
+ one underlying node type; the distinction is intention plus light validation. The
6
+ argumentative structure is *declared* with fluent methods on the downstream node
7
+ and only *lowered* to core CPDs by `compile()` (see `compile.py`). Keeping the
8
+ declaration separate from the core CPD makes `compile()` a pure function of the
9
+ graph, so it is repeatable.
10
+
11
+ Every argumentative edge is a method on its downstream node: `support` / `rebut` /
12
+ `strict` add an upstream argument to a conclusion, `undermine` attacks a premise,
13
+ `undercut` attacks an inference. The whole graph is therefore reachable by walking
14
+ inputs from the root target, the traversal `compile()` performs. See
15
+ `docs/aspic.md`.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import warnings
20
+ from typing import TYPE_CHECKING
21
+
22
+ from ..core import Node
23
+
24
+ if TYPE_CHECKING:
25
+ from ..core import BayesianNetwork
26
+ from .handle import Argument
27
+
28
+
29
+ class ArgumentWarning(UserWarning):
30
+ """Soft validation: a leaf conclusion, rebutting a premise, a non-disjoint
31
+ undercut. These do not stop compilation (they match core's permissiveness),
32
+ they flag a likely modelling slip."""
33
+
34
+
35
+ class _Claim(Node):
36
+ """A proposition in the argument: a core `Node` plus declared, not-yet-lowered
37
+ argumentative edges. Not constructed directly; use `Premise` / `Axiom` /
38
+ `Conclusion`."""
39
+
40
+ role = "claim"
41
+
42
+ def __init__(self, name: str, prior: float = 0.5):
43
+ super().__init__(name, prior=prior)
44
+ self._edges: list[tuple["Node", float, str]] = [] # (src, lr, kind)
45
+ self._strict: list["Node"] = [] # strict sources
46
+ self._undercuts: list[tuple["Node", "Node"]] = [] # (attacked source, undercutter)
47
+ self._no_undermine = False
48
+ # opaque, serialized but uninterpreted by the library (see docs/aspic.md):
49
+ self.desc: str | None = None # a longer description
50
+ self.revealable_to_judge = None # benchmark metadata, semantics TBD
51
+
52
+ def _require_new_edge(self, src: "Node") -> None:
53
+ """At most one argumentative edge joins a pair of claims. This is what lets
54
+ an undercut be addressed by its endpoints and keeps serialized edge ids
55
+ (`source->target`) unique."""
56
+ if any(s is src for s, _lr, _kind in self._edges) or any(
57
+ s is src for s in self._strict
58
+ ):
59
+ raise ValueError(
60
+ f"{self.name!r} already has an edge from {src.name!r}; at most one "
61
+ "argumentative edge may join a pair of claims"
62
+ )
63
+
64
+ # --- defeasible and strict support (methods on the downstream conclusion) --
65
+
66
+ def support(self, src: "Node", lr: float) -> "Node":
67
+ """Add a defeasible argument *for* this conclusion (`lr > 1`). Returns
68
+ `src`, so an inline source can be built further upstream, as with core
69
+ `add_input`."""
70
+ if not lr > 1:
71
+ raise ValueError(
72
+ f"support lr must be > 1 (got {lr}); use rebut for an argument against"
73
+ )
74
+ self._require_new_edge(src)
75
+ self._edges.append((src, float(lr), "support"))
76
+ return src
77
+
78
+ def rebut(self, src: "Node", lr: float) -> "Node":
79
+ """Add a defeasible argument *against* this conclusion (`0 < lr < 1`)."""
80
+ if not 0 < lr < 1:
81
+ raise ValueError(
82
+ f"rebut lr must be in (0, 1) (got {lr}); use support for an argument for"
83
+ )
84
+ if self.role == "premise":
85
+ warnings.warn(
86
+ f"rebutting {self.name!r}, a premise: a rebutted node is really a "
87
+ "conclusion. Either model it as one, or use undermine to attack a premise.",
88
+ ArgumentWarning,
89
+ stacklevel=2,
90
+ )
91
+ self._require_new_edge(src)
92
+ self._edges.append((src, float(lr), "rebut"))
93
+ return src
94
+
95
+ def strict(self, src: "Node") -> "Node":
96
+ """Add a strict argument: `src` true forces this conclusion (a rebut cannot
97
+ pull it down). Lowered by parent-divorcing."""
98
+ self._require_new_edge(src)
99
+ self._strict.append(src)
100
+ return src
101
+
102
+ # --- attacks (also methods on the downstream node) ------------------------
103
+
104
+ def undermine(self, by: "Node", lr: float) -> "Node":
105
+ """Attack this premise with `by` (`0 < lr < 1`). Mechanically a rebut into
106
+ the attacked node, named distinctly because it is the ASPIC-correct verb
107
+ for attacking a premise. An axiom cannot be undermined."""
108
+ if self._no_undermine:
109
+ raise ValueError(
110
+ f"cannot undermine {self.name!r}: an axiom has no defeasible premise to attack"
111
+ )
112
+ if not 0 < lr < 1:
113
+ raise ValueError(f"undermine lr must be in (0, 1) (got {lr})")
114
+ self._require_new_edge(by)
115
+ self._edges.append((by, float(lr), "undermine"))
116
+ return by
117
+
118
+ def undercut(self, source: "Node", by: "Node") -> "Node":
119
+ """Attack the `source -> self` inference with `by`: where `source`
120
+ supports (or strictly entails) this conclusion, `by` asserts the inference
121
+ does not apply. Full neutralization (`tau = 1`) is hard-coded. Returns the
122
+ undercutter `by`."""
123
+ self._undercuts.append((source, by))
124
+ return by
125
+
126
+ def compile(self) -> "BayesianNetwork":
127
+ """Lower this argument (as the target) to a core `BayesianNetwork`."""
128
+ from .compile import compile_argument
129
+
130
+ return compile_argument(self)
131
+
132
+ def assemble(self) -> "Argument":
133
+ """Bundle this argument (as the target) into an `Argument` handle: the home
134
+ for serialization (`to_json` / `save`), the compiled network (`.bn`), and
135
+ the argument-level metrics. Authoring is unchanged; this just hands back an
136
+ object to serialize and measure from."""
137
+ from .handle import Argument
138
+
139
+ return Argument(self)
140
+
141
+ @property
142
+ def bn(self) -> "BayesianNetwork":
143
+ """The compiled network for this argument (a fresh compile each access)."""
144
+ return self.compile()
145
+
146
+ def render(self, **kwargs):
147
+ """Draw the argument at the argument level (premises, conclusions, and
148
+ attacks), hiding the lowered splice machinery. See `aspic.visualization`.
149
+ For the full compiled network instead, use `self.bn.render()`."""
150
+ from .visualization import render_argument
151
+
152
+ return render_argument(self, **kwargs)
153
+
154
+
155
+ class Premise(_Claim):
156
+ """An asserted proposition: a leaf with a credence and no incoming argument.
157
+ Underminable. `prior` is its credence."""
158
+
159
+ role = "premise"
160
+
161
+ def __init__(self, name: str, prior: float = 0.5):
162
+ super().__init__(name, prior=prior)
163
+
164
+
165
+ class Axiom(Premise):
166
+ """A premise held with certainty (`prior = 1`) that cannot be undermined."""
167
+
168
+ role = "axiom"
169
+
170
+ def __init__(self, name: str):
171
+ super().__init__(name, prior=1.0)
172
+ self._no_undermine = True
173
+
174
+
175
+ class Conclusion(_Claim):
176
+ """A proposition argued for: it carries a prior (default `0.5`) and accumulates
177
+ support, rebut, and strict arguments."""
178
+
179
+ role = "conclusion"
180
+
181
+ def __init__(self, name: str, prior: float = 0.5):
182
+ super().__init__(name, prior=prior)
@@ -0,0 +1,178 @@
1
+ """JAX calibration of a built argument (optional — needs the `[jax]` extra).
2
+
3
+ The generator's structure choices (depth, branching, edge kinds) are discrete and
4
+ non-differentiable, so the core stays plain numpy. But for a *fixed* structure the
5
+ continuous parameters — node priors and edge likelihood ratios — are smooth, and a
6
+ generated graph is a polytree, so its root posterior is the topo-ordered forward
7
+ pass of the same message math. This module expresses that forward pass in JAX, so:
8
+
9
+ - `sensitivities(arg)` returns `d p_root / d parameter` for every prior and LR (via
10
+ `jax.grad`) — a per-parameter importance signal; and
11
+ - `calibrate_posterior(arg, target)` solves the root's prior and edge LRs for a
12
+ target posterior by projected gradient descent and writes them back. This is the
13
+ multi-parameter upgrade of the dependency-free bisection in `generate` (which
14
+ tunes the root prior alone); it keeps every other subtree's marginals fixed.
15
+
16
+ Defined for the CPD shapes the ASPIC compiler emits on a polytree
17
+ (`IndependentEvidence`, the strict noisy-OR, the undercut AND-NOT splice). A
18
+ non-polytree (shared / non-disjoint structure) is out of scope — the forward
19
+ marginal pass is only exact on a tree.
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import math
24
+ from typing import TYPE_CHECKING
25
+
26
+ if TYPE_CHECKING:
27
+ from .handle import Argument
28
+
29
+
30
+ def _require_jax():
31
+ try:
32
+ import jax
33
+ import jax.numpy as jnp
34
+ except ImportError as exc: # pragma: no cover - exercised only without the extra
35
+ raise ImportError(
36
+ "calibrate needs JAX; install the optional extra: "
37
+ "pip install 'probability-flow[jax]'"
38
+ ) from exc
39
+ return jax, jnp
40
+
41
+
42
+ def _logit(p: float) -> float:
43
+ p = min(1 - 1e-6, max(1e-6, p))
44
+ return math.log(p / (1 - p))
45
+
46
+
47
+ def _is_and_not_splice(cpd) -> bool:
48
+ from ..core import TabularCPD
49
+
50
+ if not isinstance(cpd, TabularCPD) or len(cpd.inputs) != 2:
51
+ return False
52
+ import numpy as np
53
+
54
+ p1 = np.exp(cpd.log_table[..., 1])
55
+ return np.allclose(p1, np.array([[0.0, 0.0], [1.0, 0.0]]))
56
+
57
+
58
+ def _combos(k: int):
59
+ import numpy as np
60
+
61
+ return np.array([[(c >> i) & 1 for i in range(k)] for c in range(2**k)], dtype=float)
62
+
63
+
64
+ def _forward(arg: "Argument"):
65
+ """Build the polytree forward pass. Returns `(p_root, theta0, meta)` where
66
+ `p_root(theta)` is a JAX function of the flat parameter vector, `theta0` are the
67
+ current values (logit priors, log LRs), and `meta` labels each entry."""
68
+ jax, jnp = _require_jax()
69
+ from ..core import IndependentEvidenceCPD, NoisyOrCPD
70
+
71
+ bn = arg.bn
72
+ nodes = list(bn.nodes) # topological (inputs before node)
73
+ cpd_of = {n: bn.compiled_cpd(n).cpd for n in nodes}
74
+ ins_of = {n: tuple(bn.compiled_cpd(n).inputs) for n in nodes}
75
+
76
+ prior_idx, lr_idx, theta0, meta = {}, {}, [], []
77
+ for n in nodes:
78
+ cpd = cpd_of[n]
79
+ if isinstance(cpd, IndependentEvidenceCPD):
80
+ prior_idx[n] = len(theta0)
81
+ theta0.append(_logit(n.prior))
82
+ meta.append(("prior", n.name))
83
+ for inp, lr in zip(cpd.inputs, cpd.lrs):
84
+ lr_idx[(n, inp)] = len(theta0)
85
+ theta0.append(math.log(lr))
86
+ meta.append(("lr", (inp.name, n.name)))
87
+ combos = {n: jnp.asarray(_combos(len(ins_of[n]))) for n in nodes
88
+ if isinstance(cpd_of[n], IndependentEvidenceCPD) and ins_of[n]}
89
+
90
+ def p_root(theta):
91
+ m = {}
92
+ for n in nodes:
93
+ cpd, ins = cpd_of[n], ins_of[n]
94
+ if isinstance(cpd, IndependentEvidenceCPD):
95
+ b = theta[prior_idx[n]]
96
+ if not ins:
97
+ m[n] = jax.nn.sigmoid(b)
98
+ else:
99
+ w = jnp.stack([theta[lr_idx[(n, i)]] for i in ins])
100
+ pin = jnp.stack([m[i] for i in ins])
101
+ C = combos[n]
102
+ wts = jnp.prod(jnp.where(C == 1, pin, 1 - pin), axis=1)
103
+ m[n] = jnp.sum(jax.nn.sigmoid(b + C @ w) * wts)
104
+ elif isinstance(cpd, NoisyOrCPD):
105
+ a = jnp.asarray(cpd.activations)
106
+ pin = jnp.stack([m[i] for i in ins])
107
+ m[n] = 1 - (1 - cpd.leak) * jnp.prod(1 - a * pin)
108
+ elif _is_and_not_splice(cpd):
109
+ e, u = ins
110
+ m[n] = m[e] * (1 - m[u])
111
+ else:
112
+ raise NotImplementedError(
113
+ f"calibrate does not handle {type(cpd).__name__} on {n.name!r}"
114
+ )
115
+ return m[arg.target]
116
+
117
+ return p_root, jnp.asarray(theta0), meta
118
+
119
+
120
+ def sensitivities(arg: "Argument") -> dict:
121
+ """`d p_root / d parameter` at the current values, for every prior and LR.
122
+
123
+ Keys are `("prior", node_name)` and `("lr", (source_name, target_name))`; values
124
+ are gradients in the natural unconstrained space (logit prior, log LR). A large
125
+ magnitude marks a parameter the verdict is sensitive to — a manipulability /
126
+ importance signal."""
127
+ jax, _ = _require_jax()
128
+ p_root, theta0, meta = _forward(arg)
129
+ grad = jax.grad(p_root)(theta0)
130
+ return {key: float(g) for key, g in zip(meta, grad)}
131
+
132
+
133
+ def calibrate_posterior(arg: "Argument", target: float, *,
134
+ steps: int = 500, step_size: float = 0.3) -> float:
135
+ """Solve the root's prior and edge LRs so its posterior hits `target`, by
136
+ projected gradient descent (priors kept in `[0.02, 0.98]`, LRs kept to their
137
+ original sign with `|log lr| <= log 50`), then write the result back onto the
138
+ argument. Every other subtree's marginals are held fixed. Returns the achieved
139
+ posterior. Use the bisection in `generate` for the dependency-free single-knob
140
+ version."""
141
+ jax, jnp = _require_jax()
142
+ from ..core import IndependentEvidenceCPD, LoopySolver
143
+
144
+ bn = arg.bn
145
+ root = arg.target
146
+ cpd = bn.compiled_cpd(root).cpd
147
+ if not isinstance(cpd, IndependentEvidenceCPD):
148
+ raise NotImplementedError("calibrate_posterior targets a defeasible root")
149
+ ins = list(cpd.inputs)
150
+ solver = LoopySolver(bn)
151
+ pin = jnp.asarray([solver.prob(i, 1) for i in ins]) # fixed input marginals
152
+ k = len(ins)
153
+ C = jnp.asarray(_combos(k)) if k else None
154
+ w0 = jnp.asarray([math.log(lr) for lr in cpd.lrs])
155
+ signs = jnp.sign(w0)
156
+ lo_b, hi_b = _logit(0.02), _logit(0.98)
157
+ lo_w, hi_w = math.log(1.01), math.log(50.0)
158
+
159
+ def predict(theta):
160
+ b, w = theta[0], theta[1:]
161
+ if k == 0:
162
+ return jax.nn.sigmoid(b)
163
+ wts = jnp.prod(jnp.where(C == 1, pin, 1 - pin), axis=1)
164
+ return jnp.sum(jax.nn.sigmoid(b + C @ w) * wts)
165
+
166
+ grad = jax.grad(lambda th: (predict(th) - target) ** 2)
167
+ theta = jnp.concatenate([jnp.asarray([_logit(root.prior)]), w0])
168
+ for _ in range(steps):
169
+ theta = theta - step_size * grad(theta)
170
+ b = jnp.clip(theta[0], lo_b, hi_b)
171
+ w = signs * jnp.clip(jnp.abs(theta[1:]), lo_w, hi_w)
172
+ theta = jnp.concatenate([jnp.asarray([b]), w])
173
+
174
+ root.prior = float(jax.nn.sigmoid(theta[0]))
175
+ new_lrs = [float(math.exp(x)) for x in theta[1:]]
176
+ for j, (src, _lr, kind) in enumerate(root._edges): # aligned with cpd.inputs
177
+ root._edges[j] = (src, new_lrs[j], kind)
178
+ return float(predict(theta))
@@ -0,0 +1,175 @@
1
+ """Lowering an ASPIC argument to a core `BayesianNetwork`. See `docs/aspic.md`.
2
+
3
+ `compile_argument(target)` collects everything reachable from the target, runs
4
+ validation (raising on sign / axiom errors, warning on soft slips), lowers each
5
+ claim to core CPDs, then defers to the core compiler. The lowerings:
6
+
7
+ - premise / conclusion with only defeasible edges -> `IndependentEvidenceCPD`
8
+ (a premise is just the zero-support case; an undermined premise gains an
9
+ `lr < 1` input).
10
+ - a conclusion with strict edges -> parent-divorcing: a hidden `D` carries the
11
+ defeasible part (or just the prior), and the conclusion becomes a leak-0
12
+ `NoisyOr` over its strict inputs and `D`, each at activation 1.
13
+ - an undercut on the `source -> C` edge -> a spliced `X = source AND NOT U`
14
+ (`tau = 1`) replacing `source` on that edge, with multiple undercutters first
15
+ combined by an upstream leak-0 `NoisyOr` into one `U`.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import warnings
20
+
21
+ from ..core import BayesianNetwork, IndependentEvidenceCPD, Node, TabularCPD
22
+ from .argument import ArgumentWarning, _Claim
23
+
24
+
25
+ def compile_argument(target: "_Claim") -> "BayesianNetwork":
26
+ claims = _reachable(target)
27
+ _validate(claims)
28
+ for c in claims:
29
+ if isinstance(c, _Claim):
30
+ _lower(c)
31
+ return BayesianNetwork.from_nodes([target])
32
+
33
+
34
+ # --- traversal -------------------------------------------------------------
35
+
36
+ def _upstream(c: "Node") -> list["Node"]:
37
+ """The declared upstream sources of a claim: edge and strict sources plus any
38
+ undercutters. Plain (non-claim) nodes are treated as leaves."""
39
+ out: list["Node"] = []
40
+ out.extend(src for src, _lr, _kind in getattr(c, "_edges", []))
41
+ out.extend(getattr(c, "_strict", []))
42
+ out.extend(by for _source, by in getattr(c, "_undercuts", []))
43
+ return out
44
+
45
+
46
+ def _reachable(target: "Node") -> list["Node"]:
47
+ """Every node reachable from the target through declared argument edges, in a
48
+ deterministic, dependency-respecting order (sources before the nodes using
49
+ them is not required here; the core compiler re-orders)."""
50
+ seen: dict[int, "Node"] = {}
51
+ order: list["Node"] = []
52
+ stack = [target]
53
+ while stack:
54
+ c = stack.pop()
55
+ if id(c) in seen:
56
+ continue
57
+ seen[id(c)] = c
58
+ order.append(c)
59
+ stack.extend(_upstream(c))
60
+ return order
61
+
62
+
63
+ def _leaf_premises(c: "Node") -> set["Node"]:
64
+ """The leaf premises (no incoming argument) upstream of `c`, for the undercut
65
+ disjointness check."""
66
+ out: set["Node"] = set()
67
+ seen: set[int] = set()
68
+ stack = [c]
69
+ while stack:
70
+ n = stack.pop()
71
+ if id(n) in seen:
72
+ continue
73
+ seen.add(id(n))
74
+ ups = _upstream(n)
75
+ if not ups:
76
+ out.add(n)
77
+ else:
78
+ stack.extend(ups)
79
+ return out
80
+
81
+
82
+ # --- validation ------------------------------------------------------------
83
+
84
+ def _validate(claims: list["Node"]) -> None:
85
+ for c in claims:
86
+ if not isinstance(c, _Claim):
87
+ continue
88
+
89
+ if c.role == "conclusion" and not c._edges and not c._strict:
90
+ warnings.warn(
91
+ f"{c.name!r} is a conclusion with no incoming argument; a leaf "
92
+ "should be a Premise.",
93
+ ArgumentWarning,
94
+ stacklevel=2,
95
+ )
96
+
97
+ edge_sources = {id(src) for src, _lr, _kind in c._edges}
98
+ strict_sources = {id(s) for s in c._strict}
99
+ for source, by in c._undercuts:
100
+ if id(source) not in edge_sources and id(source) not in strict_sources:
101
+ raise ValueError(
102
+ f"undercut on {c.name!r}: {source.name!r} is not a source of any "
103
+ "edge into it; undercut attacks an existing inference"
104
+ )
105
+ shared = _leaf_premises(source) & _leaf_premises(by)
106
+ if shared:
107
+ names = ", ".join(sorted(p.name for p in shared))
108
+ warnings.warn(
109
+ f"undercut on the {source.name!r} -> {c.name!r} edge shares "
110
+ f"premise(s) {{{names}}} with its undercutter {by.name!r}; the "
111
+ "splice is exact only for disjoint branches.",
112
+ ArgumentWarning,
113
+ stacklevel=2,
114
+ )
115
+
116
+
117
+ # --- lowering --------------------------------------------------------------
118
+
119
+ def _and_not_cpd(x: "Node", e: "Node", u: "Node") -> "TabularCPD":
120
+ """`X = 1` iff `E = 1` and `U = 0`; the `tau = 1` undercut splice."""
121
+ # table[e][u] = P(X | e, u); only (E=1, U=0) yields X=1.
122
+ table = [
123
+ [[1.0, 0.0], [1.0, 0.0]], # E=0: X=0 regardless of U
124
+ [[0.0, 1.0], [1.0, 0.0]], # E=1: U=0 -> X=1, U=1 -> X=0
125
+ ]
126
+ return TabularCPD(x, [e, u], table=table)
127
+
128
+
129
+ def _splice_for(c: "_Claim") -> dict[int, "Node"]:
130
+ """Build the undercut splice node for each attacked source on `c`, keyed by the
131
+ source's id. Multiple undercutters on one edge are OR-combined into one `U`."""
132
+ grouped: dict[int, tuple["Node", list["Node"]]] = {}
133
+ for source, u in c._undercuts:
134
+ grouped.setdefault(id(source), (source, []))[1].append(u)
135
+
136
+ splice: dict[int, "Node"] = {}
137
+ for sid, (source, undercutters) in grouped.items():
138
+ if len(undercutters) == 1:
139
+ u_node = undercutters[0]
140
+ else:
141
+ u_node = Node(f"{c.name}/U[{source.name}]")
142
+ u_node.noisy_or(leak=0.0)
143
+ for u in undercutters:
144
+ u_node.add_input(u, activation=1.0)
145
+ x = Node(f"{source.name} (spliced)")
146
+ x.set_cpd(_and_not_cpd(x, source, u_node))
147
+ splice[sid] = x
148
+ return splice
149
+
150
+
151
+ def _lower(c: "_Claim") -> None:
152
+ """Set `c`'s core CPD (and any hidden helper nodes) from its declared edges.
153
+ Idempotent: the CPD is rebuilt from scratch, so re-compiling is safe."""
154
+ c._cpd = IndependentEvidenceCPD(c) # reset for a clean, repeatable lowering
155
+
156
+ splice = _splice_for(c)
157
+
158
+ def resolve(src: "Node") -> "Node":
159
+ return splice.get(id(src), src)
160
+
161
+ defeasible = [(resolve(src), lr) for src, lr, _kind in c._edges]
162
+ strict = [resolve(src) for src in c._strict]
163
+
164
+ if strict:
165
+ # parent-divorcing: D holds the defeasible part (or just the prior).
166
+ d = Node(f"{c.name}/defeasible", prior=c.prior)
167
+ for src, lr in defeasible:
168
+ d.add_input(src, lr=lr)
169
+ c.noisy_or(leak=0.0)
170
+ for src in strict:
171
+ c.add_input(src, activation=1.0)
172
+ c.add_input(d, activation=1.0)
173
+ else:
174
+ for src, lr in defeasible:
175
+ c.add_input(src, lr=lr)