probability-flow 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {probability_flow-0.2.0 → probability_flow-0.4.0}/.gitignore +3 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/PKG-INFO +9 -1
- {probability_flow-0.2.0 → probability_flow-0.4.0}/README.md +7 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/__init__.py +4 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/__init__.py +3 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/argument.py +110 -15
- probability_flow-0.4.0/probability_flow/aspic/benchmark.py +254 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/calibrate.py +7 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/compile.py +39 -9
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/generate.py +177 -4
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/handle.py +58 -10
- probability_flow-0.4.0/probability_flow/aspic/optimize.py +514 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/visualization.py +42 -5
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/__init__.py +4 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/bp/engine.py +1 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/__init__.py +4 -1
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/base.py +8 -0
- probability_flow-0.4.0/probability_flow/core/cpd/correlated_evidence.py +243 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/independent_evidence.py +13 -9
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_and.py +2 -2
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_or.py +2 -2
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/exact.py +1 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/network.py +41 -3
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/node.py +21 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/__init__.py +8 -1
- probability_flow-0.4.0/probability_flow/metrics/manipulability.py +474 -0
- probability_flow-0.4.0/probability_flow/visualization/animate.py +210 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/visualization/image.py +113 -25
- probability_flow-0.4.0/probability_flow/visualization/layout.py +423 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/pyproject.toml +17 -2
- probability_flow-0.2.0/probability_flow/metrics/manipulability.py +0 -207
- probability_flow-0.2.0/probability_flow/visualization/layout.py +0 -166
- {probability_flow-0.2.0 → probability_flow-0.4.0}/LICENSE +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/_logmath.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/bp/__init__.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/bp/message.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/tabular.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/_util.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/difficulty.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/dseparation.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/loopiness.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/structure.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/py.typed +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/visualization/__init__.py +0 -0
- {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/visualization/style.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: probability-flow
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A from-scratch, modular discrete Bayesian-network library.
|
|
5
5
|
Project-URL: Homepage, https://github.com/scalable-oversight-benchmarks/probability-flow
|
|
6
6
|
Project-URL: Repository, https://github.com/scalable-oversight-benchmarks/probability-flow
|
|
@@ -28,6 +28,7 @@ Requires-Dist: pytest; extra == 'dev'
|
|
|
28
28
|
Requires-Dist: ruff; extra == 'dev'
|
|
29
29
|
Provides-Extra: jax
|
|
30
30
|
Requires-Dist: jax; extra == 'jax'
|
|
31
|
+
Requires-Dist: scipy; extra == 'jax'
|
|
31
32
|
Provides-Extra: viz
|
|
32
33
|
Requires-Dist: matplotlib; extra == 'viz'
|
|
33
34
|
Description-Content-Type: text/markdown
|
|
@@ -118,6 +119,13 @@ combiner) from the *object* that implements it (the CPD).
|
|
|
118
119
|
sources of evidence: `logit P(node=1) = logit(prior) + sum of log(lr)` over the
|
|
119
120
|
active inputs. Adding weights of evidence is Bayes' rule for independent
|
|
120
121
|
likelihood ratios. Set per edge with `add_input(x, lr=...)`.
|
|
122
|
+
- **CorrelatedEvidenceCPD**. Independent evidence plus pairwise couplings, for
|
|
123
|
+
*redundant* inputs the additive rule would otherwise double-count (two reports of
|
|
124
|
+
one fact, two clues from a shared cause): it adds a pairwise term
|
|
125
|
+
`+ sum J_ij s_i s_j` (a negative `J` makes two inputs sub-additive when both fire,
|
|
126
|
+
a positive one synergistic). The coupling lives inside the single CPD factor, so
|
|
127
|
+
it adds no edge to the graph and no loop to the solver, and stays a valid
|
|
128
|
+
distribution for any real `J`.
|
|
121
129
|
- **NoisyOrCPD**. "Any one cause can fire the effect":
|
|
122
130
|
`P(node=0) = (1 - leak) * product of (1 - activation)` over present causes.
|
|
123
131
|
Declared with `node.noisy_or(leak=...)` and `add_input(cause, activation=...)`.
|
|
@@ -84,6 +84,13 @@ combiner) from the *object* that implements it (the CPD).
|
|
|
84
84
|
sources of evidence: `logit P(node=1) = logit(prior) + sum of log(lr)` over the
|
|
85
85
|
active inputs. Adding weights of evidence is Bayes' rule for independent
|
|
86
86
|
likelihood ratios. Set per edge with `add_input(x, lr=...)`.
|
|
87
|
+
- **CorrelatedEvidenceCPD**. Independent evidence plus pairwise couplings, for
|
|
88
|
+
*redundant* inputs the additive rule would otherwise double-count (two reports of
|
|
89
|
+
one fact, two clues from a shared cause): it adds a pairwise term
|
|
90
|
+
`+ sum J_ij s_i s_j` (a negative `J` makes two inputs sub-additive when both fire,
|
|
91
|
+
a positive one synergistic). The coupling lives inside the single CPD factor, so
|
|
92
|
+
it adds no edge to the graph and no loop to the solver, and stays a valid
|
|
93
|
+
distribution for any real `J`.
|
|
87
94
|
- **NoisyOrCPD**. "Any one cause can fire the effect":
|
|
88
95
|
`P(node=0) = (1 - leak) * product of (1 - activation)` over present causes.
|
|
89
96
|
Declared with `node.noisy_or(leak=...)` and `add_input(cause, activation=...)`.
|
|
@@ -12,12 +12,14 @@ from .core import (
|
|
|
12
12
|
CPD,
|
|
13
13
|
BayesianNetwork,
|
|
14
14
|
CompiledCPD,
|
|
15
|
+
CorrelatedEvidenceCPD,
|
|
15
16
|
ExactSolver,
|
|
16
17
|
IndependentEvidenceCPD,
|
|
17
18
|
LoopySolver,
|
|
18
19
|
Node,
|
|
19
20
|
NoisyAndCPD,
|
|
20
21
|
NoisyOrCPD,
|
|
22
|
+
PendingWeightError,
|
|
21
23
|
TabularCPD,
|
|
22
24
|
)
|
|
23
25
|
|
|
@@ -36,6 +38,8 @@ __all__ = [
|
|
|
36
38
|
"CPD",
|
|
37
39
|
"TabularCPD",
|
|
38
40
|
"IndependentEvidenceCPD",
|
|
41
|
+
"CorrelatedEvidenceCPD",
|
|
39
42
|
"NoisyOrCPD",
|
|
40
43
|
"NoisyAndCPD",
|
|
44
|
+
"PendingWeightError",
|
|
41
45
|
]
|
|
@@ -9,13 +9,16 @@ pure-BN `core`. Build an argument out of premises, conclusions, and attacks, the
|
|
|
9
9
|
from .argument import ArgumentWarning, Axiom, Conclusion, Premise
|
|
10
10
|
from .generate import (
|
|
11
11
|
ArgumentGenerator,
|
|
12
|
+
BatchParams,
|
|
12
13
|
DifficultyTargets,
|
|
13
14
|
StructuralParams,
|
|
14
15
|
generate,
|
|
16
|
+
generate_batch,
|
|
15
17
|
)
|
|
16
18
|
from .handle import Argument
|
|
17
19
|
|
|
18
20
|
__all__ = [
|
|
19
21
|
"Premise", "Axiom", "Conclusion", "ArgumentWarning", "Argument",
|
|
20
22
|
"ArgumentGenerator", "StructuralParams", "DifficultyTargets", "generate",
|
|
23
|
+
"BatchParams", "generate_batch",
|
|
21
24
|
]
|
|
@@ -41,9 +41,17 @@ class _Claim(Node):
|
|
|
41
41
|
|
|
42
42
|
def __init__(self, name: str, prior: float = 0.5):
|
|
43
43
|
super().__init__(name, prior=prior)
|
|
44
|
-
self._edges: list[tuple["Node", float, str]] = [] # (src, lr, kind)
|
|
44
|
+
self._edges: list[tuple["Node", "float | None", str]] = [] # (src, lr, kind)
|
|
45
45
|
self._strict: list["Node"] = [] # strict sources
|
|
46
46
|
self._undercuts: list[tuple["Node", "Node"]] = [] # (attacked source, undercutter)
|
|
47
|
+
# (source_a, source_b, J): residual correlation between two defeasible
|
|
48
|
+
# sources of this conclusion, lowered to a CorrelatedEvidenceCPD coupling.
|
|
49
|
+
self._correlations: list[tuple["Node", "Node", "float | None"]] = []
|
|
50
|
+
# (sources, lr, leak, rule): a conjunctive support group — the sources
|
|
51
|
+
# JOINTLY support this conclusion through one noisy-AND inference (the
|
|
52
|
+
# deduction rule), lowered to a hidden NoisyAnd gate fed in as one lr.
|
|
53
|
+
self._conjunctions: list[
|
|
54
|
+
tuple[list["Node"], "float | None", float, "str | None"]] = []
|
|
47
55
|
self._no_undermine = False
|
|
48
56
|
# opaque, serialized but uninterpreted by the library (see docs/aspic.md):
|
|
49
57
|
self.desc: str | None = None # a longer description
|
|
@@ -53,8 +61,13 @@ class _Claim(Node):
|
|
|
53
61
|
"""At most one argumentative edge joins a pair of claims. This is what lets
|
|
54
62
|
an undercut be addressed by its endpoints and keeps serialized edge ids
|
|
55
63
|
(`source->target`) unique."""
|
|
56
|
-
|
|
57
|
-
s is src for
|
|
64
|
+
in_conj = any(
|
|
65
|
+
s is src for srcs, _lr, _leak, _rule in self._conjunctions for s in srcs
|
|
66
|
+
)
|
|
67
|
+
if (
|
|
68
|
+
any(s is src for s, _lr, _kind in self._edges)
|
|
69
|
+
or any(s is src for s in self._strict)
|
|
70
|
+
or in_conj
|
|
58
71
|
):
|
|
59
72
|
raise ValueError(
|
|
60
73
|
f"{self.name!r} already has an edge from {src.name!r}; at most one "
|
|
@@ -63,21 +76,69 @@ class _Claim(Node):
|
|
|
63
76
|
|
|
64
77
|
# --- defeasible and strict support (methods on the downstream conclusion) --
|
|
65
78
|
|
|
66
|
-
def support(self, src: "Node", lr: float) -> "Node":
|
|
79
|
+
def support(self, src: "Node", lr: "float | None" = None) -> "Node":
|
|
67
80
|
"""Add a defeasible argument *for* this conclusion (`lr > 1`). Returns
|
|
68
81
|
`src`, so an inline source can be built further upstream, as with core
|
|
69
|
-
`add_input`.
|
|
70
|
-
|
|
82
|
+
`add_input`. `lr` may be omitted (`None`) to declare the edge before its
|
|
83
|
+
strength is known — the topology compiles and renders (dashed), but the
|
|
84
|
+
network cannot be solved until every weight is assigned."""
|
|
85
|
+
if lr is not None and not lr > 1:
|
|
71
86
|
raise ValueError(
|
|
72
87
|
f"support lr must be > 1 (got {lr}); use rebut for an argument against"
|
|
73
88
|
)
|
|
74
89
|
self._require_new_edge(src)
|
|
75
|
-
self._edges.append((src, float(lr), "support"))
|
|
90
|
+
self._edges.append((src, None if lr is None else float(lr), "support"))
|
|
76
91
|
return src
|
|
77
92
|
|
|
78
|
-
def
|
|
79
|
-
|
|
80
|
-
|
|
93
|
+
def support_all(
|
|
94
|
+
self,
|
|
95
|
+
srcs: "list[Node]",
|
|
96
|
+
lr: "float | None" = None,
|
|
97
|
+
leak: float = 0.0,
|
|
98
|
+
rule: "str | None" = None,
|
|
99
|
+
) -> "list[Node]":
|
|
100
|
+
"""Conjunctive support: `srcs` JOINTLY support this conclusion through a
|
|
101
|
+
single noisy-AND inference — the deduction fires only when *all* the sources
|
|
102
|
+
hold. This is the faithful mapping of a deductive entailment step: N
|
|
103
|
+
antecedent siblings plus one inference rule yield one deduced conclusion,
|
|
104
|
+
the shape of every node in a MuSR-style reasoning tree. Returns `srcs`, so
|
|
105
|
+
inline sources can be built further upstream (as with `support`).
|
|
106
|
+
|
|
107
|
+
Lowered (see `compile.py`) to a hidden `NoisyAnd` gate over `srcs`, each at
|
|
108
|
+
activation 1 — the gate is the logical AND, firing iff every source holds —
|
|
109
|
+
fed into this conclusion as a single `lr` support. So all the evidential
|
|
110
|
+
strength is the one `lr > 1` (where the inference rule's re-elicited weight
|
|
111
|
+
lands), and the conclusion keeps its prior and composes with ordinary
|
|
112
|
+
`support` / `rebut` edges. `leak` is the rule's fallibility:
|
|
113
|
+
`P(fire | all sources hold) = 1 - leak`. `rule` is an optional label for the
|
|
114
|
+
inference licence (carried for audit / rendering, not interpreted). `lr` may
|
|
115
|
+
be omitted (`None`) to scaffold the group before its strength is known.
|
|
116
|
+
"""
|
|
117
|
+
srcs = list(srcs)
|
|
118
|
+
if len(srcs) < 2:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"support_all needs >= 2 sources (it is a conjunction); use support "
|
|
121
|
+
"for a single argument"
|
|
122
|
+
)
|
|
123
|
+
if len({id(s) for s in srcs}) != len(srcs):
|
|
124
|
+
raise ValueError("support_all sources must be distinct")
|
|
125
|
+
if lr is not None and not lr > 1:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"support_all lr must be > 1 (got {lr}); it is conjunctive support"
|
|
128
|
+
)
|
|
129
|
+
if not 0.0 <= leak <= 1.0:
|
|
130
|
+
raise ValueError("leak must be a probability in [0, 1]")
|
|
131
|
+
for src in srcs:
|
|
132
|
+
self._require_new_edge(src)
|
|
133
|
+
self._conjunctions.append(
|
|
134
|
+
(srcs, None if lr is None else float(lr), float(leak), rule)
|
|
135
|
+
)
|
|
136
|
+
return srcs
|
|
137
|
+
|
|
138
|
+
def rebut(self, src: "Node", lr: "float | None" = None) -> "Node":
|
|
139
|
+
"""Add a defeasible argument *against* this conclusion (`0 < lr < 1`). `lr`
|
|
140
|
+
may be omitted (`None`) to declare the edge before its strength is known."""
|
|
141
|
+
if lr is not None and not 0 < lr < 1:
|
|
81
142
|
raise ValueError(
|
|
82
143
|
f"rebut lr must be in (0, 1) (got {lr}); use support for an argument for"
|
|
83
144
|
)
|
|
@@ -89,7 +150,7 @@ class _Claim(Node):
|
|
|
89
150
|
stacklevel=2,
|
|
90
151
|
)
|
|
91
152
|
self._require_new_edge(src)
|
|
92
|
-
self._edges.append((src, float(lr), "rebut"))
|
|
153
|
+
self._edges.append((src, None if lr is None else float(lr), "rebut"))
|
|
93
154
|
return src
|
|
94
155
|
|
|
95
156
|
def strict(self, src: "Node") -> "Node":
|
|
@@ -101,18 +162,19 @@ class _Claim(Node):
|
|
|
101
162
|
|
|
102
163
|
# --- attacks (also methods on the downstream node) ------------------------
|
|
103
164
|
|
|
104
|
-
def undermine(self, by: "Node", lr: float) -> "Node":
|
|
165
|
+
def undermine(self, by: "Node", lr: "float | None" = None) -> "Node":
|
|
105
166
|
"""Attack this premise with `by` (`0 < lr < 1`). Mechanically a rebut into
|
|
106
167
|
the attacked node, named distinctly because it is the ASPIC-correct verb
|
|
107
|
-
for attacking a premise. An axiom cannot be undermined.
|
|
168
|
+
for attacking a premise. An axiom cannot be undermined. `lr` may be omitted
|
|
169
|
+
(`None`) to declare the edge before its strength is known."""
|
|
108
170
|
if self._no_undermine:
|
|
109
171
|
raise ValueError(
|
|
110
172
|
f"cannot undermine {self.name!r}: an axiom has no defeasible premise to attack"
|
|
111
173
|
)
|
|
112
|
-
if not 0 < lr < 1:
|
|
174
|
+
if lr is not None and not 0 < lr < 1:
|
|
113
175
|
raise ValueError(f"undermine lr must be in (0, 1) (got {lr})")
|
|
114
176
|
self._require_new_edge(by)
|
|
115
|
-
self._edges.append((by, float(lr), "undermine"))
|
|
177
|
+
self._edges.append((by, None if lr is None else float(lr), "undermine"))
|
|
116
178
|
return by
|
|
117
179
|
|
|
118
180
|
def undercut(self, source: "Node", by: "Node") -> "Node":
|
|
@@ -123,6 +185,39 @@ class _Claim(Node):
|
|
|
123
185
|
self._undercuts.append((source, by))
|
|
124
186
|
return by
|
|
125
187
|
|
|
188
|
+
def correlate(self, source_a: "Node", source_b: "Node",
|
|
189
|
+
coupling: "float | None" = None) -> "_Claim":
|
|
190
|
+
"""Declare residual correlation between two **defeasible sources** of this
|
|
191
|
+
conclusion — redundancy (two reports of one fact, a shared cause) or
|
|
192
|
+
synergy. Lowered to a pairwise `CorrelatedEvidenceCPD` coupling `J`:
|
|
193
|
+
negative discounts the pair (redundant / sub-additive), positive boosts it.
|
|
194
|
+
`J` is the engine-native quantity (a caller with a correlation or a pair of
|
|
195
|
+
conditional probabilities does the inverse map). `coupling` may be omitted
|
|
196
|
+
to scaffold the pair before its value is chosen. Returns self for chaining.
|
|
197
|
+
|
|
198
|
+
This overrides the core `Node.correlate`: a `_Claim` declares the pairing
|
|
199
|
+
against its argument sources here, and `compile` lowers it onto whichever
|
|
200
|
+
core node carries the defeasible inputs.
|
|
201
|
+
"""
|
|
202
|
+
sources = {id(s) for s, _lr, kind in self._edges if kind in ("support", "rebut")}
|
|
203
|
+
for x in (source_a, source_b):
|
|
204
|
+
if id(x) not in sources:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"{getattr(x, 'name', x)!r} is not a support/rebut source of "
|
|
207
|
+
f"{self.name!r}; correlate couples two of its defeasible sources"
|
|
208
|
+
)
|
|
209
|
+
if source_a is source_b:
|
|
210
|
+
raise ValueError("cannot correlate a source with itself")
|
|
211
|
+
if any({id(source_a), id(source_b)} == {id(a), id(b)}
|
|
212
|
+
for a, b, _J in self._correlations):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"{source_a.name!r} and {source_b.name!r} are already correlated on "
|
|
215
|
+
f"{self.name!r}"
|
|
216
|
+
)
|
|
217
|
+
self._correlations.append(
|
|
218
|
+
(source_a, source_b, None if coupling is None else float(coupling)))
|
|
219
|
+
return self
|
|
220
|
+
|
|
126
221
|
def compile(self) -> "BayesianNetwork":
|
|
127
222
|
"""Lower this argument (as the target) to a core `BayesianNetwork`."""
|
|
128
223
|
from .compile import compile_argument
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Benchmark helpers: generate → optimize bridge for large-batch posterior targeting.
|
|
2
|
+
|
|
3
|
+
This module isolates the JAX/scipy dependency (via ``optimize``) from the pure-Python
|
|
4
|
+
``generate.py``. The primary entry point is ``generate_with_targets``, which produces
|
|
5
|
+
N argument graphs each tuned to a requested root posterior.
|
|
6
|
+
|
|
7
|
+
Strategy
|
|
8
|
+
--------
|
|
9
|
+
For each (seed, target_posterior) pair:
|
|
10
|
+
|
|
11
|
+
1. Try ``generate(seed, structural, DifficultyTargets(target_posterior=t))`` — this
|
|
12
|
+
bisects the root prior to hit *t* cheaply (no JAX needed).
|
|
13
|
+
2. If the bisection cannot reach *t* for the drawn structure (``RuntimeError`` from
|
|
14
|
+
the rejection loop), fall back to:
|
|
15
|
+
a. Generate an *unconstrained* graph from the same seed.
|
|
16
|
+
b. Check whether *t* is inside ``achievable_interval(arg)``.
|
|
17
|
+
c. If yes, call ``optimize(arg, target_posterior=t)`` to tune LRs.
|
|
18
|
+
3. If neither path succeeds, record the failure (skip / warn / raise per
|
|
19
|
+
``on_failure``).
|
|
20
|
+
|
|
21
|
+
Both paths return an ``Argument`` with identical public interface. The ``optimize``
|
|
22
|
+
fallback is more expensive (JAX JIT + SLSQP) but reaches extreme posteriors that
|
|
23
|
+
bisection cannot.
|
|
24
|
+
|
|
25
|
+
Batch labeling
|
|
26
|
+
--------------
|
|
27
|
+
Each accepted argument is annotated with a ``_benchmark_meta`` attribute (a plain
|
|
28
|
+
dict) carrying the difficulty scalars the benchmark consumer (``debate-eval``) needs:
|
|
29
|
+
|
|
30
|
+
posterior, manipulability, d_sep_count, max_depth_metric,
|
|
31
|
+
upstream_size, circuit_rank, seed, is_exact_manipulability,
|
|
32
|
+
used_optimize_fallback
|
|
33
|
+
|
|
34
|
+
These come entirely from existing metrics — no new computation.
|
|
35
|
+
|
|
36
|
+
JAX/scipy availability
|
|
37
|
+
----------------------
|
|
38
|
+
Both are required for the ``optimize`` fallback path. When they are absent the
|
|
39
|
+
module still imports cleanly; ``generate_with_targets`` will error only if an
|
|
40
|
+
optimize fallback is actually needed. Use ``generate_batch`` (no JAX) when you only
|
|
41
|
+
need bisection-reachable posteriors.
|
|
42
|
+
"""
|
|
43
|
+
from __future__ import annotations
|
|
44
|
+
|
|
45
|
+
import warnings
|
|
46
|
+
from typing import TYPE_CHECKING, Optional
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from .handle import Argument
|
|
50
|
+
|
|
51
|
+
# ---------------------------------------------------------------------------
|
|
52
|
+
# Internal helpers
|
|
53
|
+
# ---------------------------------------------------------------------------
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _label(arg: "Argument", seed: int, used_optimize: bool) -> None:
|
|
57
|
+
"""Attach ``_benchmark_meta`` to *arg* in-place."""
|
|
58
|
+
from ..metrics import (
|
|
59
|
+
circuit_rank,
|
|
60
|
+
d_separated_groups,
|
|
61
|
+
max_depth,
|
|
62
|
+
posterior_range,
|
|
63
|
+
upstream_size,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
bn = arg.bn
|
|
67
|
+
p = arg.posterior(arg.target)
|
|
68
|
+
|
|
69
|
+
# posterior_range: exact=True on polytrees, outer bound on shared graphs.
|
|
70
|
+
# Report manipulability as the [min, max] achievable root posterior — far more
|
|
71
|
+
# informative than the width alone, since it shows *where* on [0,1] the judge can
|
|
72
|
+
# be pushed (e.g. a one-sided graph is pinned to one side of the 0.5 threshold).
|
|
73
|
+
pr = posterior_range(bn, arg.target, exact=True)
|
|
74
|
+
|
|
75
|
+
arg._benchmark_meta = {
|
|
76
|
+
"posterior": float(p),
|
|
77
|
+
"manipulability": [round(float(pr.lo), 4), round(float(pr.hi), 4)],
|
|
78
|
+
"manipulability_width": round(float(pr.hi - pr.lo), 4),
|
|
79
|
+
"d_sep_count": len(d_separated_groups(bn, arg.target)),
|
|
80
|
+
"max_depth_metric": max_depth(bn, arg.target),
|
|
81
|
+
"upstream_size": upstream_size(bn, arg.target),
|
|
82
|
+
"circuit_rank": circuit_rank(bn),
|
|
83
|
+
"seed": seed,
|
|
84
|
+
"is_exact_manipulability": bool(pr.is_exact),
|
|
85
|
+
"used_optimize_fallback": used_optimize,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _try_optimize(
|
|
90
|
+
arg: "Argument",
|
|
91
|
+
target: float,
|
|
92
|
+
posterior_tol: float,
|
|
93
|
+
param_limits: Optional[dict],
|
|
94
|
+
) -> "Optional[Argument]":
|
|
95
|
+
"""Try to hit *target* via ``optimize``; return the new arg or None on failure."""
|
|
96
|
+
try:
|
|
97
|
+
from .optimize import (
|
|
98
|
+
InfeasibleTargetError,
|
|
99
|
+
OptimizeError,
|
|
100
|
+
achievable_interval,
|
|
101
|
+
optimize,
|
|
102
|
+
)
|
|
103
|
+
except ImportError:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
lo, hi = achievable_interval(arg, param_limits=param_limits)
|
|
108
|
+
except Exception:
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
if not (lo - 1e-4 <= target <= hi + 1e-4):
|
|
112
|
+
return None # infeasible for this structure
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
return optimize(arg, target_posterior=target,
|
|
116
|
+
posterior_tol=posterior_tol,
|
|
117
|
+
param_limits=param_limits)
|
|
118
|
+
except (InfeasibleTargetError, OptimizeError):
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# ---------------------------------------------------------------------------
|
|
123
|
+
# Public API
|
|
124
|
+
# ---------------------------------------------------------------------------
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def generate_with_targets(
|
|
128
|
+
n: int,
|
|
129
|
+
structural,
|
|
130
|
+
posterior_targets: "list[float]",
|
|
131
|
+
posterior_tol: float = 0.02,
|
|
132
|
+
param_limits: "Optional[dict]" = None,
|
|
133
|
+
seeds: "Optional[list[int]]" = None,
|
|
134
|
+
max_attempts_per_graph: int = 200,
|
|
135
|
+
n_jobs: int = 1,
|
|
136
|
+
on_failure: str = "skip",
|
|
137
|
+
verbose: bool = False,
|
|
138
|
+
) -> "list[Argument]":
|
|
139
|
+
"""Generate *n* graphs, each tuned to a target root posterior.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
n : int
|
|
144
|
+
Number of graphs to produce.
|
|
145
|
+
structural : StructuralParams
|
|
146
|
+
Structural shape shared across all graphs.
|
|
147
|
+
posterior_targets : list[float]
|
|
148
|
+
Desired root posterior for each graph. Must have length *n*.
|
|
149
|
+
posterior_tol : float
|
|
150
|
+
Accepted absolute deviation from each target posterior (default 0.02).
|
|
151
|
+
param_limits : dict or None
|
|
152
|
+
Passed to ``optimize`` on the fallback path. None uses default boxes.
|
|
153
|
+
seeds : list[int] or None
|
|
154
|
+
Per-graph seeds. Defaults to ``list(range(n))``.
|
|
155
|
+
max_attempts_per_graph : int
|
|
156
|
+
Rejection budget for the primary ``generate`` path.
|
|
157
|
+
n_jobs : int
|
|
158
|
+
Worker count. Currently only 1 is fully supported for the optimize
|
|
159
|
+
fallback (JAX / fork safety, OQ3 in spec); n_jobs > 1 on the primary
|
|
160
|
+
path only is safe.
|
|
161
|
+
on_failure : str
|
|
162
|
+
``"skip"`` (default) | ``"warn"`` | ``"raise"``.
|
|
163
|
+
verbose : bool
|
|
164
|
+
Print progress per graph.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
list[Argument]
|
|
169
|
+
Accepted and tuned arguments. Each carries ``._benchmark_meta``.
|
|
170
|
+
Length <= *n* when ``on_failure != 'raise'``.
|
|
171
|
+
|
|
172
|
+
Notes
|
|
173
|
+
-----
|
|
174
|
+
The fallback threshold (when to call ``optimize`` instead of resampling) is:
|
|
175
|
+
whenever the bisection-only ``generate`` raises ``RuntimeError`` (budget
|
|
176
|
+
exhausted). ``optimize`` is then tried on an *unconstrained* graph from the
|
|
177
|
+
same seed; if the target is achievable in the parameter box, it succeeds.
|
|
178
|
+
"""
|
|
179
|
+
from .generate import DifficultyTargets, generate
|
|
180
|
+
|
|
181
|
+
if len(posterior_targets) != n:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"posterior_targets has {len(posterior_targets)} entries but n={n}"
|
|
184
|
+
)
|
|
185
|
+
if seeds is None:
|
|
186
|
+
seeds = list(range(n))
|
|
187
|
+
if len(seeds) != n:
|
|
188
|
+
raise ValueError(f"seeds has {len(seeds)} entries but n={n}")
|
|
189
|
+
|
|
190
|
+
accepted: list[Argument] = []
|
|
191
|
+
|
|
192
|
+
for i, (seed, t) in enumerate(zip(seeds, posterior_targets)):
|
|
193
|
+
if verbose:
|
|
194
|
+
print(f"[{i + 1}/{n}] seed={seed} target={t:.3f}", end=" ... ")
|
|
195
|
+
|
|
196
|
+
# --- primary path: bisection inside generate() -----------------------
|
|
197
|
+
arg: Optional[Argument] = None
|
|
198
|
+
used_optimize = False
|
|
199
|
+
try:
|
|
200
|
+
arg = generate(
|
|
201
|
+
seed=seed,
|
|
202
|
+
structural=structural,
|
|
203
|
+
targets=DifficultyTargets(
|
|
204
|
+
target_posterior=t,
|
|
205
|
+
posterior_tol=posterior_tol,
|
|
206
|
+
),
|
|
207
|
+
max_attempts=max_attempts_per_graph,
|
|
208
|
+
verbose=False,
|
|
209
|
+
)
|
|
210
|
+
except RuntimeError:
|
|
211
|
+
pass # bisection budget exhausted — try the optimize fallback
|
|
212
|
+
|
|
213
|
+
# --- fallback: generate unconstrained then optimize ------------------
|
|
214
|
+
if arg is None:
|
|
215
|
+
try:
|
|
216
|
+
base_arg = generate(
|
|
217
|
+
seed=seed,
|
|
218
|
+
structural=structural,
|
|
219
|
+
max_attempts=max_attempts_per_graph,
|
|
220
|
+
verbose=False,
|
|
221
|
+
)
|
|
222
|
+
opt_arg = _try_optimize(base_arg, t, posterior_tol, param_limits)
|
|
223
|
+
if opt_arg is not None:
|
|
224
|
+
arg = opt_arg
|
|
225
|
+
used_optimize = True
|
|
226
|
+
except RuntimeError:
|
|
227
|
+
pass
|
|
228
|
+
|
|
229
|
+
# --- failure handling ------------------------------------------------
|
|
230
|
+
if arg is None:
|
|
231
|
+
msg = (
|
|
232
|
+
f"generate_with_targets: seed {seed} could not reach "
|
|
233
|
+
f"target {t:.4f} via bisection or optimize."
|
|
234
|
+
)
|
|
235
|
+
if on_failure == "raise":
|
|
236
|
+
raise RuntimeError(msg)
|
|
237
|
+
if on_failure == "warn":
|
|
238
|
+
warnings.warn(msg, stacklevel=2)
|
|
239
|
+
if verbose:
|
|
240
|
+
print("FAILED")
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
# --- label and collect -----------------------------------------------
|
|
244
|
+
_label(arg, seed=seed, used_optimize=used_optimize)
|
|
245
|
+
accepted.append(arg)
|
|
246
|
+
if verbose:
|
|
247
|
+
m = arg._benchmark_meta
|
|
248
|
+
print(
|
|
249
|
+
f"ok P={m['posterior']:.3f} manip={m['manipulability']:.3f} "
|
|
250
|
+
f"nodes={m['upstream_size'] + 1} "
|
|
251
|
+
f"{'[opt]' if used_optimize else '[bisect]'}"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return accepted
|
|
@@ -67,6 +67,7 @@ def _forward(arg: "Argument"):
|
|
|
67
67
|
current values (logit priors, log LRs), and `meta` labels each entry."""
|
|
68
68
|
jax, jnp = _require_jax()
|
|
69
69
|
from ..core import IndependentEvidenceCPD, NoisyOrCPD
|
|
70
|
+
from ..core.cpd.noisy_and import NoisyAndCPD
|
|
70
71
|
|
|
71
72
|
bn = arg.bn
|
|
72
73
|
nodes = list(bn.nodes) # topological (inputs before node)
|
|
@@ -105,6 +106,12 @@ def _forward(arg: "Argument"):
|
|
|
105
106
|
a = jnp.asarray(cpd.activations)
|
|
106
107
|
pin = jnp.stack([m[i] for i in ins])
|
|
107
108
|
m[n] = 1 - (1 - cpd.leak) * jnp.prod(1 - a * pin)
|
|
109
|
+
elif isinstance(cpd, NoisyAndCPD):
|
|
110
|
+
# NoisyAnd activations are fixed structural constants (all 1.0 from
|
|
111
|
+
# the ASPIC compiler); they are NOT free parameters in theta.
|
|
112
|
+
a = jnp.asarray(cpd.activations)
|
|
113
|
+
pin = jnp.stack([m[i] for i in ins])
|
|
114
|
+
m[n] = (1 - cpd.leak) * jnp.prod(a * pin)
|
|
108
115
|
elif _is_and_not_splice(cpd):
|
|
109
116
|
e, u = ins
|
|
110
117
|
m[n] = m[e] * (1 - m[u])
|
|
@@ -40,6 +40,8 @@ def _upstream(c: "Node") -> list["Node"]:
|
|
|
40
40
|
out.extend(src for src, _lr, _kind in getattr(c, "_edges", []))
|
|
41
41
|
out.extend(getattr(c, "_strict", []))
|
|
42
42
|
out.extend(by for _source, by in getattr(c, "_undercuts", []))
|
|
43
|
+
for srcs, _lr, _leak, _rule in getattr(c, "_conjunctions", []):
|
|
44
|
+
out.extend(srcs)
|
|
43
45
|
return out
|
|
44
46
|
|
|
45
47
|
|
|
@@ -86,7 +88,12 @@ def _validate(claims: list["Node"]) -> None:
|
|
|
86
88
|
if not isinstance(c, _Claim):
|
|
87
89
|
continue
|
|
88
90
|
|
|
89
|
-
if
|
|
91
|
+
if (
|
|
92
|
+
c.role == "conclusion"
|
|
93
|
+
and not c._edges
|
|
94
|
+
and not c._strict
|
|
95
|
+
and not getattr(c, "_conjunctions", [])
|
|
96
|
+
):
|
|
90
97
|
warnings.warn(
|
|
91
98
|
f"{c.name!r} is a conclusion with no incoming argument; a leaf "
|
|
92
99
|
"should be a Premise.",
|
|
@@ -160,16 +167,39 @@ def _lower(c: "_Claim") -> None:
|
|
|
160
167
|
|
|
161
168
|
defeasible = [(resolve(src), lr) for src, lr, _kind in c._edges]
|
|
162
169
|
strict = [resolve(src) for src in c._strict]
|
|
170
|
+
conjunctions = getattr(c, "_conjunctions", [])
|
|
171
|
+
|
|
172
|
+
# `host` carries the defeasible inputs: the conclusion itself, or a hidden `D`
|
|
173
|
+
# when strict edges divorce it (parent-divorcing).
|
|
174
|
+
host = Node(f"{c.name}/defeasible", prior=c.prior) if strict else c
|
|
175
|
+
|
|
176
|
+
for src, lr in defeasible:
|
|
177
|
+
host.add_input(src, lr=lr)
|
|
178
|
+
|
|
179
|
+
# Conjunctive support: one hidden NoisyAnd gate per group, fed into `host` as a
|
|
180
|
+
# single lr support. The gate is the logical AND (each source at activation 1,
|
|
181
|
+
# firing iff all hold); `leak` is the rule's fallibility; `lr` is its strength.
|
|
182
|
+
for k, (srcs, lr, leak, _rule) in enumerate(conjunctions):
|
|
183
|
+
gate = Node(f"{c.name}/conj[{k}]")
|
|
184
|
+
gate.noisy_and(leak=leak)
|
|
185
|
+
for src in srcs:
|
|
186
|
+
gate.add_input(resolve(src), activation=1.0)
|
|
187
|
+
host.add_input(gate, lr=lr)
|
|
188
|
+
|
|
189
|
+
_apply_correlations(c, host, resolve)
|
|
163
190
|
|
|
164
191
|
if strict:
|
|
165
|
-
# parent-divorcing: D holds the defeasible part (or just the prior).
|
|
166
|
-
d = Node(f"{c.name}/defeasible", prior=c.prior)
|
|
167
|
-
for src, lr in defeasible:
|
|
168
|
-
d.add_input(src, lr=lr)
|
|
169
192
|
c.noisy_or(leak=0.0)
|
|
170
193
|
for src in strict:
|
|
171
194
|
c.add_input(src, activation=1.0)
|
|
172
|
-
c.add_input(
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
195
|
+
c.add_input(host, activation=1.0)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _apply_correlations(c: "_Claim", host: "Node", resolve) -> None:
|
|
199
|
+
"""Lower `c`'s declared source correlations onto `host` — the core node that
|
|
200
|
+
actually carries the defeasible inputs (the conclusion itself, or its hidden
|
|
201
|
+
`/defeasible` node when strict edges divorced it). Uses `Node.correlate`
|
|
202
|
+
explicitly: `c` overrides `correlate` with the argument-level declaration, so
|
|
203
|
+
the core coupling must be reached through the base class."""
|
|
204
|
+
for a, b, coupling in getattr(c, "_correlations", []):
|
|
205
|
+
Node.correlate(host, resolve(a), resolve(b), coupling=coupling)
|