probability-flow 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {probability_flow-0.3.0 → probability_flow-0.4.0}/.gitignore +3 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/PKG-INFO +2 -1
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/__init__.py +3 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/argument.py +57 -2
- probability_flow-0.4.0/probability_flow/aspic/benchmark.py +254 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/calibrate.py +7 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/compile.py +29 -11
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/generate.py +177 -4
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/handle.py +25 -2
- probability_flow-0.4.0/probability_flow/aspic/optimize.py +514 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/visualization.py +17 -1
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/correlated_evidence.py +3 -1
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/__init__.py +8 -1
- probability_flow-0.4.0/probability_flow/metrics/manipulability.py +474 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/layout.py +184 -62
- {probability_flow-0.3.0 → probability_flow-0.4.0}/pyproject.toml +12 -2
- probability_flow-0.3.0/probability_flow/metrics/manipulability.py +0 -207
- {probability_flow-0.3.0 → probability_flow-0.4.0}/LICENSE +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/README.md +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/__init__.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/__init__.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/_logmath.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/bp/__init__.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/bp/engine.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/bp/message.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/__init__.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/base.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/independent_evidence.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_and.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_or.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/tabular.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/exact.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/network.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/node.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/_util.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/difficulty.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/dseparation.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/loopiness.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/structure.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/py.typed +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/__init__.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/animate.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/image.py +0 -0
- {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/style.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: probability-flow
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A from-scratch, modular discrete Bayesian-network library.
|
|
5
5
|
Project-URL: Homepage, https://github.com/scalable-oversight-benchmarks/probability-flow
|
|
6
6
|
Project-URL: Repository, https://github.com/scalable-oversight-benchmarks/probability-flow
|
|
@@ -28,6 +28,7 @@ Requires-Dist: pytest; extra == 'dev'
|
|
|
28
28
|
Requires-Dist: ruff; extra == 'dev'
|
|
29
29
|
Provides-Extra: jax
|
|
30
30
|
Requires-Dist: jax; extra == 'jax'
|
|
31
|
+
Requires-Dist: scipy; extra == 'jax'
|
|
31
32
|
Provides-Extra: viz
|
|
32
33
|
Requires-Dist: matplotlib; extra == 'viz'
|
|
33
34
|
Description-Content-Type: text/markdown
|
|
@@ -9,13 +9,16 @@ pure-BN `core`. Build an argument out of premises, conclusions, and attacks, the
|
|
|
9
9
|
from .argument import ArgumentWarning, Axiom, Conclusion, Premise
|
|
10
10
|
from .generate import (
|
|
11
11
|
ArgumentGenerator,
|
|
12
|
+
BatchParams,
|
|
12
13
|
DifficultyTargets,
|
|
13
14
|
StructuralParams,
|
|
14
15
|
generate,
|
|
16
|
+
generate_batch,
|
|
15
17
|
)
|
|
16
18
|
from .handle import Argument
|
|
17
19
|
|
|
18
20
|
__all__ = [
|
|
19
21
|
"Premise", "Axiom", "Conclusion", "ArgumentWarning", "Argument",
|
|
20
22
|
"ArgumentGenerator", "StructuralParams", "DifficultyTargets", "generate",
|
|
23
|
+
"BatchParams", "generate_batch",
|
|
21
24
|
]
|
|
@@ -47,6 +47,11 @@ class _Claim(Node):
|
|
|
47
47
|
# (source_a, source_b, J): residual correlation between two defeasible
|
|
48
48
|
# sources of this conclusion, lowered to a CorrelatedEvidenceCPD coupling.
|
|
49
49
|
self._correlations: list[tuple["Node", "Node", "float | None"]] = []
|
|
50
|
+
# (sources, lr, leak, rule): a conjunctive support group — the sources
|
|
51
|
+
# JOINTLY support this conclusion through one noisy-AND inference (the
|
|
52
|
+
# deduction rule), lowered to a hidden NoisyAnd gate fed in as one lr.
|
|
53
|
+
self._conjunctions: list[
|
|
54
|
+
tuple[list["Node"], "float | None", float, "str | None"]] = []
|
|
50
55
|
self._no_undermine = False
|
|
51
56
|
# opaque, serialized but uninterpreted by the library (see docs/aspic.md):
|
|
52
57
|
self.desc: str | None = None # a longer description
|
|
@@ -56,8 +61,13 @@ class _Claim(Node):
|
|
|
56
61
|
"""At most one argumentative edge joins a pair of claims. This is what lets
|
|
57
62
|
an undercut be addressed by its endpoints and keeps serialized edge ids
|
|
58
63
|
(`source->target`) unique."""
|
|
59
|
-
|
|
60
|
-
s is src for
|
|
64
|
+
in_conj = any(
|
|
65
|
+
s is src for srcs, _lr, _leak, _rule in self._conjunctions for s in srcs
|
|
66
|
+
)
|
|
67
|
+
if (
|
|
68
|
+
any(s is src for s, _lr, _kind in self._edges)
|
|
69
|
+
or any(s is src for s in self._strict)
|
|
70
|
+
or in_conj
|
|
61
71
|
):
|
|
62
72
|
raise ValueError(
|
|
63
73
|
f"{self.name!r} already has an edge from {src.name!r}; at most one "
|
|
@@ -80,6 +90,51 @@ class _Claim(Node):
|
|
|
80
90
|
self._edges.append((src, None if lr is None else float(lr), "support"))
|
|
81
91
|
return src
|
|
82
92
|
|
|
93
|
+
def support_all(
|
|
94
|
+
self,
|
|
95
|
+
srcs: "list[Node]",
|
|
96
|
+
lr: "float | None" = None,
|
|
97
|
+
leak: float = 0.0,
|
|
98
|
+
rule: "str | None" = None,
|
|
99
|
+
) -> "list[Node]":
|
|
100
|
+
"""Conjunctive support: `srcs` JOINTLY support this conclusion through a
|
|
101
|
+
single noisy-AND inference — the deduction fires only when *all* the sources
|
|
102
|
+
hold. This is the faithful mapping of a deductive entailment step: N
|
|
103
|
+
antecedent siblings plus one inference rule yield one deduced conclusion,
|
|
104
|
+
the shape of every node in a MuSR-style reasoning tree. Returns `srcs`, so
|
|
105
|
+
inline sources can be built further upstream (as with `support`).
|
|
106
|
+
|
|
107
|
+
Lowered (see `compile.py`) to a hidden `NoisyAnd` gate over `srcs`, each at
|
|
108
|
+
activation 1 — the gate is the logical AND, firing iff every source holds —
|
|
109
|
+
fed into this conclusion as a single `lr` support. So all the evidential
|
|
110
|
+
strength is the one `lr > 1` (where the inference rule's re-elicited weight
|
|
111
|
+
lands), and the conclusion keeps its prior and composes with ordinary
|
|
112
|
+
`support` / `rebut` edges. `leak` is the rule's fallibility:
|
|
113
|
+
`P(fire | all sources hold) = 1 - leak`. `rule` is an optional label for the
|
|
114
|
+
inference licence (carried for audit / rendering, not interpreted). `lr` may
|
|
115
|
+
be omitted (`None`) to scaffold the group before its strength is known.
|
|
116
|
+
"""
|
|
117
|
+
srcs = list(srcs)
|
|
118
|
+
if len(srcs) < 2:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"support_all needs >= 2 sources (it is a conjunction); use support "
|
|
121
|
+
"for a single argument"
|
|
122
|
+
)
|
|
123
|
+
if len({id(s) for s in srcs}) != len(srcs):
|
|
124
|
+
raise ValueError("support_all sources must be distinct")
|
|
125
|
+
if lr is not None and not lr > 1:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"support_all lr must be > 1 (got {lr}); it is conjunctive support"
|
|
128
|
+
)
|
|
129
|
+
if not 0.0 <= leak <= 1.0:
|
|
130
|
+
raise ValueError("leak must be a probability in [0, 1]")
|
|
131
|
+
for src in srcs:
|
|
132
|
+
self._require_new_edge(src)
|
|
133
|
+
self._conjunctions.append(
|
|
134
|
+
(srcs, None if lr is None else float(lr), float(leak), rule)
|
|
135
|
+
)
|
|
136
|
+
return srcs
|
|
137
|
+
|
|
83
138
|
def rebut(self, src: "Node", lr: "float | None" = None) -> "Node":
|
|
84
139
|
"""Add a defeasible argument *against* this conclusion (`0 < lr < 1`). `lr`
|
|
85
140
|
may be omitted (`None`) to declare the edge before its strength is known."""
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Benchmark helpers: generate → optimize bridge for large-batch posterior targeting.
|
|
2
|
+
|
|
3
|
+
This module isolates the JAX/scipy dependency (via ``optimize``) from the pure-Python
|
|
4
|
+
``generate.py``. The primary entry point is ``generate_with_targets``, which produces
|
|
5
|
+
N argument graphs each tuned to a requested root posterior.
|
|
6
|
+
|
|
7
|
+
Strategy
|
|
8
|
+
--------
|
|
9
|
+
For each (seed, target_posterior) pair:
|
|
10
|
+
|
|
11
|
+
1. Try ``generate(seed, structural, DifficultyTargets(target_posterior=t))`` — this
|
|
12
|
+
bisects the root prior to hit *t* cheaply (no JAX needed).
|
|
13
|
+
2. If the bisection cannot reach *t* for the drawn structure (``RuntimeError`` from
|
|
14
|
+
the rejection loop), fall back to:
|
|
15
|
+
a. Generate an *unconstrained* graph from the same seed.
|
|
16
|
+
b. Check whether *t* is inside ``achievable_interval(arg)``.
|
|
17
|
+
c. If yes, call ``optimize(arg, target_posterior=t)`` to tune LRs.
|
|
18
|
+
3. If neither path succeeds, record the failure (skip / warn / raise per
|
|
19
|
+
``on_failure``).
|
|
20
|
+
|
|
21
|
+
Both paths return an ``Argument`` with identical public interface. The ``optimize``
|
|
22
|
+
fallback is more expensive (JAX JIT + SLSQP) but reaches extreme posteriors that
|
|
23
|
+
bisection cannot.
|
|
24
|
+
|
|
25
|
+
Batch labeling
|
|
26
|
+
--------------
|
|
27
|
+
Each accepted argument is annotated with a ``_benchmark_meta`` attribute (a plain
|
|
28
|
+
dict) carrying the difficulty scalars the benchmark consumer (``debate-eval``) needs:
|
|
29
|
+
|
|
30
|
+
posterior, manipulability, d_sep_count, max_depth_metric,
|
|
31
|
+
upstream_size, circuit_rank, seed, is_exact_manipulability,
|
|
32
|
+
used_optimize_fallback
|
|
33
|
+
|
|
34
|
+
These come entirely from existing metrics — no new computation.
|
|
35
|
+
|
|
36
|
+
JAX/scipy availability
|
|
37
|
+
----------------------
|
|
38
|
+
Both are required for the ``optimize`` fallback path. When they are absent the
|
|
39
|
+
module still imports cleanly; ``generate_with_targets`` will error only if an
|
|
40
|
+
optimize fallback is actually needed. Use ``generate_batch`` (no JAX) when you only
|
|
41
|
+
need bisection-reachable posteriors.
|
|
42
|
+
"""
|
|
43
|
+
from __future__ import annotations
|
|
44
|
+
|
|
45
|
+
import warnings
|
|
46
|
+
from typing import TYPE_CHECKING, Optional
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from .handle import Argument
|
|
50
|
+
|
|
51
|
+
# ---------------------------------------------------------------------------
|
|
52
|
+
# Internal helpers
|
|
53
|
+
# ---------------------------------------------------------------------------
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _label(arg: "Argument", seed: int, used_optimize: bool) -> None:
|
|
57
|
+
"""Attach ``_benchmark_meta`` to *arg* in-place."""
|
|
58
|
+
from ..metrics import (
|
|
59
|
+
circuit_rank,
|
|
60
|
+
d_separated_groups,
|
|
61
|
+
max_depth,
|
|
62
|
+
posterior_range,
|
|
63
|
+
upstream_size,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
bn = arg.bn
|
|
67
|
+
p = arg.posterior(arg.target)
|
|
68
|
+
|
|
69
|
+
# posterior_range: exact=True on polytrees, outer bound on shared graphs.
|
|
70
|
+
# Report manipulability as the [min, max] achievable root posterior — far more
|
|
71
|
+
# informative than the width alone, since it shows *where* on [0,1] the judge can
|
|
72
|
+
# be pushed (e.g. a one-sided graph is pinned to one side of the 0.5 threshold).
|
|
73
|
+
pr = posterior_range(bn, arg.target, exact=True)
|
|
74
|
+
|
|
75
|
+
arg._benchmark_meta = {
|
|
76
|
+
"posterior": float(p),
|
|
77
|
+
"manipulability": [round(float(pr.lo), 4), round(float(pr.hi), 4)],
|
|
78
|
+
"manipulability_width": round(float(pr.hi - pr.lo), 4),
|
|
79
|
+
"d_sep_count": len(d_separated_groups(bn, arg.target)),
|
|
80
|
+
"max_depth_metric": max_depth(bn, arg.target),
|
|
81
|
+
"upstream_size": upstream_size(bn, arg.target),
|
|
82
|
+
"circuit_rank": circuit_rank(bn),
|
|
83
|
+
"seed": seed,
|
|
84
|
+
"is_exact_manipulability": bool(pr.is_exact),
|
|
85
|
+
"used_optimize_fallback": used_optimize,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _try_optimize(
|
|
90
|
+
arg: "Argument",
|
|
91
|
+
target: float,
|
|
92
|
+
posterior_tol: float,
|
|
93
|
+
param_limits: Optional[dict],
|
|
94
|
+
) -> "Optional[Argument]":
|
|
95
|
+
"""Try to hit *target* via ``optimize``; return the new arg or None on failure."""
|
|
96
|
+
try:
|
|
97
|
+
from .optimize import (
|
|
98
|
+
InfeasibleTargetError,
|
|
99
|
+
OptimizeError,
|
|
100
|
+
achievable_interval,
|
|
101
|
+
optimize,
|
|
102
|
+
)
|
|
103
|
+
except ImportError:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
lo, hi = achievable_interval(arg, param_limits=param_limits)
|
|
108
|
+
except Exception:
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
if not (lo - 1e-4 <= target <= hi + 1e-4):
|
|
112
|
+
return None # infeasible for this structure
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
return optimize(arg, target_posterior=target,
|
|
116
|
+
posterior_tol=posterior_tol,
|
|
117
|
+
param_limits=param_limits)
|
|
118
|
+
except (InfeasibleTargetError, OptimizeError):
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# ---------------------------------------------------------------------------
|
|
123
|
+
# Public API
|
|
124
|
+
# ---------------------------------------------------------------------------
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def generate_with_targets(
|
|
128
|
+
n: int,
|
|
129
|
+
structural,
|
|
130
|
+
posterior_targets: "list[float]",
|
|
131
|
+
posterior_tol: float = 0.02,
|
|
132
|
+
param_limits: "Optional[dict]" = None,
|
|
133
|
+
seeds: "Optional[list[int]]" = None,
|
|
134
|
+
max_attempts_per_graph: int = 200,
|
|
135
|
+
n_jobs: int = 1,
|
|
136
|
+
on_failure: str = "skip",
|
|
137
|
+
verbose: bool = False,
|
|
138
|
+
) -> "list[Argument]":
|
|
139
|
+
"""Generate *n* graphs, each tuned to a target root posterior.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
n : int
|
|
144
|
+
Number of graphs to produce.
|
|
145
|
+
structural : StructuralParams
|
|
146
|
+
Structural shape shared across all graphs.
|
|
147
|
+
posterior_targets : list[float]
|
|
148
|
+
Desired root posterior for each graph. Must have length *n*.
|
|
149
|
+
posterior_tol : float
|
|
150
|
+
Accepted absolute deviation from each target posterior (default 0.02).
|
|
151
|
+
param_limits : dict or None
|
|
152
|
+
Passed to ``optimize`` on the fallback path. None uses default boxes.
|
|
153
|
+
seeds : list[int] or None
|
|
154
|
+
Per-graph seeds. Defaults to ``list(range(n))``.
|
|
155
|
+
max_attempts_per_graph : int
|
|
156
|
+
Rejection budget for the primary ``generate`` path.
|
|
157
|
+
n_jobs : int
|
|
158
|
+
Worker count. Currently only 1 is fully supported for the optimize
|
|
159
|
+
fallback (JAX / fork safety, OQ3 in spec); n_jobs > 1 on the primary
|
|
160
|
+
path only is safe.
|
|
161
|
+
on_failure : str
|
|
162
|
+
``"skip"`` (default) | ``"warn"`` | ``"raise"``.
|
|
163
|
+
verbose : bool
|
|
164
|
+
Print progress per graph.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
list[Argument]
|
|
169
|
+
Accepted and tuned arguments. Each carries ``._benchmark_meta``.
|
|
170
|
+
Length <= *n* when ``on_failure != 'raise'``.
|
|
171
|
+
|
|
172
|
+
Notes
|
|
173
|
+
-----
|
|
174
|
+
The fallback threshold (when to call ``optimize`` instead of resampling) is:
|
|
175
|
+
whenever the bisection-only ``generate`` raises ``RuntimeError`` (budget
|
|
176
|
+
exhausted). ``optimize`` is then tried on an *unconstrained* graph from the
|
|
177
|
+
same seed; if the target is achievable in the parameter box, it succeeds.
|
|
178
|
+
"""
|
|
179
|
+
from .generate import DifficultyTargets, generate
|
|
180
|
+
|
|
181
|
+
if len(posterior_targets) != n:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"posterior_targets has {len(posterior_targets)} entries but n={n}"
|
|
184
|
+
)
|
|
185
|
+
if seeds is None:
|
|
186
|
+
seeds = list(range(n))
|
|
187
|
+
if len(seeds) != n:
|
|
188
|
+
raise ValueError(f"seeds has {len(seeds)} entries but n={n}")
|
|
189
|
+
|
|
190
|
+
accepted: list[Argument] = []
|
|
191
|
+
|
|
192
|
+
for i, (seed, t) in enumerate(zip(seeds, posterior_targets)):
|
|
193
|
+
if verbose:
|
|
194
|
+
print(f"[{i + 1}/{n}] seed={seed} target={t:.3f}", end=" ... ")
|
|
195
|
+
|
|
196
|
+
# --- primary path: bisection inside generate() -----------------------
|
|
197
|
+
arg: Optional[Argument] = None
|
|
198
|
+
used_optimize = False
|
|
199
|
+
try:
|
|
200
|
+
arg = generate(
|
|
201
|
+
seed=seed,
|
|
202
|
+
structural=structural,
|
|
203
|
+
targets=DifficultyTargets(
|
|
204
|
+
target_posterior=t,
|
|
205
|
+
posterior_tol=posterior_tol,
|
|
206
|
+
),
|
|
207
|
+
max_attempts=max_attempts_per_graph,
|
|
208
|
+
verbose=False,
|
|
209
|
+
)
|
|
210
|
+
except RuntimeError:
|
|
211
|
+
pass # bisection budget exhausted — try the optimize fallback
|
|
212
|
+
|
|
213
|
+
# --- fallback: generate unconstrained then optimize ------------------
|
|
214
|
+
if arg is None:
|
|
215
|
+
try:
|
|
216
|
+
base_arg = generate(
|
|
217
|
+
seed=seed,
|
|
218
|
+
structural=structural,
|
|
219
|
+
max_attempts=max_attempts_per_graph,
|
|
220
|
+
verbose=False,
|
|
221
|
+
)
|
|
222
|
+
opt_arg = _try_optimize(base_arg, t, posterior_tol, param_limits)
|
|
223
|
+
if opt_arg is not None:
|
|
224
|
+
arg = opt_arg
|
|
225
|
+
used_optimize = True
|
|
226
|
+
except RuntimeError:
|
|
227
|
+
pass
|
|
228
|
+
|
|
229
|
+
# --- failure handling ------------------------------------------------
|
|
230
|
+
if arg is None:
|
|
231
|
+
msg = (
|
|
232
|
+
f"generate_with_targets: seed {seed} could not reach "
|
|
233
|
+
f"target {t:.4f} via bisection or optimize."
|
|
234
|
+
)
|
|
235
|
+
if on_failure == "raise":
|
|
236
|
+
raise RuntimeError(msg)
|
|
237
|
+
if on_failure == "warn":
|
|
238
|
+
warnings.warn(msg, stacklevel=2)
|
|
239
|
+
if verbose:
|
|
240
|
+
print("FAILED")
|
|
241
|
+
continue
|
|
242
|
+
|
|
243
|
+
# --- label and collect -----------------------------------------------
|
|
244
|
+
_label(arg, seed=seed, used_optimize=used_optimize)
|
|
245
|
+
accepted.append(arg)
|
|
246
|
+
if verbose:
|
|
247
|
+
m = arg._benchmark_meta
|
|
248
|
+
print(
|
|
249
|
+
f"ok P={m['posterior']:.3f} manip={m['manipulability']:.3f} "
|
|
250
|
+
f"nodes={m['upstream_size'] + 1} "
|
|
251
|
+
f"{'[opt]' if used_optimize else '[bisect]'}"
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return accepted
|
|
@@ -67,6 +67,7 @@ def _forward(arg: "Argument"):
|
|
|
67
67
|
current values (logit priors, log LRs), and `meta` labels each entry."""
|
|
68
68
|
jax, jnp = _require_jax()
|
|
69
69
|
from ..core import IndependentEvidenceCPD, NoisyOrCPD
|
|
70
|
+
from ..core.cpd.noisy_and import NoisyAndCPD
|
|
70
71
|
|
|
71
72
|
bn = arg.bn
|
|
72
73
|
nodes = list(bn.nodes) # topological (inputs before node)
|
|
@@ -105,6 +106,12 @@ def _forward(arg: "Argument"):
|
|
|
105
106
|
a = jnp.asarray(cpd.activations)
|
|
106
107
|
pin = jnp.stack([m[i] for i in ins])
|
|
107
108
|
m[n] = 1 - (1 - cpd.leak) * jnp.prod(1 - a * pin)
|
|
109
|
+
elif isinstance(cpd, NoisyAndCPD):
|
|
110
|
+
# NoisyAnd activations are fixed structural constants (all 1.0 from
|
|
111
|
+
# the ASPIC compiler); they are NOT free parameters in theta.
|
|
112
|
+
a = jnp.asarray(cpd.activations)
|
|
113
|
+
pin = jnp.stack([m[i] for i in ins])
|
|
114
|
+
m[n] = (1 - cpd.leak) * jnp.prod(a * pin)
|
|
108
115
|
elif _is_and_not_splice(cpd):
|
|
109
116
|
e, u = ins
|
|
110
117
|
m[n] = m[e] * (1 - m[u])
|
|
@@ -40,6 +40,8 @@ def _upstream(c: "Node") -> list["Node"]:
|
|
|
40
40
|
out.extend(src for src, _lr, _kind in getattr(c, "_edges", []))
|
|
41
41
|
out.extend(getattr(c, "_strict", []))
|
|
42
42
|
out.extend(by for _source, by in getattr(c, "_undercuts", []))
|
|
43
|
+
for srcs, _lr, _leak, _rule in getattr(c, "_conjunctions", []):
|
|
44
|
+
out.extend(srcs)
|
|
43
45
|
return out
|
|
44
46
|
|
|
45
47
|
|
|
@@ -86,7 +88,12 @@ def _validate(claims: list["Node"]) -> None:
|
|
|
86
88
|
if not isinstance(c, _Claim):
|
|
87
89
|
continue
|
|
88
90
|
|
|
89
|
-
if
|
|
91
|
+
if (
|
|
92
|
+
c.role == "conclusion"
|
|
93
|
+
and not c._edges
|
|
94
|
+
and not c._strict
|
|
95
|
+
and not getattr(c, "_conjunctions", [])
|
|
96
|
+
):
|
|
90
97
|
warnings.warn(
|
|
91
98
|
f"{c.name!r} is a conclusion with no incoming argument; a leaf "
|
|
92
99
|
"should be a Premise.",
|
|
@@ -160,21 +167,32 @@ def _lower(c: "_Claim") -> None:
|
|
|
160
167
|
|
|
161
168
|
defeasible = [(resolve(src), lr) for src, lr, _kind in c._edges]
|
|
162
169
|
strict = [resolve(src) for src in c._strict]
|
|
170
|
+
conjunctions = getattr(c, "_conjunctions", [])
|
|
171
|
+
|
|
172
|
+
# `host` carries the defeasible inputs: the conclusion itself, or a hidden `D`
|
|
173
|
+
# when strict edges divorce it (parent-divorcing).
|
|
174
|
+
host = Node(f"{c.name}/defeasible", prior=c.prior) if strict else c
|
|
175
|
+
|
|
176
|
+
for src, lr in defeasible:
|
|
177
|
+
host.add_input(src, lr=lr)
|
|
178
|
+
|
|
179
|
+
# Conjunctive support: one hidden NoisyAnd gate per group, fed into `host` as a
|
|
180
|
+
# single lr support. The gate is the logical AND (each source at activation 1,
|
|
181
|
+
# firing iff all hold); `leak` is the rule's fallibility; `lr` is its strength.
|
|
182
|
+
for k, (srcs, lr, leak, _rule) in enumerate(conjunctions):
|
|
183
|
+
gate = Node(f"{c.name}/conj[{k}]")
|
|
184
|
+
gate.noisy_and(leak=leak)
|
|
185
|
+
for src in srcs:
|
|
186
|
+
gate.add_input(resolve(src), activation=1.0)
|
|
187
|
+
host.add_input(gate, lr=lr)
|
|
188
|
+
|
|
189
|
+
_apply_correlations(c, host, resolve)
|
|
163
190
|
|
|
164
191
|
if strict:
|
|
165
|
-
# parent-divorcing: D holds the defeasible part (or just the prior).
|
|
166
|
-
d = Node(f"{c.name}/defeasible", prior=c.prior)
|
|
167
|
-
for src, lr in defeasible:
|
|
168
|
-
d.add_input(src, lr=lr)
|
|
169
|
-
_apply_correlations(c, d, resolve)
|
|
170
192
|
c.noisy_or(leak=0.0)
|
|
171
193
|
for src in strict:
|
|
172
194
|
c.add_input(src, activation=1.0)
|
|
173
|
-
c.add_input(
|
|
174
|
-
else:
|
|
175
|
-
for src, lr in defeasible:
|
|
176
|
-
c.add_input(src, lr=lr)
|
|
177
|
-
_apply_correlations(c, c, resolve)
|
|
195
|
+
c.add_input(host, activation=1.0)
|
|
178
196
|
|
|
179
197
|
|
|
180
198
|
def _apply_correlations(c: "_Claim", host: "Node", resolve) -> None:
|
|
@@ -23,11 +23,16 @@ prior to a *target posterior value* by bisection, grows to a target claim count
|
|
|
23
23
|
constructively, and screens realized depth. The branch-level methods
|
|
24
24
|
(`add_support_branch` / `add_attack_branch`) are usable directly to hand-script a
|
|
25
25
|
template.
|
|
26
|
+
|
|
27
|
+
Phase-1 batch generation: use `StructuralParams.independent_only()` to get the
|
|
28
|
+
baseline preset (no strict/undercut/undermine/axiom edges) and `generate_batch` to
|
|
29
|
+
produce N graphs in parallel with isolated per-graph seeds.
|
|
26
30
|
"""
|
|
27
31
|
from __future__ import annotations
|
|
28
32
|
|
|
29
33
|
import math
|
|
30
34
|
import random
|
|
35
|
+
import warnings
|
|
31
36
|
from dataclasses import dataclass
|
|
32
37
|
from typing import Optional
|
|
33
38
|
|
|
@@ -71,6 +76,28 @@ class StructuralParams:
|
|
|
71
76
|
def n_groups(self) -> int:
|
|
72
77
|
return self.n_support + self.n_attack
|
|
73
78
|
|
|
79
|
+
@classmethod
|
|
80
|
+
def independent_only(cls, **overrides) -> "StructuralParams":
|
|
81
|
+
"""Phase-1 preset: purely independent-evidence CPDs.
|
|
82
|
+
|
|
83
|
+
Pins ``strict_prob = undercut_prob = undermine_prob = axiom_prob = 0`` so
|
|
84
|
+
every node compiles to an ``IndependentEvidenceCPD`` — no NoisyOr splices,
|
|
85
|
+
no undercut AND-NOT gates, no axioms. The graph is a polytree of defeasible
|
|
86
|
+
support / rebut edges only, and every metric is exact.
|
|
87
|
+
|
|
88
|
+
The d-sep guardrail is satisfied by keeping ``n_support + n_attack`` in
|
|
89
|
+
[1, 3] (the default is 2+1 = 3 groups). Pass overrides to relax any field,
|
|
90
|
+
e.g. ``StructuralParams.independent_only(n_support=1, n_attack=0)``.
|
|
91
|
+
"""
|
|
92
|
+
defaults = dict(
|
|
93
|
+
strict_prob=0.0,
|
|
94
|
+
undercut_prob=0.0,
|
|
95
|
+
undermine_prob=0.0,
|
|
96
|
+
axiom_prob=0.0,
|
|
97
|
+
)
|
|
98
|
+
defaults.update(overrides)
|
|
99
|
+
return cls(**defaults)
|
|
100
|
+
|
|
74
101
|
def __post_init__(self):
|
|
75
102
|
if self.n_support < 0 or self.n_attack < 0:
|
|
76
103
|
raise ValueError("n_support and n_attack must be >= 0")
|
|
@@ -93,6 +120,46 @@ class DifficultyTargets:
|
|
|
93
120
|
min_manipulability: Optional[float] = None
|
|
94
121
|
min_depth: Optional[int] = None # realized longest input-path into the root
|
|
95
122
|
d_sep_groups: Optional[int] = None # realized d-separated group count (for share mode)
|
|
123
|
+
max_d_sep_groups: Optional[int] = None # upper bound on d-sep group count
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class BatchParams:
|
|
128
|
+
"""Parameters for batch generation via ``generate_batch``.
|
|
129
|
+
|
|
130
|
+
Attributes
|
|
131
|
+
----------
|
|
132
|
+
n : int
|
|
133
|
+
Number of graphs to generate.
|
|
134
|
+
seeds : list[int] or None
|
|
135
|
+
Explicit per-graph seeds (length must equal ``n`` when provided).
|
|
136
|
+
Defaults to ``list(range(n))`` so seed i produces graph i.
|
|
137
|
+
max_attempts_per_graph : int
|
|
138
|
+
Rejection budget passed to each ``generate()`` call.
|
|
139
|
+
on_failure : str
|
|
140
|
+
One of ``"skip"`` (default), ``"warn"``, or ``"raise"``.
|
|
141
|
+
Controls what happens when a seed exhausts its budget.
|
|
142
|
+
n_jobs : int
|
|
143
|
+
Number of parallel worker processes (``multiprocessing.Pool``).
|
|
144
|
+
Default 1 (serial). See OQ3 in graph-generation-spec.md before
|
|
145
|
+
enabling parallelism with an optimize fallback.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
n: int = 1000
|
|
149
|
+
seeds: Optional[list] = None
|
|
150
|
+
max_attempts_per_graph: int = 200
|
|
151
|
+
on_failure: str = "skip"
|
|
152
|
+
n_jobs: int = 1
|
|
153
|
+
|
|
154
|
+
def __post_init__(self):
|
|
155
|
+
if self.on_failure not in ("skip", "warn", "raise"):
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"on_failure must be 'skip', 'warn', or 'raise'; got {self.on_failure!r}"
|
|
158
|
+
)
|
|
159
|
+
if self.seeds is not None and len(self.seeds) != self.n:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f"seeds has {len(self.seeds)} entries but n={self.n}"
|
|
162
|
+
)
|
|
96
163
|
|
|
97
164
|
|
|
98
165
|
class ArgumentGenerator:
|
|
@@ -205,10 +272,16 @@ class ArgumentGenerator:
|
|
|
205
272
|
"""Existing nodes that may legally become a new parent of `conclusion`:
|
|
206
273
|
matching role, not the root, not already a source of it, not inside an
|
|
207
274
|
undercutter subgraph, and not an ancestor of it (which would make a directed
|
|
208
|
-
cycle).
|
|
275
|
+
cycle). Candidates are drawn from `conclusion`'s OWN branch — its already-built
|
|
276
|
+
upstream subtree (`_reachable(conclusion)`) — never from sibling branches, so a
|
|
277
|
+
reused node creates a within-branch reconvergence but never merges two
|
|
278
|
+
d-separated branches at the root. (Bug fix 2026-06-26: this used to iterate
|
|
279
|
+
`_reachable(self.root)`; because a branch is attached to the root only *after*
|
|
280
|
+
it is built, that pool excluded the current branch and pulled from previously
|
|
281
|
+
built sibling branches, collapsing root-level d-separation.)"""
|
|
209
282
|
existing = {s for s, _lr, _k in conclusion._edges} | set(conclusion._strict)
|
|
210
283
|
out = []
|
|
211
|
-
for n in _reachable(
|
|
284
|
+
for n in _reachable(conclusion): # within this branch only
|
|
212
285
|
if not isinstance(n, role) or n is self.root or n is conclusion:
|
|
213
286
|
continue
|
|
214
287
|
if n in self._uc_nodes or n in existing:
|
|
@@ -363,8 +436,12 @@ def generate(seed: Optional[int] = None, *,
|
|
|
363
436
|
# cycle, so fall back to the exact solver (keep share-mode graphs small).
|
|
364
437
|
p_root = (LoopySolver if is_polytree(bn) else ExactSolver)(bn).prob(arg.target, 1)
|
|
365
438
|
|
|
366
|
-
|
|
367
|
-
|
|
439
|
+
n_dsep = len(d_separated_groups(bn, arg.target))
|
|
440
|
+
if tgt.d_sep_groups is not None and n_dsep != tgt.d_sep_groups:
|
|
441
|
+
continue
|
|
442
|
+
if tgt.max_d_sep_groups is not None and n_dsep > tgt.max_d_sep_groups:
|
|
443
|
+
if verbose:
|
|
444
|
+
print(f"attempt {attempt}: d_sep_groups {n_dsep} > max {tgt.max_d_sep_groups}")
|
|
368
445
|
continue
|
|
369
446
|
if tgt.posterior_side == "above" and not p_root > tgt.threshold:
|
|
370
447
|
continue
|
|
@@ -395,3 +472,99 @@ def generate(seed: Optional[int] = None, *,
|
|
|
395
472
|
"(move threshold toward 0.5, widen posterior_tol, lower min_manipulability/"
|
|
396
473
|
"min_depth) or raise max_attempts"
|
|
397
474
|
)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def generate_batch(
|
|
478
|
+
batch: BatchParams,
|
|
479
|
+
structural: Optional[StructuralParams] = None,
|
|
480
|
+
targets: Optional[DifficultyTargets] = None,
|
|
481
|
+
verbose: bool = False,
|
|
482
|
+
) -> "tuple[list[Argument], list[tuple[int, Exception]]]":
|
|
483
|
+
"""Batch-generate ``batch.n`` ``Argument`` objects.
|
|
484
|
+
|
|
485
|
+
Each graph is generated by ``generate(seed=s, structural=structural,
|
|
486
|
+
targets=targets, max_attempts=batch.max_attempts_per_graph)``. Seeds default
|
|
487
|
+
to ``0, 1, …, n-1`` or use ``batch.seeds`` for explicit per-graph seeds.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
batch : BatchParams
|
|
492
|
+
Controls count, seeds, failure policy, and parallelism.
|
|
493
|
+
structural : StructuralParams or None
|
|
494
|
+
Structural shape. None uses ``StructuralParams()`` defaults.
|
|
495
|
+
targets : DifficultyTargets or None
|
|
496
|
+
Difficulty filters. None uses unconstrained defaults.
|
|
497
|
+
verbose : bool
|
|
498
|
+
Forward to each ``generate`` call for per-attempt logging.
|
|
499
|
+
|
|
500
|
+
Returns
|
|
501
|
+
-------
|
|
502
|
+
args : list[Argument]
|
|
503
|
+
Accepted arguments, length <= ``batch.n`` when ``on_failure='skip'``.
|
|
504
|
+
failures : list[tuple[int, Exception]]
|
|
505
|
+
(seed, error) pairs for seeds that exhausted their budget.
|
|
506
|
+
Always empty when ``on_failure='raise'``.
|
|
507
|
+
|
|
508
|
+
Raises
|
|
509
|
+
------
|
|
510
|
+
RuntimeError
|
|
511
|
+
If any seed fails and ``batch.on_failure == 'raise'``.
|
|
512
|
+
"""
|
|
513
|
+
seeds = batch.seeds if batch.seeds is not None else list(range(batch.n))
|
|
514
|
+
|
|
515
|
+
def _one(seed: int):
|
|
516
|
+
try:
|
|
517
|
+
return generate(
|
|
518
|
+
seed=seed,
|
|
519
|
+
structural=structural,
|
|
520
|
+
targets=targets,
|
|
521
|
+
max_attempts=batch.max_attempts_per_graph,
|
|
522
|
+
verbose=verbose,
|
|
523
|
+
)
|
|
524
|
+
except Exception as exc:
|
|
525
|
+
return exc
|
|
526
|
+
|
|
527
|
+
accepted: list[Argument] = []
|
|
528
|
+
failures: list[tuple[int, Exception]] = []
|
|
529
|
+
|
|
530
|
+
if batch.n_jobs == 1:
|
|
531
|
+
for s in seeds:
|
|
532
|
+
result = _one(s)
|
|
533
|
+
if isinstance(result, Exception):
|
|
534
|
+
failures.append((s, result))
|
|
535
|
+
if batch.on_failure == "raise":
|
|
536
|
+
raise result
|
|
537
|
+
if batch.on_failure == "warn":
|
|
538
|
+
warnings.warn(
|
|
539
|
+
f"generate_batch: seed {s} failed: {result}",
|
|
540
|
+
stacklevel=2,
|
|
541
|
+
)
|
|
542
|
+
else:
|
|
543
|
+
accepted.append(result)
|
|
544
|
+
else:
|
|
545
|
+
# Parallelise with spawn-safe multiprocessing (avoid JAX fork issues, OQ3).
|
|
546
|
+
import multiprocessing as mp
|
|
547
|
+
|
|
548
|
+
ctx = mp.get_context("spawn")
|
|
549
|
+
with ctx.Pool(batch.n_jobs) as pool:
|
|
550
|
+
results = pool.map(_one, seeds)
|
|
551
|
+
for s, result in zip(seeds, results):
|
|
552
|
+
if isinstance(result, Exception):
|
|
553
|
+
failures.append((s, result))
|
|
554
|
+
if batch.on_failure == "raise":
|
|
555
|
+
raise result
|
|
556
|
+
if batch.on_failure == "warn":
|
|
557
|
+
warnings.warn(
|
|
558
|
+
f"generate_batch: seed {s} failed: {result}",
|
|
559
|
+
stacklevel=2,
|
|
560
|
+
)
|
|
561
|
+
else:
|
|
562
|
+
accepted.append(result)
|
|
563
|
+
|
|
564
|
+
if verbose and failures:
|
|
565
|
+
print(
|
|
566
|
+
f"generate_batch: {len(failures)} seeds failed, "
|
|
567
|
+
f"{len(accepted)} accepted out of {batch.n} requested."
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
return accepted, failures
|