probability-flow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- probability_flow/__init__.py +41 -0
- probability_flow/aspic/__init__.py +21 -0
- probability_flow/aspic/argument.py +182 -0
- probability_flow/aspic/calibrate.py +178 -0
- probability_flow/aspic/compile.py +175 -0
- probability_flow/aspic/generate.py +397 -0
- probability_flow/aspic/handle.py +281 -0
- probability_flow/aspic/visualization.py +128 -0
- probability_flow/core/__init__.py +29 -0
- probability_flow/core/_logmath.py +40 -0
- probability_flow/core/bp/__init__.py +5 -0
- probability_flow/core/bp/engine.py +198 -0
- probability_flow/core/bp/message.py +30 -0
- probability_flow/core/cpd/__init__.py +13 -0
- probability_flow/core/cpd/base.py +84 -0
- probability_flow/core/cpd/independent_evidence.py +155 -0
- probability_flow/core/cpd/noisy_and.py +113 -0
- probability_flow/core/cpd/noisy_or.py +109 -0
- probability_flow/core/cpd/tabular.py +111 -0
- probability_flow/core/exact.py +67 -0
- probability_flow/core/network.py +111 -0
- probability_flow/core/node.py +125 -0
- probability_flow/metrics/__init__.py +64 -0
- probability_flow/metrics/_util.py +42 -0
- probability_flow/metrics/difficulty.py +87 -0
- probability_flow/metrics/dseparation.py +83 -0
- probability_flow/metrics/loopiness.py +82 -0
- probability_flow/metrics/manipulability.py +207 -0
- probability_flow/metrics/structure.py +49 -0
- probability_flow/py.typed +0 -0
- probability_flow/visualization/__init__.py +11 -0
- probability_flow/visualization/image.py +402 -0
- probability_flow/visualization/style.py +58 -0
- probability_flow-0.1.0.dist-info/METADATA +304 -0
- probability_flow-0.1.0.dist-info/RECORD +37 -0
- probability_flow-0.1.0.dist-info/WHEEL +4 -0
- probability_flow-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
"""Random ASPIC argument-graph generator. See docs/generation.md.
|
|
2
|
+
|
|
3
|
+
Builds typed arguments *directly* — `Premise` / `Axiom` / `Conclusion` wired with
|
|
4
|
+
`support` / `rebut` / `strict` / `undermine` / `undercut` — and returns an
|
|
5
|
+
`Argument`, then screens by difficulty through the metrics seam. There is no
|
|
6
|
+
post-hoc type relabelling and no `LR == 9999` sentinel: the type is chosen as each
|
|
7
|
+
node is made.
|
|
8
|
+
|
|
9
|
+
Structural model: a root fed by `n_support` supporting and `n_attack` attacking
|
|
10
|
+
branches. In the default no-share mode every branch is built from fresh nodes, so
|
|
11
|
+
the branches are mutually d-separated and `n_groups` (= `n_support + n_attack`) is
|
|
12
|
+
the *exact* d-separated-group count. A branch is a **tree** (a conclusion can draw
|
|
13
|
+
several sub-arguments — fan-in — not just a chain); a sub-argument can support or
|
|
14
|
+
*attack* its parent conclusion (`internal_attack_prob`), so attack-of-an-attack
|
|
15
|
+
falls out; undercutters and underminers can themselves be small sub-arguments
|
|
16
|
+
(`attacker_depth_range`); a strict internal edge compiles to a noisy-OR. Everything
|
|
17
|
+
stays a polytree (fresh disjoint nodes), so the d-sep dial holds and loopy BP is
|
|
18
|
+
exact.
|
|
19
|
+
|
|
20
|
+
`generate(...)` wraps the builder in a rejection loop (the hybrid strategy:
|
|
21
|
+
rejection now, a constructive pass for hard targets later). It calibrates the root
|
|
22
|
+
prior to a *target posterior value* by bisection, grows to a target claim count
|
|
23
|
+
constructively, and screens realized depth. The branch-level methods
|
|
24
|
+
(`add_support_branch` / `add_attack_branch`) are usable directly to hand-script a
|
|
25
|
+
template.
|
|
26
|
+
"""
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import math
|
|
30
|
+
import random
|
|
31
|
+
from dataclasses import dataclass
|
|
32
|
+
from typing import Optional
|
|
33
|
+
|
|
34
|
+
from .argument import Axiom, Conclusion, Premise, _Claim
|
|
35
|
+
from .compile import _reachable
|
|
36
|
+
from .handle import Argument
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class StructuralParams:
|
|
41
|
+
"""Knobs the proposer draws from (the *shape*, not the difficulty target). The
|
|
42
|
+
number of d-separated groups is `n_support + n_attack`, set explicitly."""
|
|
43
|
+
|
|
44
|
+
n_support: int = 2 # supporting branches into the root
|
|
45
|
+
n_attack: int = 1 # attacking branches into the root
|
|
46
|
+
min_depth: int = 1 # per-branch chain length, drawn uniform
|
|
47
|
+
max_depth: int = 3
|
|
48
|
+
depths: Optional[list] = None # exact per-branch depths (len == n_groups)
|
|
49
|
+
max_fanin: int = 2 # extra sub-arguments per conclusion
|
|
50
|
+
fanin_prob: float = 0.3 # probability of each extra sub-argument
|
|
51
|
+
internal_attack_prob: float = 0.1 # a sub-argument rebuts its parent conclusion
|
|
52
|
+
share_prob: float = 0.0 # an extra sub-argument reuses an existing
|
|
53
|
+
# node (shared parent). 0 = forest of trees
|
|
54
|
+
# (polytree); > 0 makes real cycles / collapses
|
|
55
|
+
# d-sep groups, so loopy BP turns approximate.
|
|
56
|
+
undercut_prob: float = 0.15 # per edge
|
|
57
|
+
undermine_prob: float = 0.15 # per ordinary premise
|
|
58
|
+
strict_prob: float = 0.1 # per internal support edge (-> noisy-OR)
|
|
59
|
+
axiom_prob: float = 0.2 # per leaf
|
|
60
|
+
attacker_depth_range: tuple = (1, 1) # undercutter / underminer depth (1 = bare premise)
|
|
61
|
+
n_claims_range: Optional[tuple] = None # constructive total argument-claim band
|
|
62
|
+
root_prior: Optional[float] = None # None -> uniform in [0.3, 0.7]
|
|
63
|
+
leaf_prior_range: tuple = (0.3, 0.75)
|
|
64
|
+
mid_prior_range: tuple = (0.1, 0.25) # intermediate conclusion prior
|
|
65
|
+
attacker_prior_range: tuple = (0.05, 0.4) # bare undercutter / underminer prior
|
|
66
|
+
support_lr_range: tuple = (3.0, 15.0) # root support edges (log-uniform, lr > 1)
|
|
67
|
+
internal_lr_range: tuple = (1.5, 12.0) # internal support edges (lr > 1)
|
|
68
|
+
attack_lr_range: tuple = (1 / 15, 1 / 3) # reciprocal of support -> symmetric |log lr|
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def n_groups(self) -> int:
|
|
72
|
+
return self.n_support + self.n_attack
|
|
73
|
+
|
|
74
|
+
def __post_init__(self):
|
|
75
|
+
if self.n_support < 0 or self.n_attack < 0:
|
|
76
|
+
raise ValueError("n_support and n_attack must be >= 0")
|
|
77
|
+
if self.n_groups < 1:
|
|
78
|
+
raise ValueError("need at least one branch (n_support + n_attack >= 1)")
|
|
79
|
+
if self.depths is not None and len(self.depths) != self.n_groups:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"depths has {len(self.depths)} entries but n_groups is {self.n_groups}"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class DifficultyTargets:
|
|
87
|
+
"""What an accepted graph must satisfy (None = unconstrained)."""
|
|
88
|
+
|
|
89
|
+
threshold: float = 0.5
|
|
90
|
+
posterior_side: Optional[str] = None # 'above' | 'below' | None (coarse)
|
|
91
|
+
target_posterior: Optional[float] = None # calibrate the root prior to hit this
|
|
92
|
+
posterior_tol: float = 0.02 # accepted band around target_posterior
|
|
93
|
+
min_manipulability: Optional[float] = None
|
|
94
|
+
min_depth: Optional[int] = None # realized longest input-path into the root
|
|
95
|
+
d_sep_groups: Optional[int] = None # realized d-separated group count (for share mode)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ArgumentGenerator:
|
|
99
|
+
"""Builds one argument. Add branches (each its own d-separated group), then
|
|
100
|
+
`build()` for the `Argument`. Construct with a seeded `random.Random` for
|
|
101
|
+
reproducibility."""
|
|
102
|
+
|
|
103
|
+
def __init__(self, rng: random.Random, params: Optional[StructuralParams] = None):
|
|
104
|
+
self.rng = rng
|
|
105
|
+
self.p = params or StructuralParams()
|
|
106
|
+
if self.p.n_groups < 1:
|
|
107
|
+
raise ValueError("need at least one branch")
|
|
108
|
+
self._n = 0
|
|
109
|
+
self._uc_nodes: set = set() # nodes inside undercutter/underminer subgraphs
|
|
110
|
+
prior = (self.p.root_prior if self.p.root_prior is not None
|
|
111
|
+
else round(rng.uniform(0.3, 0.7), 4))
|
|
112
|
+
self.root = Conclusion("Root", prior=prior)
|
|
113
|
+
|
|
114
|
+
# --- branch-level API (the "add a d-separated group" primitive) ---------
|
|
115
|
+
|
|
116
|
+
def add_support_branch(self, depth: Optional[int] = None) -> None:
|
|
117
|
+
self._branch(self.root, "support", depth or self._depth())
|
|
118
|
+
|
|
119
|
+
def add_attack_branch(self, depth: Optional[int] = None) -> None:
|
|
120
|
+
self._branch(self.root, "attack", depth or self._depth())
|
|
121
|
+
|
|
122
|
+
def build(self) -> Argument:
|
|
123
|
+
return self.root.assemble()
|
|
124
|
+
|
|
125
|
+
def claim_count(self) -> int:
|
|
126
|
+
return sum(1 for c in _reachable(self.root) if isinstance(c, _Claim))
|
|
127
|
+
|
|
128
|
+
def sprout(self) -> None:
|
|
129
|
+
"""Add one depth-1 supporting premise to a random conclusion — grows the
|
|
130
|
+
claim count by exactly one, for constructive size control."""
|
|
131
|
+
conclusions = [c for c in _reachable(self.root) if isinstance(c, Conclusion)]
|
|
132
|
+
self.rng.choice(conclusions).support(self._leaf_premise(),
|
|
133
|
+
self._lr(*self.p.internal_lr_range))
|
|
134
|
+
|
|
135
|
+
# --- internals ----------------------------------------------------------
|
|
136
|
+
|
|
137
|
+
def _name(self) -> str:
|
|
138
|
+
self._n += 1
|
|
139
|
+
return f"Node {self._n}"
|
|
140
|
+
|
|
141
|
+
def _depth(self) -> int:
|
|
142
|
+
return self.rng.randint(self.p.min_depth, self.p.max_depth)
|
|
143
|
+
|
|
144
|
+
def _lr(self, lo: float, hi: float) -> float:
|
|
145
|
+
return round(math.exp(self.rng.uniform(math.log(lo), math.log(hi))), 4)
|
|
146
|
+
|
|
147
|
+
def _prior(self, rng_range: tuple) -> float:
|
|
148
|
+
return round(self.rng.uniform(*rng_range), 4)
|
|
149
|
+
|
|
150
|
+
def _leaf_premise(self) -> Premise:
|
|
151
|
+
return Premise(self._name(), prior=self._prior(self.p.leaf_prior_range))
|
|
152
|
+
|
|
153
|
+
def _leaf(self):
|
|
154
|
+
if self.rng.random() < self.p.axiom_prob:
|
|
155
|
+
return Axiom(self._name())
|
|
156
|
+
return self._leaf_premise()
|
|
157
|
+
|
|
158
|
+
def _branch(self, target, polarity: str, depth: int):
|
|
159
|
+
src = self._subtree(depth - 1)
|
|
160
|
+
self._attach(target, src, polarity)
|
|
161
|
+
self._decorate(target, src)
|
|
162
|
+
return src
|
|
163
|
+
|
|
164
|
+
def _subtree(self, height: int, *, isolated: bool = False):
|
|
165
|
+
"""A fresh sub-argument; returns its root (a premise at height 0, else a
|
|
166
|
+
conclusion fed by `_populate`). `height` is hops above the root.
|
|
167
|
+
`isolated=True` (attacker sub-arguments) disables decoration and sharing, so
|
|
168
|
+
the subgraph stays fully disjoint."""
|
|
169
|
+
if height <= 0:
|
|
170
|
+
return self._leaf()
|
|
171
|
+
root = Conclusion(self._name(), prior=self._prior(self.p.mid_prior_range))
|
|
172
|
+
self._populate(root, height, isolated=isolated)
|
|
173
|
+
return root
|
|
174
|
+
|
|
175
|
+
def _populate(self, conclusion, height: int, *, isolated: bool = False) -> None:
|
|
176
|
+
"""Give `conclusion` its sub-arguments: the first continues to `height - 1`
|
|
177
|
+
(so the height is realized) and is always fresh; up to `max_fanin` shallower
|
|
178
|
+
extras follow (a tree). With `share_prob > 0` an extra may *reuse* an existing
|
|
179
|
+
node instead of a fresh one — a shared parent, which makes the graph
|
|
180
|
+
non-polytree (a within-branch cycle) or collapses d-sep groups (across
|
|
181
|
+
branches). Each child attaches by support, or by attack with
|
|
182
|
+
`internal_attack_prob`."""
|
|
183
|
+
n_extra = sum(self.rng.random() < self.p.fanin_prob
|
|
184
|
+
for _ in range(self.p.max_fanin))
|
|
185
|
+
heights = [height - 1] + [self.rng.randint(0, height - 1) for _ in range(n_extra)]
|
|
186
|
+
for j, h in enumerate(heights):
|
|
187
|
+
child, shared = self._child(conclusion, h, allow_share=not isolated and j > 0)
|
|
188
|
+
polarity = ("attack" if self.rng.random() < self.p.internal_attack_prob
|
|
189
|
+
else "support")
|
|
190
|
+
self._attach(conclusion, child, polarity)
|
|
191
|
+
if not isolated and not shared: # a reused node already has structure
|
|
192
|
+
self._decorate(conclusion, child)
|
|
193
|
+
|
|
194
|
+
def _child(self, conclusion, height: int, *, allow_share: bool):
|
|
195
|
+
"""Either reuse an eligible existing node (probability `share_prob`) or build a
|
|
196
|
+
fresh sub-argument. Returns `(node, was_shared)`."""
|
|
197
|
+
if allow_share and self.p.share_prob > 0 and self.rng.random() < self.p.share_prob:
|
|
198
|
+
role = Premise if height <= 0 else Conclusion
|
|
199
|
+
cands = self._share_candidates(role, conclusion)
|
|
200
|
+
if cands:
|
|
201
|
+
return self.rng.choice(cands), True
|
|
202
|
+
return self._subtree(height), False
|
|
203
|
+
|
|
204
|
+
def _share_candidates(self, role, conclusion):
|
|
205
|
+
"""Existing nodes that may legally become a new parent of `conclusion`:
|
|
206
|
+
matching role, not the root, not already a source of it, not inside an
|
|
207
|
+
undercutter subgraph, and not an ancestor of it (which would make a directed
|
|
208
|
+
cycle). Reuse is what turns the forest into a DAG with shared parents."""
|
|
209
|
+
existing = {s for s, _lr, _k in conclusion._edges} | set(conclusion._strict)
|
|
210
|
+
out = []
|
|
211
|
+
for n in _reachable(self.root):
|
|
212
|
+
if not isinstance(n, role) or n is self.root or n is conclusion:
|
|
213
|
+
continue
|
|
214
|
+
if n in self._uc_nodes or n in existing:
|
|
215
|
+
continue
|
|
216
|
+
if conclusion in _reachable(n): # n is an ancestor -> cycle
|
|
217
|
+
continue
|
|
218
|
+
out.append(n)
|
|
219
|
+
return out
|
|
220
|
+
|
|
221
|
+
def _mark_uc(self, node) -> None:
|
|
222
|
+
for c in _reachable(node):
|
|
223
|
+
if isinstance(c, _Claim):
|
|
224
|
+
self._uc_nodes.add(c)
|
|
225
|
+
|
|
226
|
+
def _attach(self, downstream, upstream, polarity: str) -> None:
|
|
227
|
+
"""Wire `upstream` into `downstream` with the right typed edge."""
|
|
228
|
+
if polarity == "support":
|
|
229
|
+
# strict only on an internal edge from a non-axiom: a strict root would
|
|
230
|
+
# divorce every defeasible branch into D and break the d-sep count.
|
|
231
|
+
if (downstream is not self.root and not isinstance(upstream, Axiom)
|
|
232
|
+
and self.rng.random() < self.p.strict_prob):
|
|
233
|
+
downstream.strict(upstream)
|
|
234
|
+
else:
|
|
235
|
+
lr_range = (self.p.support_lr_range if downstream is self.root
|
|
236
|
+
else self.p.internal_lr_range)
|
|
237
|
+
downstream.support(upstream, self._lr(*lr_range))
|
|
238
|
+
else: # attack
|
|
239
|
+
lr = self._lr(*self.p.attack_lr_range)
|
|
240
|
+
if isinstance(downstream, Premise):
|
|
241
|
+
downstream.undermine(by=upstream, lr=lr) # attack on a premise
|
|
242
|
+
else:
|
|
243
|
+
downstream.rebut(upstream, lr=lr) # attack on a conclusion
|
|
244
|
+
|
|
245
|
+
def _attacker(self):
|
|
246
|
+
"""A fresh, disjoint undercutter / underminer: a bare premise at depth 1, or
|
|
247
|
+
its own small sub-argument up to `attacker_depth_range`."""
|
|
248
|
+
d = self.rng.randint(*self.p.attacker_depth_range)
|
|
249
|
+
if d <= 1:
|
|
250
|
+
return Premise(self._name(), prior=self._prior(self.p.attacker_prior_range))
|
|
251
|
+
root = Conclusion(self._name(), prior=self._prior(self.p.mid_prior_range))
|
|
252
|
+
self._populate(root, d - 1, isolated=True) # disjoint: no decorate, no share
|
|
253
|
+
return root
|
|
254
|
+
|
|
255
|
+
def _decorate(self, downstream, source) -> None:
|
|
256
|
+
"""Optionally undercut the `source -> downstream` edge, and optionally
|
|
257
|
+
undermine `source` if it is an ordinary premise. Attacker subgraphs are
|
|
258
|
+
marked so sharing never reaches into or out of them."""
|
|
259
|
+
if self.rng.random() < self.p.undercut_prob:
|
|
260
|
+
by = self._attacker()
|
|
261
|
+
downstream.undercut(source, by=by)
|
|
262
|
+
self._mark_uc(by)
|
|
263
|
+
if (isinstance(source, Premise) and not isinstance(source, Axiom)
|
|
264
|
+
and self.rng.random() < self.p.undermine_prob):
|
|
265
|
+
by = self._attacker()
|
|
266
|
+
source.undermine(by=by, lr=self._lr(*self.p.attack_lr_range))
|
|
267
|
+
self._mark_uc(by)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _calibrate_root_prior(root, target: float, tol: float,
|
|
271
|
+
lo: float = 0.02, hi: float = 0.98, iters: int = 60) -> bool:
|
|
272
|
+
"""Bisect the root's prior so its posterior hits `target` (the posterior is
|
|
273
|
+
monotone increasing in the prior). Mutates `root.prior`; returns True if a
|
|
274
|
+
prior in `[lo, hi]` lands within `tol`, else False (unreachable for this
|
|
275
|
+
structure — resample). Cheap: the graph is a polytree, so loopy BP is exact.
|
|
276
|
+
|
|
277
|
+
This is the dependency-free calibration fallback. It tunes one parameter (the
|
|
278
|
+
root prior) and may push it outside the nominal structural range; the planned
|
|
279
|
+
JAX path (docs/generation.md) does a projected multi-parameter solve instead.
|
|
280
|
+
"""
|
|
281
|
+
from ..core import ExactSolver, LoopySolver
|
|
282
|
+
from ..metrics import is_polytree
|
|
283
|
+
|
|
284
|
+
def posterior(p: float) -> float:
|
|
285
|
+
root.prior = float(p)
|
|
286
|
+
bn = root.compile()
|
|
287
|
+
solver = LoopySolver if is_polytree(bn) else ExactSolver # share mode -> exact
|
|
288
|
+
return solver(bn).prob(root, 1)
|
|
289
|
+
|
|
290
|
+
if not posterior(lo) - tol <= target <= posterior(hi) + tol:
|
|
291
|
+
return False
|
|
292
|
+
for _ in range(iters):
|
|
293
|
+
mid = 0.5 * (lo + hi)
|
|
294
|
+
pm = posterior(mid)
|
|
295
|
+
if abs(pm - target) <= tol:
|
|
296
|
+
return True
|
|
297
|
+
if pm < target:
|
|
298
|
+
lo = mid
|
|
299
|
+
else:
|
|
300
|
+
hi = mid
|
|
301
|
+
return abs(posterior(0.5 * (lo + hi)) - target) <= tol
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _build_one(rng, params) -> "ArgumentGenerator":
|
|
305
|
+
"""One proposed graph: the branches at their drawn/exact depths, then grown
|
|
306
|
+
constructively to `n_claims_range`'s lower bound if set."""
|
|
307
|
+
g = ArgumentGenerator(rng, params)
|
|
308
|
+
if params.depths is not None:
|
|
309
|
+
for d in params.depths[:params.n_support]:
|
|
310
|
+
g.add_support_branch(depth=d)
|
|
311
|
+
for d in params.depths[params.n_support:]:
|
|
312
|
+
g.add_attack_branch(depth=d)
|
|
313
|
+
else:
|
|
314
|
+
for _ in range(params.n_support):
|
|
315
|
+
g.add_support_branch()
|
|
316
|
+
for _ in range(params.n_attack):
|
|
317
|
+
g.add_attack_branch()
|
|
318
|
+
if params.n_claims_range is not None:
|
|
319
|
+
lo, _hi = params.n_claims_range
|
|
320
|
+
guard = 0
|
|
321
|
+
while g.claim_count() < lo and guard < 1000:
|
|
322
|
+
g.sprout()
|
|
323
|
+
guard += 1
|
|
324
|
+
return g
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def generate(seed: Optional[int] = None, *,
|
|
328
|
+
structural: Optional[StructuralParams] = None,
|
|
329
|
+
targets: Optional[DifficultyTargets] = None,
|
|
330
|
+
max_attempts: int = 200, verbose: bool = False) -> Argument:
|
|
331
|
+
"""Rejection-sample an `Argument` meeting `targets`. The structural shape
|
|
332
|
+
(`n_support`/`n_attack`, depth, features) is set directly; the difficulty
|
|
333
|
+
targets are screened, and `target_posterior` is reached by calibrating the root
|
|
334
|
+
prior. Raises `RuntimeError` if no graph passes within `max_attempts`."""
|
|
335
|
+
from ..core import ExactSolver, LoopySolver
|
|
336
|
+
from ..metrics import d_separated_groups, is_polytree, manipulability, max_depth
|
|
337
|
+
|
|
338
|
+
params = structural or StructuralParams()
|
|
339
|
+
tgt = targets or DifficultyTargets()
|
|
340
|
+
base = random.Random(seed)
|
|
341
|
+
|
|
342
|
+
for attempt in range(1, max_attempts + 1):
|
|
343
|
+
rng = random.Random(base.randint(0, 2**31))
|
|
344
|
+
g = _build_one(rng, params)
|
|
345
|
+
|
|
346
|
+
if params.n_claims_range is not None:
|
|
347
|
+
_lo, hi = params.n_claims_range
|
|
348
|
+
if g.claim_count() > hi:
|
|
349
|
+
if verbose:
|
|
350
|
+
print(f"attempt {attempt}: {g.claim_count()} claims > {hi}")
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
if (tgt.target_posterior is not None
|
|
354
|
+
and not _calibrate_root_prior(g.root, tgt.target_posterior,
|
|
355
|
+
tgt.posterior_tol)):
|
|
356
|
+
if verbose:
|
|
357
|
+
print(f"attempt {attempt}: target {tgt.target_posterior} unreachable")
|
|
358
|
+
continue
|
|
359
|
+
|
|
360
|
+
arg = g.build()
|
|
361
|
+
bn = arg.bn
|
|
362
|
+
# a polytree (no sharing) -> loopy BP is exact and linear; sharing makes a
|
|
363
|
+
# cycle, so fall back to the exact solver (keep share-mode graphs small).
|
|
364
|
+
p_root = (LoopySolver if is_polytree(bn) else ExactSolver)(bn).prob(arg.target, 1)
|
|
365
|
+
|
|
366
|
+
if tgt.d_sep_groups is not None and \
|
|
367
|
+
len(d_separated_groups(bn, arg.target)) != tgt.d_sep_groups:
|
|
368
|
+
continue
|
|
369
|
+
if tgt.posterior_side == "above" and not p_root > tgt.threshold:
|
|
370
|
+
continue
|
|
371
|
+
if tgt.posterior_side == "below" and not p_root < tgt.threshold:
|
|
372
|
+
continue
|
|
373
|
+
if (tgt.target_posterior is not None
|
|
374
|
+
and abs(p_root - tgt.target_posterior) > tgt.posterior_tol):
|
|
375
|
+
continue
|
|
376
|
+
if tgt.min_depth is not None and max_depth(arg.bn, arg.target) < tgt.min_depth:
|
|
377
|
+
if verbose:
|
|
378
|
+
print(f"attempt {attempt}: depth {max_depth(arg.bn, arg.target)} "
|
|
379
|
+
f"< {tgt.min_depth}")
|
|
380
|
+
continue
|
|
381
|
+
if tgt.min_manipulability is not None:
|
|
382
|
+
m = manipulability(arg.bn, arg.target)
|
|
383
|
+
if m < tgt.min_manipulability:
|
|
384
|
+
if verbose:
|
|
385
|
+
print(f"attempt {attempt}: manipulability {m:.3f} < "
|
|
386
|
+
f"{tgt.min_manipulability}")
|
|
387
|
+
continue
|
|
388
|
+
|
|
389
|
+
if verbose:
|
|
390
|
+
print(f"attempt {attempt}: accepted (P={p_root:.3f})")
|
|
391
|
+
return arg
|
|
392
|
+
|
|
393
|
+
raise RuntimeError(
|
|
394
|
+
f"no graph met the targets in {max_attempts} attempts; relax the targets "
|
|
395
|
+
"(move threshold toward 0.5, widen posterior_tol, lower min_manipulability/"
|
|
396
|
+
"min_depth) or raise max_attempts"
|
|
397
|
+
)
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
"""The `Argument` handle: what `target.assemble()` returns.
|
|
2
|
+
|
|
3
|
+
Authoring stays in `argument.py` (free constructors, fluent edges on the
|
|
4
|
+
downstream node). Once an argument is built, `assemble()` wraps its target in an
|
|
5
|
+
`Argument`, the home for everything you do *with* a finished argument: serialize
|
|
6
|
+
it (`to_json` / `from_json` / `save` / `load`), reach the compiled network
|
|
7
|
+
(`.bn`), query a posterior, and measure it (the `metrics` seam).
|
|
8
|
+
|
|
9
|
+
Serialization is at the *argument* level, not the compiled BN: the JSON carries
|
|
10
|
+
the typed edges (support / rebut / strict / undermine) and the undercut as an
|
|
11
|
+
`undercutter_ids` field on the edge it attacks, and `from_json` recompiles the BN
|
|
12
|
+
deterministically. The lowered splice / divorced-D helper nodes never appear.
|
|
13
|
+
|
|
14
|
+
The schema is the team-agreed one in `docs/aspic.md`. The metric fields
|
|
15
|
+
(`input_groups_sizes` from d-separation, `min_posterior` / `max_posterior` from
|
|
16
|
+
the manipulability range) are filled by the `metrics` seam; `revealable_to_judge`
|
|
17
|
+
and `desc` are carried verbatim but the library does not interpret them.
|
|
18
|
+
"""
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import json
|
|
22
|
+
import re
|
|
23
|
+
import warnings
|
|
24
|
+
from typing import TYPE_CHECKING, Optional
|
|
25
|
+
|
|
26
|
+
from ..core import ExactSolver
|
|
27
|
+
from .compile import _reachable
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from ..core import BayesianNetwork, Node
|
|
31
|
+
from .argument import _Claim
|
|
32
|
+
|
|
33
|
+
SCHEMA_VERSION = "0.1"
|
|
34
|
+
|
|
35
|
+
# Above this many nodes the brute-force ExactSolver fallback (2**n joint
|
|
36
|
+
# enumeration, used only on non-polytree arguments) gets painfully slow. See the
|
|
37
|
+
# "Known issues" note in docs/ROADMAP.md.
|
|
38
|
+
_EXACT_FALLBACK_WARN_NODES = 22
|
|
39
|
+
|
|
40
|
+
# argument-edge kind (stored on the node) <-> serialized "type"
|
|
41
|
+
_KIND_TO_TYPE = {"support": "support/defeas", "rebut": "rebut", "undermine": "undermine"}
|
|
42
|
+
_TYPE_TO_KIND = {v: k for k, v in _KIND_TO_TYPE.items()}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _slugify(label: str) -> str:
|
|
46
|
+
return re.sub(r"[^a-z0-9]+", "-", label.lower()).strip("-") or "node"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _assign_ids(claims: list["_Claim"]) -> dict["_Claim", str]:
|
|
50
|
+
"""A stable readable id per claim: the slugified label, with a numeric suffix
|
|
51
|
+
on collision. Assignment is sorted by slug so the same multiset of labels maps
|
|
52
|
+
to the same ids regardless of graph-traversal order (round-trip stable)."""
|
|
53
|
+
used: dict[str, int] = {}
|
|
54
|
+
ids: dict["_Claim", str] = {}
|
|
55
|
+
for c in sorted(claims, key=lambda n: _slugify(n.name)):
|
|
56
|
+
base = _slugify(c.name)
|
|
57
|
+
n = used.get(base, 0)
|
|
58
|
+
ids[c] = base if n == 0 else f"{base}-{n + 1}"
|
|
59
|
+
used[base] = n + 1
|
|
60
|
+
return ids
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Argument:
|
|
64
|
+
"""A finished argument, rooted at its target. Construct via `target.assemble()`."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, target: "_Claim"):
|
|
67
|
+
self.target = target
|
|
68
|
+
self._bn: Optional["BayesianNetwork"] = None
|
|
69
|
+
self._solver = None # cached; see the `solver` property
|
|
70
|
+
self._edge_meta: dict[str, dict] = {} # edge id -> {label, revealable_to_judge}
|
|
71
|
+
|
|
72
|
+
# --- compiled network and queries -------------------------------------
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def bn(self) -> "BayesianNetwork":
|
|
76
|
+
"""The compiled `BayesianNetwork` (built once, then cached)."""
|
|
77
|
+
if self._bn is None:
|
|
78
|
+
self._bn = self.target.compile()
|
|
79
|
+
return self._bn
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def solver(self):
|
|
83
|
+
"""The ground-truth solver for this argument, chosen by topology:
|
|
84
|
+
`LoopySolver` on a polytree (provably exact there, linear in the graph,
|
|
85
|
+
and one propagation answers every node), `ExactSolver` otherwise (loops;
|
|
86
|
+
exponential, so only viable on small graphs). Built once, then cached —
|
|
87
|
+
`LoopySolver` reuses its message passing across queries."""
|
|
88
|
+
if self._solver is None:
|
|
89
|
+
from ..core import LoopySolver
|
|
90
|
+
from ..metrics import is_polytree
|
|
91
|
+
|
|
92
|
+
if is_polytree(self.bn):
|
|
93
|
+
self._solver = LoopySolver(self.bn)
|
|
94
|
+
else:
|
|
95
|
+
n = len(self.bn.nodes)
|
|
96
|
+
if n > _EXACT_FALLBACK_WARN_NODES:
|
|
97
|
+
warnings.warn(
|
|
98
|
+
f"this argument is not a polytree, so it is solved by "
|
|
99
|
+
f"brute-force enumeration of all 2**{n} joint states, which "
|
|
100
|
+
"is slow at this size. Keep shared / loopy graphs small "
|
|
101
|
+
"until the loopy-BP robustness work lands (docs/ROADMAP.md).",
|
|
102
|
+
stacklevel=2,
|
|
103
|
+
)
|
|
104
|
+
self._solver = ExactSolver(self.bn)
|
|
105
|
+
return self._solver
|
|
106
|
+
|
|
107
|
+
def posterior(self, node: "Node", evidence=None) -> float:
|
|
108
|
+
"""`P(node = 1 | evidence)` under the compiled network (see `solver` for
|
|
109
|
+
how the solver is chosen)."""
|
|
110
|
+
return self.solver.prob(node, 1, evidence=evidence)
|
|
111
|
+
|
|
112
|
+
def render(self, **kwargs):
|
|
113
|
+
"""Draw the argument view (see `aspic.visualization`); `self.bn.render()`
|
|
114
|
+
draws the compiled network instead."""
|
|
115
|
+
return self.target.render(**kwargs)
|
|
116
|
+
|
|
117
|
+
# --- argument-level metrics (delegating to the metrics seam) ----------
|
|
118
|
+
|
|
119
|
+
def input_groups_sizes(self, node: "Node") -> Optional[list[int]]:
|
|
120
|
+
"""Sorted evidence-mass sizes of `node`'s d-separated input groups — one
|
|
121
|
+
entry per independent argument branch, counting **only argument claims**.
|
|
122
|
+
The lowered helper nodes (the undercut splice `X`, the divorced `D`) are
|
|
123
|
+
hidden here as everywhere in the argument view (decision 8), so a branch
|
|
124
|
+
whose cone is pure machinery (a prior-only divorced `D`) drops out. `None`
|
|
125
|
+
for a node with no inputs (a premise). Sorted for round-trip stability.
|
|
126
|
+
|
|
127
|
+
The underlying `metrics.d_separated_groups` stays BN-level (its `g.nodes`
|
|
128
|
+
is the full compiled cone); the argument-level filtering happens here, the
|
|
129
|
+
seam where the abstracted view is the handle's job."""
|
|
130
|
+
from ..metrics import d_separated_groups
|
|
131
|
+
from .argument import _Claim
|
|
132
|
+
|
|
133
|
+
if not self.bn.compiled_cpd(node).inputs:
|
|
134
|
+
return None
|
|
135
|
+
sizes = []
|
|
136
|
+
for g in d_separated_groups(self.bn, node):
|
|
137
|
+
claims = sum(1 for x in g.nodes if isinstance(x, _Claim))
|
|
138
|
+
if claims: # drop pure-helper branches (prior-only D)
|
|
139
|
+
sizes.append(claims)
|
|
140
|
+
return sorted(sizes)
|
|
141
|
+
|
|
142
|
+
def posterior_range(self, node: "Node") -> tuple[float, float]:
|
|
143
|
+
"""(min, max) posterior `node` could be pushed to over partial reveals of
|
|
144
|
+
its upstream evidence — the manipulability range. The cheap polynomial outer
|
|
145
|
+
bound; see `metrics.posterior_range` for the reveal model."""
|
|
146
|
+
from ..metrics import posterior_range
|
|
147
|
+
|
|
148
|
+
return posterior_range(self.bn, node)
|
|
149
|
+
|
|
150
|
+
# --- serialization ----------------------------------------------------
|
|
151
|
+
|
|
152
|
+
def to_dict(self) -> dict:
|
|
153
|
+
"""The argument as a plain dict matching the team schema (`docs/aspic.md`)."""
|
|
154
|
+
from .argument import Axiom, Premise, _Claim
|
|
155
|
+
|
|
156
|
+
bn = self.bn
|
|
157
|
+
solver = self.solver
|
|
158
|
+
claims = [c for c in _reachable(self.target) if isinstance(c, _Claim)]
|
|
159
|
+
non_claim = [c for c in _reachable(self.target) if not isinstance(c, _Claim)]
|
|
160
|
+
if non_claim:
|
|
161
|
+
raise TypeError(
|
|
162
|
+
"Argument.to_dict supports arguments built from Premise / Axiom / "
|
|
163
|
+
f"Conclusion only; found plain node(s): {[n.name for n in non_claim]}"
|
|
164
|
+
)
|
|
165
|
+
ids = _assign_ids(claims)
|
|
166
|
+
|
|
167
|
+
def type_of(c: "_Claim") -> str:
|
|
168
|
+
if c is self.target:
|
|
169
|
+
return "root"
|
|
170
|
+
if isinstance(c, Axiom):
|
|
171
|
+
return "axiomatic premise"
|
|
172
|
+
if isinstance(c, Premise):
|
|
173
|
+
return "premise"
|
|
174
|
+
return "conclusion"
|
|
175
|
+
|
|
176
|
+
def cpd_type(c: "_Claim") -> str:
|
|
177
|
+
name = type(bn.compiled_cpd(c).cpd).__name__
|
|
178
|
+
return name[:-3] if name.endswith("CPD") else name
|
|
179
|
+
|
|
180
|
+
def node_dict(c: "_Claim") -> dict:
|
|
181
|
+
lo, hi = self.posterior_range(c)
|
|
182
|
+
return {
|
|
183
|
+
"id": ids[c],
|
|
184
|
+
"type": type_of(c),
|
|
185
|
+
"label": c.name,
|
|
186
|
+
"desc": c.desc,
|
|
187
|
+
"prior": float(c.prior),
|
|
188
|
+
"posterior": float(solver.prob(c, 1)),
|
|
189
|
+
"min_posterior": float(lo),
|
|
190
|
+
"max_posterior": float(hi),
|
|
191
|
+
"input_groups_sizes": self.input_groups_sizes(c),
|
|
192
|
+
"cpd_type": cpd_type(c),
|
|
193
|
+
"revealable_to_judge": c.revealable_to_judge,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
nodes = [node_dict(c) for c in claims]
|
|
197
|
+
|
|
198
|
+
edges = []
|
|
199
|
+
for c in claims:
|
|
200
|
+
undercutters: dict[int, list["_Claim"]] = {}
|
|
201
|
+
for src, by in c._undercuts:
|
|
202
|
+
undercutters.setdefault(id(src), []).append(by)
|
|
203
|
+
|
|
204
|
+
def edge(src, etype, lr):
|
|
205
|
+
eid = f"{ids[src]}->{ids[c]}"
|
|
206
|
+
meta = self._edge_meta.get(eid, {})
|
|
207
|
+
return {
|
|
208
|
+
"id": eid,
|
|
209
|
+
"type": etype,
|
|
210
|
+
"label": meta.get("label"),
|
|
211
|
+
"source": ids[src],
|
|
212
|
+
"target": ids[c],
|
|
213
|
+
"undercutter_ids": [ids[u] for u in undercutters.get(id(src), [])],
|
|
214
|
+
"lr": None if lr is None else float(lr),
|
|
215
|
+
"revealable_to_judge": meta.get("revealable_to_judge"),
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
for src, lr, kind in c._edges:
|
|
219
|
+
edges.append(edge(src, _KIND_TO_TYPE[kind], lr))
|
|
220
|
+
for src in c._strict:
|
|
221
|
+
edges.append(edge(src, "strict", None))
|
|
222
|
+
|
|
223
|
+
return {"version": SCHEMA_VERSION, "target": ids[self.target],
|
|
224
|
+
"nodes": nodes, "edges": edges}
|
|
225
|
+
|
|
226
|
+
def to_json(self, *, indent: int = 2) -> str:
|
|
227
|
+
return json.dumps(self.to_dict(), indent=indent)
|
|
228
|
+
|
|
229
|
+
def save(self, path: str) -> None:
|
|
230
|
+
with open(path, "w") as f:
|
|
231
|
+
f.write(self.to_json())
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def from_dict(cls, data: dict) -> "Argument":
|
|
235
|
+
"""Rebuild an `Argument` from a `to_dict` payload. The compiled BN is
|
|
236
|
+
recomputed lazily on first `.bn` access, so a round trip reproduces it."""
|
|
237
|
+
from .argument import Axiom, Conclusion, Premise
|
|
238
|
+
|
|
239
|
+
by_id: dict[str, "_Claim"] = {}
|
|
240
|
+
for nd in data["nodes"]:
|
|
241
|
+
t = nd["type"]
|
|
242
|
+
if t == "axiomatic premise":
|
|
243
|
+
node: "_Claim" = Axiom(nd["label"])
|
|
244
|
+
elif t == "premise":
|
|
245
|
+
node = Premise(nd["label"], prior=nd.get("prior", 0.5))
|
|
246
|
+
else: # "conclusion" or "root"
|
|
247
|
+
node = Conclusion(nd["label"], prior=nd.get("prior", 0.5))
|
|
248
|
+
node.desc = nd.get("desc")
|
|
249
|
+
node.revealable_to_judge = nd.get("revealable_to_judge")
|
|
250
|
+
by_id[nd["id"]] = node
|
|
251
|
+
|
|
252
|
+
edge_meta: dict[str, dict] = {}
|
|
253
|
+
for ed in data["edges"]:
|
|
254
|
+
src, tgt = by_id[ed["source"]], by_id[ed["target"]]
|
|
255
|
+
etype, lr = ed["type"], ed.get("lr")
|
|
256
|
+
if etype == "strict":
|
|
257
|
+
tgt.strict(src)
|
|
258
|
+
elif etype == "undermine":
|
|
259
|
+
tgt.undermine(by=src, lr=lr)
|
|
260
|
+
elif etype in _TYPE_TO_KIND: # support/defeas or rebut
|
|
261
|
+
(tgt.support if etype == "support/defeas" else tgt.rebut)(src, lr)
|
|
262
|
+
else:
|
|
263
|
+
raise ValueError(f"unknown edge type {etype!r}")
|
|
264
|
+
for uid in ed.get("undercutter_ids", []):
|
|
265
|
+
tgt.undercut(src, by=by_id[uid])
|
|
266
|
+
if ed.get("label") is not None or ed.get("revealable_to_judge") is not None:
|
|
267
|
+
edge_meta[ed["id"]] = {"label": ed.get("label"),
|
|
268
|
+
"revealable_to_judge": ed.get("revealable_to_judge")}
|
|
269
|
+
|
|
270
|
+
arg = cls(by_id[data["target"]])
|
|
271
|
+
arg._edge_meta = edge_meta
|
|
272
|
+
return arg
|
|
273
|
+
|
|
274
|
+
@classmethod
|
|
275
|
+
def from_json(cls, s: str) -> "Argument":
|
|
276
|
+
return cls.from_dict(json.loads(s))
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def load(cls, path: str) -> "Argument":
|
|
280
|
+
with open(path) as f:
|
|
281
|
+
return cls.from_json(f.read())
|