probability-flow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- probability_flow/__init__.py +41 -0
- probability_flow/aspic/__init__.py +21 -0
- probability_flow/aspic/argument.py +182 -0
- probability_flow/aspic/calibrate.py +178 -0
- probability_flow/aspic/compile.py +175 -0
- probability_flow/aspic/generate.py +397 -0
- probability_flow/aspic/handle.py +281 -0
- probability_flow/aspic/visualization.py +128 -0
- probability_flow/core/__init__.py +29 -0
- probability_flow/core/_logmath.py +40 -0
- probability_flow/core/bp/__init__.py +5 -0
- probability_flow/core/bp/engine.py +198 -0
- probability_flow/core/bp/message.py +30 -0
- probability_flow/core/cpd/__init__.py +13 -0
- probability_flow/core/cpd/base.py +84 -0
- probability_flow/core/cpd/independent_evidence.py +155 -0
- probability_flow/core/cpd/noisy_and.py +113 -0
- probability_flow/core/cpd/noisy_or.py +109 -0
- probability_flow/core/cpd/tabular.py +111 -0
- probability_flow/core/exact.py +67 -0
- probability_flow/core/network.py +111 -0
- probability_flow/core/node.py +125 -0
- probability_flow/metrics/__init__.py +64 -0
- probability_flow/metrics/_util.py +42 -0
- probability_flow/metrics/difficulty.py +87 -0
- probability_flow/metrics/dseparation.py +83 -0
- probability_flow/metrics/loopiness.py +82 -0
- probability_flow/metrics/manipulability.py +207 -0
- probability_flow/metrics/structure.py +49 -0
- probability_flow/py.typed +0 -0
- probability_flow/visualization/__init__.py +11 -0
- probability_flow/visualization/image.py +402 -0
- probability_flow/visualization/style.py +58 -0
- probability_flow-0.1.0.dist-info/METADATA +304 -0
- probability_flow-0.1.0.dist-info/RECORD +37 -0
- probability_flow-0.1.0.dist-info/WHEEL +4 -0
- probability_flow-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""probability_flow: a from-scratch, modular discrete Bayesian-network library.
|
|
2
|
+
|
|
3
|
+
The design is in `docs/SPEC.md`; settled choices in `docs/DECISIONS.md` and
|
|
4
|
+
milestones in `docs/ROADMAP.md`. The public API currently lives in
|
|
5
|
+
`probability_flow.core` and is re-exported here for convenience:
|
|
6
|
+
|
|
7
|
+
from probability_flow import Node, ExactSolver
|
|
8
|
+
"""
|
|
9
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
10
|
+
|
|
11
|
+
from .core import (
|
|
12
|
+
CPD,
|
|
13
|
+
BayesianNetwork,
|
|
14
|
+
CompiledCPD,
|
|
15
|
+
ExactSolver,
|
|
16
|
+
IndependentEvidenceCPD,
|
|
17
|
+
LoopySolver,
|
|
18
|
+
Node,
|
|
19
|
+
NoisyAndCPD,
|
|
20
|
+
NoisyOrCPD,
|
|
21
|
+
TabularCPD,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
__version__ = version("probability-flow")
|
|
26
|
+
except PackageNotFoundError: # running from a source checkout, not installed
|
|
27
|
+
__version__ = "0.0.0+unknown"
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"__version__",
|
|
31
|
+
"Node",
|
|
32
|
+
"BayesianNetwork",
|
|
33
|
+
"CompiledCPD",
|
|
34
|
+
"ExactSolver",
|
|
35
|
+
"LoopySolver",
|
|
36
|
+
"CPD",
|
|
37
|
+
"TabularCPD",
|
|
38
|
+
"IndependentEvidenceCPD",
|
|
39
|
+
"NoisyOrCPD",
|
|
40
|
+
"NoisyAndCPD",
|
|
41
|
+
]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""probability_flow.aspic: the ASPIC argument-compilation layer.
|
|
2
|
+
|
|
3
|
+
The first of several planned domain wrappers (legal, medical, AI-safety) over the
|
|
4
|
+
pure-BN `core`. Build an argument out of premises, conclusions, and attacks, then
|
|
5
|
+
`compile()` the target to an ordinary `BayesianNetwork`. See `docs/aspic.md`.
|
|
6
|
+
|
|
7
|
+
from probability_flow.aspic import Premise, Axiom, Conclusion
|
|
8
|
+
"""
|
|
9
|
+
from .argument import ArgumentWarning, Axiom, Conclusion, Premise
|
|
10
|
+
from .generate import (
|
|
11
|
+
ArgumentGenerator,
|
|
12
|
+
DifficultyTargets,
|
|
13
|
+
StructuralParams,
|
|
14
|
+
generate,
|
|
15
|
+
)
|
|
16
|
+
from .handle import Argument
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"Premise", "Axiom", "Conclusion", "ArgumentWarning", "Argument",
|
|
20
|
+
"ArgumentGenerator", "StructuralParams", "DifficultyTargets", "generate",
|
|
21
|
+
]
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""The ASPIC authoring layer: free constructors over the core BN.
|
|
2
|
+
|
|
3
|
+
`Premise`, `Axiom`, and `Conclusion` are core `Node`s (so a compiled argument is
|
|
4
|
+
queried by passing the object straight to a solver, exactly as in core). They are
|
|
5
|
+
one underlying node type; the distinction is intention plus light validation. The
|
|
6
|
+
argumentative structure is *declared* with fluent methods on the downstream node
|
|
7
|
+
and only *lowered* to core CPDs by `compile()` (see `compile.py`). Keeping the
|
|
8
|
+
declaration separate from the core CPD makes `compile()` a pure function of the
|
|
9
|
+
graph, so it is repeatable.
|
|
10
|
+
|
|
11
|
+
Every argumentative edge is a method on its downstream node: `support` / `rebut` /
|
|
12
|
+
`strict` add an upstream argument to a conclusion, `undermine` attacks a premise,
|
|
13
|
+
`undercut` attacks an inference. The whole graph is therefore reachable by walking
|
|
14
|
+
inputs from the root target, the traversal `compile()` performs. See
|
|
15
|
+
`docs/aspic.md`.
|
|
16
|
+
"""
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import warnings
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
21
|
+
|
|
22
|
+
from ..core import Node
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ..core import BayesianNetwork
|
|
26
|
+
from .handle import Argument
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ArgumentWarning(UserWarning):
|
|
30
|
+
"""Soft validation: a leaf conclusion, rebutting a premise, a non-disjoint
|
|
31
|
+
undercut. These do not stop compilation (they match core's permissiveness),
|
|
32
|
+
they flag a likely modelling slip."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _Claim(Node):
|
|
36
|
+
"""A proposition in the argument: a core `Node` plus declared, not-yet-lowered
|
|
37
|
+
argumentative edges. Not constructed directly; use `Premise` / `Axiom` /
|
|
38
|
+
`Conclusion`."""
|
|
39
|
+
|
|
40
|
+
role = "claim"
|
|
41
|
+
|
|
42
|
+
def __init__(self, name: str, prior: float = 0.5):
|
|
43
|
+
super().__init__(name, prior=prior)
|
|
44
|
+
self._edges: list[tuple["Node", float, str]] = [] # (src, lr, kind)
|
|
45
|
+
self._strict: list["Node"] = [] # strict sources
|
|
46
|
+
self._undercuts: list[tuple["Node", "Node"]] = [] # (attacked source, undercutter)
|
|
47
|
+
self._no_undermine = False
|
|
48
|
+
# opaque, serialized but uninterpreted by the library (see docs/aspic.md):
|
|
49
|
+
self.desc: str | None = None # a longer description
|
|
50
|
+
self.revealable_to_judge = None # benchmark metadata, semantics TBD
|
|
51
|
+
|
|
52
|
+
def _require_new_edge(self, src: "Node") -> None:
|
|
53
|
+
"""At most one argumentative edge joins a pair of claims. This is what lets
|
|
54
|
+
an undercut be addressed by its endpoints and keeps serialized edge ids
|
|
55
|
+
(`source->target`) unique."""
|
|
56
|
+
if any(s is src for s, _lr, _kind in self._edges) or any(
|
|
57
|
+
s is src for s in self._strict
|
|
58
|
+
):
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"{self.name!r} already has an edge from {src.name!r}; at most one "
|
|
61
|
+
"argumentative edge may join a pair of claims"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# --- defeasible and strict support (methods on the downstream conclusion) --
|
|
65
|
+
|
|
66
|
+
def support(self, src: "Node", lr: float) -> "Node":
|
|
67
|
+
"""Add a defeasible argument *for* this conclusion (`lr > 1`). Returns
|
|
68
|
+
`src`, so an inline source can be built further upstream, as with core
|
|
69
|
+
`add_input`."""
|
|
70
|
+
if not lr > 1:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"support lr must be > 1 (got {lr}); use rebut for an argument against"
|
|
73
|
+
)
|
|
74
|
+
self._require_new_edge(src)
|
|
75
|
+
self._edges.append((src, float(lr), "support"))
|
|
76
|
+
return src
|
|
77
|
+
|
|
78
|
+
def rebut(self, src: "Node", lr: float) -> "Node":
|
|
79
|
+
"""Add a defeasible argument *against* this conclusion (`0 < lr < 1`)."""
|
|
80
|
+
if not 0 < lr < 1:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"rebut lr must be in (0, 1) (got {lr}); use support for an argument for"
|
|
83
|
+
)
|
|
84
|
+
if self.role == "premise":
|
|
85
|
+
warnings.warn(
|
|
86
|
+
f"rebutting {self.name!r}, a premise: a rebutted node is really a "
|
|
87
|
+
"conclusion. Either model it as one, or use undermine to attack a premise.",
|
|
88
|
+
ArgumentWarning,
|
|
89
|
+
stacklevel=2,
|
|
90
|
+
)
|
|
91
|
+
self._require_new_edge(src)
|
|
92
|
+
self._edges.append((src, float(lr), "rebut"))
|
|
93
|
+
return src
|
|
94
|
+
|
|
95
|
+
def strict(self, src: "Node") -> "Node":
|
|
96
|
+
"""Add a strict argument: `src` true forces this conclusion (a rebut cannot
|
|
97
|
+
pull it down). Lowered by parent-divorcing."""
|
|
98
|
+
self._require_new_edge(src)
|
|
99
|
+
self._strict.append(src)
|
|
100
|
+
return src
|
|
101
|
+
|
|
102
|
+
# --- attacks (also methods on the downstream node) ------------------------
|
|
103
|
+
|
|
104
|
+
def undermine(self, by: "Node", lr: float) -> "Node":
|
|
105
|
+
"""Attack this premise with `by` (`0 < lr < 1`). Mechanically a rebut into
|
|
106
|
+
the attacked node, named distinctly because it is the ASPIC-correct verb
|
|
107
|
+
for attacking a premise. An axiom cannot be undermined."""
|
|
108
|
+
if self._no_undermine:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"cannot undermine {self.name!r}: an axiom has no defeasible premise to attack"
|
|
111
|
+
)
|
|
112
|
+
if not 0 < lr < 1:
|
|
113
|
+
raise ValueError(f"undermine lr must be in (0, 1) (got {lr})")
|
|
114
|
+
self._require_new_edge(by)
|
|
115
|
+
self._edges.append((by, float(lr), "undermine"))
|
|
116
|
+
return by
|
|
117
|
+
|
|
118
|
+
def undercut(self, source: "Node", by: "Node") -> "Node":
|
|
119
|
+
"""Attack the `source -> self` inference with `by`: where `source`
|
|
120
|
+
supports (or strictly entails) this conclusion, `by` asserts the inference
|
|
121
|
+
does not apply. Full neutralization (`tau = 1`) is hard-coded. Returns the
|
|
122
|
+
undercutter `by`."""
|
|
123
|
+
self._undercuts.append((source, by))
|
|
124
|
+
return by
|
|
125
|
+
|
|
126
|
+
def compile(self) -> "BayesianNetwork":
|
|
127
|
+
"""Lower this argument (as the target) to a core `BayesianNetwork`."""
|
|
128
|
+
from .compile import compile_argument
|
|
129
|
+
|
|
130
|
+
return compile_argument(self)
|
|
131
|
+
|
|
132
|
+
def assemble(self) -> "Argument":
|
|
133
|
+
"""Bundle this argument (as the target) into an `Argument` handle: the home
|
|
134
|
+
for serialization (`to_json` / `save`), the compiled network (`.bn`), and
|
|
135
|
+
the argument-level metrics. Authoring is unchanged; this just hands back an
|
|
136
|
+
object to serialize and measure from."""
|
|
137
|
+
from .handle import Argument
|
|
138
|
+
|
|
139
|
+
return Argument(self)
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def bn(self) -> "BayesianNetwork":
|
|
143
|
+
"""The compiled network for this argument (a fresh compile each access)."""
|
|
144
|
+
return self.compile()
|
|
145
|
+
|
|
146
|
+
def render(self, **kwargs):
|
|
147
|
+
"""Draw the argument at the argument level (premises, conclusions, and
|
|
148
|
+
attacks), hiding the lowered splice machinery. See `aspic.visualization`.
|
|
149
|
+
For the full compiled network instead, use `self.bn.render()`."""
|
|
150
|
+
from .visualization import render_argument
|
|
151
|
+
|
|
152
|
+
return render_argument(self, **kwargs)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class Premise(_Claim):
|
|
156
|
+
"""An asserted proposition: a leaf with a credence and no incoming argument.
|
|
157
|
+
Underminable. `prior` is its credence."""
|
|
158
|
+
|
|
159
|
+
role = "premise"
|
|
160
|
+
|
|
161
|
+
def __init__(self, name: str, prior: float = 0.5):
|
|
162
|
+
super().__init__(name, prior=prior)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class Axiom(Premise):
|
|
166
|
+
"""A premise held with certainty (`prior = 1`) that cannot be undermined."""
|
|
167
|
+
|
|
168
|
+
role = "axiom"
|
|
169
|
+
|
|
170
|
+
def __init__(self, name: str):
|
|
171
|
+
super().__init__(name, prior=1.0)
|
|
172
|
+
self._no_undermine = True
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class Conclusion(_Claim):
|
|
176
|
+
"""A proposition argued for: it carries a prior (default `0.5`) and accumulates
|
|
177
|
+
support, rebut, and strict arguments."""
|
|
178
|
+
|
|
179
|
+
role = "conclusion"
|
|
180
|
+
|
|
181
|
+
def __init__(self, name: str, prior: float = 0.5):
|
|
182
|
+
super().__init__(name, prior=prior)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""JAX calibration of a built argument (optional — needs the `[jax]` extra).
|
|
2
|
+
|
|
3
|
+
The generator's structure choices (depth, branching, edge kinds) are discrete and
|
|
4
|
+
non-differentiable, so the core stays plain numpy. But for a *fixed* structure the
|
|
5
|
+
continuous parameters — node priors and edge likelihood ratios — are smooth, and a
|
|
6
|
+
generated graph is a polytree, so its root posterior is the topo-ordered forward
|
|
7
|
+
pass of the same message math. This module expresses that forward pass in JAX, so:
|
|
8
|
+
|
|
9
|
+
- `sensitivities(arg)` returns `d p_root / d parameter` for every prior and LR (via
|
|
10
|
+
`jax.grad`) — a per-parameter importance signal; and
|
|
11
|
+
- `calibrate_posterior(arg, target)` solves the root's prior and edge LRs for a
|
|
12
|
+
target posterior by projected gradient descent and writes them back. This is the
|
|
13
|
+
multi-parameter upgrade of the dependency-free bisection in `generate` (which
|
|
14
|
+
tunes the root prior alone); it keeps every other subtree's marginals fixed.
|
|
15
|
+
|
|
16
|
+
Defined for the CPD shapes the ASPIC compiler emits on a polytree
|
|
17
|
+
(`IndependentEvidence`, the strict noisy-OR, the undercut AND-NOT splice). A
|
|
18
|
+
non-polytree (shared / non-disjoint structure) is out of scope — the forward
|
|
19
|
+
marginal pass is only exact on a tree.
|
|
20
|
+
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import math
|
|
24
|
+
from typing import TYPE_CHECKING
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from .handle import Argument
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _require_jax():
|
|
31
|
+
try:
|
|
32
|
+
import jax
|
|
33
|
+
import jax.numpy as jnp
|
|
34
|
+
except ImportError as exc: # pragma: no cover - exercised only without the extra
|
|
35
|
+
raise ImportError(
|
|
36
|
+
"calibrate needs JAX; install the optional extra: "
|
|
37
|
+
"pip install 'probability-flow[jax]'"
|
|
38
|
+
) from exc
|
|
39
|
+
return jax, jnp
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _logit(p: float) -> float:
|
|
43
|
+
p = min(1 - 1e-6, max(1e-6, p))
|
|
44
|
+
return math.log(p / (1 - p))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _is_and_not_splice(cpd) -> bool:
|
|
48
|
+
from ..core import TabularCPD
|
|
49
|
+
|
|
50
|
+
if not isinstance(cpd, TabularCPD) or len(cpd.inputs) != 2:
|
|
51
|
+
return False
|
|
52
|
+
import numpy as np
|
|
53
|
+
|
|
54
|
+
p1 = np.exp(cpd.log_table[..., 1])
|
|
55
|
+
return np.allclose(p1, np.array([[0.0, 0.0], [1.0, 0.0]]))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _combos(k: int):
|
|
59
|
+
import numpy as np
|
|
60
|
+
|
|
61
|
+
return np.array([[(c >> i) & 1 for i in range(k)] for c in range(2**k)], dtype=float)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _forward(arg: "Argument"):
|
|
65
|
+
"""Build the polytree forward pass. Returns `(p_root, theta0, meta)` where
|
|
66
|
+
`p_root(theta)` is a JAX function of the flat parameter vector, `theta0` are the
|
|
67
|
+
current values (logit priors, log LRs), and `meta` labels each entry."""
|
|
68
|
+
jax, jnp = _require_jax()
|
|
69
|
+
from ..core import IndependentEvidenceCPD, NoisyOrCPD
|
|
70
|
+
|
|
71
|
+
bn = arg.bn
|
|
72
|
+
nodes = list(bn.nodes) # topological (inputs before node)
|
|
73
|
+
cpd_of = {n: bn.compiled_cpd(n).cpd for n in nodes}
|
|
74
|
+
ins_of = {n: tuple(bn.compiled_cpd(n).inputs) for n in nodes}
|
|
75
|
+
|
|
76
|
+
prior_idx, lr_idx, theta0, meta = {}, {}, [], []
|
|
77
|
+
for n in nodes:
|
|
78
|
+
cpd = cpd_of[n]
|
|
79
|
+
if isinstance(cpd, IndependentEvidenceCPD):
|
|
80
|
+
prior_idx[n] = len(theta0)
|
|
81
|
+
theta0.append(_logit(n.prior))
|
|
82
|
+
meta.append(("prior", n.name))
|
|
83
|
+
for inp, lr in zip(cpd.inputs, cpd.lrs):
|
|
84
|
+
lr_idx[(n, inp)] = len(theta0)
|
|
85
|
+
theta0.append(math.log(lr))
|
|
86
|
+
meta.append(("lr", (inp.name, n.name)))
|
|
87
|
+
combos = {n: jnp.asarray(_combos(len(ins_of[n]))) for n in nodes
|
|
88
|
+
if isinstance(cpd_of[n], IndependentEvidenceCPD) and ins_of[n]}
|
|
89
|
+
|
|
90
|
+
def p_root(theta):
|
|
91
|
+
m = {}
|
|
92
|
+
for n in nodes:
|
|
93
|
+
cpd, ins = cpd_of[n], ins_of[n]
|
|
94
|
+
if isinstance(cpd, IndependentEvidenceCPD):
|
|
95
|
+
b = theta[prior_idx[n]]
|
|
96
|
+
if not ins:
|
|
97
|
+
m[n] = jax.nn.sigmoid(b)
|
|
98
|
+
else:
|
|
99
|
+
w = jnp.stack([theta[lr_idx[(n, i)]] for i in ins])
|
|
100
|
+
pin = jnp.stack([m[i] for i in ins])
|
|
101
|
+
C = combos[n]
|
|
102
|
+
wts = jnp.prod(jnp.where(C == 1, pin, 1 - pin), axis=1)
|
|
103
|
+
m[n] = jnp.sum(jax.nn.sigmoid(b + C @ w) * wts)
|
|
104
|
+
elif isinstance(cpd, NoisyOrCPD):
|
|
105
|
+
a = jnp.asarray(cpd.activations)
|
|
106
|
+
pin = jnp.stack([m[i] for i in ins])
|
|
107
|
+
m[n] = 1 - (1 - cpd.leak) * jnp.prod(1 - a * pin)
|
|
108
|
+
elif _is_and_not_splice(cpd):
|
|
109
|
+
e, u = ins
|
|
110
|
+
m[n] = m[e] * (1 - m[u])
|
|
111
|
+
else:
|
|
112
|
+
raise NotImplementedError(
|
|
113
|
+
f"calibrate does not handle {type(cpd).__name__} on {n.name!r}"
|
|
114
|
+
)
|
|
115
|
+
return m[arg.target]
|
|
116
|
+
|
|
117
|
+
return p_root, jnp.asarray(theta0), meta
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def sensitivities(arg: "Argument") -> dict:
|
|
121
|
+
"""`d p_root / d parameter` at the current values, for every prior and LR.
|
|
122
|
+
|
|
123
|
+
Keys are `("prior", node_name)` and `("lr", (source_name, target_name))`; values
|
|
124
|
+
are gradients in the natural unconstrained space (logit prior, log LR). A large
|
|
125
|
+
magnitude marks a parameter the verdict is sensitive to — a manipulability /
|
|
126
|
+
importance signal."""
|
|
127
|
+
jax, _ = _require_jax()
|
|
128
|
+
p_root, theta0, meta = _forward(arg)
|
|
129
|
+
grad = jax.grad(p_root)(theta0)
|
|
130
|
+
return {key: float(g) for key, g in zip(meta, grad)}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def calibrate_posterior(arg: "Argument", target: float, *,
|
|
134
|
+
steps: int = 500, step_size: float = 0.3) -> float:
|
|
135
|
+
"""Solve the root's prior and edge LRs so its posterior hits `target`, by
|
|
136
|
+
projected gradient descent (priors kept in `[0.02, 0.98]`, LRs kept to their
|
|
137
|
+
original sign with `|log lr| <= log 50`), then write the result back onto the
|
|
138
|
+
argument. Every other subtree's marginals are held fixed. Returns the achieved
|
|
139
|
+
posterior. Use the bisection in `generate` for the dependency-free single-knob
|
|
140
|
+
version."""
|
|
141
|
+
jax, jnp = _require_jax()
|
|
142
|
+
from ..core import IndependentEvidenceCPD, LoopySolver
|
|
143
|
+
|
|
144
|
+
bn = arg.bn
|
|
145
|
+
root = arg.target
|
|
146
|
+
cpd = bn.compiled_cpd(root).cpd
|
|
147
|
+
if not isinstance(cpd, IndependentEvidenceCPD):
|
|
148
|
+
raise NotImplementedError("calibrate_posterior targets a defeasible root")
|
|
149
|
+
ins = list(cpd.inputs)
|
|
150
|
+
solver = LoopySolver(bn)
|
|
151
|
+
pin = jnp.asarray([solver.prob(i, 1) for i in ins]) # fixed input marginals
|
|
152
|
+
k = len(ins)
|
|
153
|
+
C = jnp.asarray(_combos(k)) if k else None
|
|
154
|
+
w0 = jnp.asarray([math.log(lr) for lr in cpd.lrs])
|
|
155
|
+
signs = jnp.sign(w0)
|
|
156
|
+
lo_b, hi_b = _logit(0.02), _logit(0.98)
|
|
157
|
+
lo_w, hi_w = math.log(1.01), math.log(50.0)
|
|
158
|
+
|
|
159
|
+
def predict(theta):
|
|
160
|
+
b, w = theta[0], theta[1:]
|
|
161
|
+
if k == 0:
|
|
162
|
+
return jax.nn.sigmoid(b)
|
|
163
|
+
wts = jnp.prod(jnp.where(C == 1, pin, 1 - pin), axis=1)
|
|
164
|
+
return jnp.sum(jax.nn.sigmoid(b + C @ w) * wts)
|
|
165
|
+
|
|
166
|
+
grad = jax.grad(lambda th: (predict(th) - target) ** 2)
|
|
167
|
+
theta = jnp.concatenate([jnp.asarray([_logit(root.prior)]), w0])
|
|
168
|
+
for _ in range(steps):
|
|
169
|
+
theta = theta - step_size * grad(theta)
|
|
170
|
+
b = jnp.clip(theta[0], lo_b, hi_b)
|
|
171
|
+
w = signs * jnp.clip(jnp.abs(theta[1:]), lo_w, hi_w)
|
|
172
|
+
theta = jnp.concatenate([jnp.asarray([b]), w])
|
|
173
|
+
|
|
174
|
+
root.prior = float(jax.nn.sigmoid(theta[0]))
|
|
175
|
+
new_lrs = [float(math.exp(x)) for x in theta[1:]]
|
|
176
|
+
for j, (src, _lr, kind) in enumerate(root._edges): # aligned with cpd.inputs
|
|
177
|
+
root._edges[j] = (src, new_lrs[j], kind)
|
|
178
|
+
return float(predict(theta))
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Lowering an ASPIC argument to a core `BayesianNetwork`. See `docs/aspic.md`.
|
|
2
|
+
|
|
3
|
+
`compile_argument(target)` collects everything reachable from the target, runs
|
|
4
|
+
validation (raising on sign / axiom errors, warning on soft slips), lowers each
|
|
5
|
+
claim to core CPDs, then defers to the core compiler. The lowerings:
|
|
6
|
+
|
|
7
|
+
- premise / conclusion with only defeasible edges -> `IndependentEvidenceCPD`
|
|
8
|
+
(a premise is just the zero-support case; an undermined premise gains an
|
|
9
|
+
`lr < 1` input).
|
|
10
|
+
- a conclusion with strict edges -> parent-divorcing: a hidden `D` carries the
|
|
11
|
+
defeasible part (or just the prior), and the conclusion becomes a leak-0
|
|
12
|
+
`NoisyOr` over its strict inputs and `D`, each at activation 1.
|
|
13
|
+
- an undercut on the `source -> C` edge -> a spliced `X = source AND NOT U`
|
|
14
|
+
(`tau = 1`) replacing `source` on that edge, with multiple undercutters first
|
|
15
|
+
combined by an upstream leak-0 `NoisyOr` into one `U`.
|
|
16
|
+
"""
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
from ..core import BayesianNetwork, IndependentEvidenceCPD, Node, TabularCPD
|
|
22
|
+
from .argument import ArgumentWarning, _Claim
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def compile_argument(target: "_Claim") -> "BayesianNetwork":
|
|
26
|
+
claims = _reachable(target)
|
|
27
|
+
_validate(claims)
|
|
28
|
+
for c in claims:
|
|
29
|
+
if isinstance(c, _Claim):
|
|
30
|
+
_lower(c)
|
|
31
|
+
return BayesianNetwork.from_nodes([target])
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# --- traversal -------------------------------------------------------------
|
|
35
|
+
|
|
36
|
+
def _upstream(c: "Node") -> list["Node"]:
|
|
37
|
+
"""The declared upstream sources of a claim: edge and strict sources plus any
|
|
38
|
+
undercutters. Plain (non-claim) nodes are treated as leaves."""
|
|
39
|
+
out: list["Node"] = []
|
|
40
|
+
out.extend(src for src, _lr, _kind in getattr(c, "_edges", []))
|
|
41
|
+
out.extend(getattr(c, "_strict", []))
|
|
42
|
+
out.extend(by for _source, by in getattr(c, "_undercuts", []))
|
|
43
|
+
return out
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _reachable(target: "Node") -> list["Node"]:
|
|
47
|
+
"""Every node reachable from the target through declared argument edges, in a
|
|
48
|
+
deterministic, dependency-respecting order (sources before the nodes using
|
|
49
|
+
them is not required here; the core compiler re-orders)."""
|
|
50
|
+
seen: dict[int, "Node"] = {}
|
|
51
|
+
order: list["Node"] = []
|
|
52
|
+
stack = [target]
|
|
53
|
+
while stack:
|
|
54
|
+
c = stack.pop()
|
|
55
|
+
if id(c) in seen:
|
|
56
|
+
continue
|
|
57
|
+
seen[id(c)] = c
|
|
58
|
+
order.append(c)
|
|
59
|
+
stack.extend(_upstream(c))
|
|
60
|
+
return order
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _leaf_premises(c: "Node") -> set["Node"]:
|
|
64
|
+
"""The leaf premises (no incoming argument) upstream of `c`, for the undercut
|
|
65
|
+
disjointness check."""
|
|
66
|
+
out: set["Node"] = set()
|
|
67
|
+
seen: set[int] = set()
|
|
68
|
+
stack = [c]
|
|
69
|
+
while stack:
|
|
70
|
+
n = stack.pop()
|
|
71
|
+
if id(n) in seen:
|
|
72
|
+
continue
|
|
73
|
+
seen.add(id(n))
|
|
74
|
+
ups = _upstream(n)
|
|
75
|
+
if not ups:
|
|
76
|
+
out.add(n)
|
|
77
|
+
else:
|
|
78
|
+
stack.extend(ups)
|
|
79
|
+
return out
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# --- validation ------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
def _validate(claims: list["Node"]) -> None:
|
|
85
|
+
for c in claims:
|
|
86
|
+
if not isinstance(c, _Claim):
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
if c.role == "conclusion" and not c._edges and not c._strict:
|
|
90
|
+
warnings.warn(
|
|
91
|
+
f"{c.name!r} is a conclusion with no incoming argument; a leaf "
|
|
92
|
+
"should be a Premise.",
|
|
93
|
+
ArgumentWarning,
|
|
94
|
+
stacklevel=2,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
edge_sources = {id(src) for src, _lr, _kind in c._edges}
|
|
98
|
+
strict_sources = {id(s) for s in c._strict}
|
|
99
|
+
for source, by in c._undercuts:
|
|
100
|
+
if id(source) not in edge_sources and id(source) not in strict_sources:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"undercut on {c.name!r}: {source.name!r} is not a source of any "
|
|
103
|
+
"edge into it; undercut attacks an existing inference"
|
|
104
|
+
)
|
|
105
|
+
shared = _leaf_premises(source) & _leaf_premises(by)
|
|
106
|
+
if shared:
|
|
107
|
+
names = ", ".join(sorted(p.name for p in shared))
|
|
108
|
+
warnings.warn(
|
|
109
|
+
f"undercut on the {source.name!r} -> {c.name!r} edge shares "
|
|
110
|
+
f"premise(s) {{{names}}} with its undercutter {by.name!r}; the "
|
|
111
|
+
"splice is exact only for disjoint branches.",
|
|
112
|
+
ArgumentWarning,
|
|
113
|
+
stacklevel=2,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# --- lowering --------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
def _and_not_cpd(x: "Node", e: "Node", u: "Node") -> "TabularCPD":
|
|
120
|
+
"""`X = 1` iff `E = 1` and `U = 0`; the `tau = 1` undercut splice."""
|
|
121
|
+
# table[e][u] = P(X | e, u); only (E=1, U=0) yields X=1.
|
|
122
|
+
table = [
|
|
123
|
+
[[1.0, 0.0], [1.0, 0.0]], # E=0: X=0 regardless of U
|
|
124
|
+
[[0.0, 1.0], [1.0, 0.0]], # E=1: U=0 -> X=1, U=1 -> X=0
|
|
125
|
+
]
|
|
126
|
+
return TabularCPD(x, [e, u], table=table)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _splice_for(c: "_Claim") -> dict[int, "Node"]:
|
|
130
|
+
"""Build the undercut splice node for each attacked source on `c`, keyed by the
|
|
131
|
+
source's id. Multiple undercutters on one edge are OR-combined into one `U`."""
|
|
132
|
+
grouped: dict[int, tuple["Node", list["Node"]]] = {}
|
|
133
|
+
for source, u in c._undercuts:
|
|
134
|
+
grouped.setdefault(id(source), (source, []))[1].append(u)
|
|
135
|
+
|
|
136
|
+
splice: dict[int, "Node"] = {}
|
|
137
|
+
for sid, (source, undercutters) in grouped.items():
|
|
138
|
+
if len(undercutters) == 1:
|
|
139
|
+
u_node = undercutters[0]
|
|
140
|
+
else:
|
|
141
|
+
u_node = Node(f"{c.name}/U[{source.name}]")
|
|
142
|
+
u_node.noisy_or(leak=0.0)
|
|
143
|
+
for u in undercutters:
|
|
144
|
+
u_node.add_input(u, activation=1.0)
|
|
145
|
+
x = Node(f"{source.name} (spliced)")
|
|
146
|
+
x.set_cpd(_and_not_cpd(x, source, u_node))
|
|
147
|
+
splice[sid] = x
|
|
148
|
+
return splice
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _lower(c: "_Claim") -> None:
|
|
152
|
+
"""Set `c`'s core CPD (and any hidden helper nodes) from its declared edges.
|
|
153
|
+
Idempotent: the CPD is rebuilt from scratch, so re-compiling is safe."""
|
|
154
|
+
c._cpd = IndependentEvidenceCPD(c) # reset for a clean, repeatable lowering
|
|
155
|
+
|
|
156
|
+
splice = _splice_for(c)
|
|
157
|
+
|
|
158
|
+
def resolve(src: "Node") -> "Node":
|
|
159
|
+
return splice.get(id(src), src)
|
|
160
|
+
|
|
161
|
+
defeasible = [(resolve(src), lr) for src, lr, _kind in c._edges]
|
|
162
|
+
strict = [resolve(src) for src in c._strict]
|
|
163
|
+
|
|
164
|
+
if strict:
|
|
165
|
+
# parent-divorcing: D holds the defeasible part (or just the prior).
|
|
166
|
+
d = Node(f"{c.name}/defeasible", prior=c.prior)
|
|
167
|
+
for src, lr in defeasible:
|
|
168
|
+
d.add_input(src, lr=lr)
|
|
169
|
+
c.noisy_or(leak=0.0)
|
|
170
|
+
for src in strict:
|
|
171
|
+
c.add_input(src, activation=1.0)
|
|
172
|
+
c.add_input(d, activation=1.0)
|
|
173
|
+
else:
|
|
174
|
+
for src, lr in defeasible:
|
|
175
|
+
c.add_input(src, lr=lr)
|