probability-flow 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {probability_flow-0.3.0 → probability_flow-0.4.0}/.gitignore +3 -0
  2. {probability_flow-0.3.0 → probability_flow-0.4.0}/PKG-INFO +2 -1
  3. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/__init__.py +3 -0
  4. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/argument.py +57 -2
  5. probability_flow-0.4.0/probability_flow/aspic/benchmark.py +254 -0
  6. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/calibrate.py +7 -0
  7. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/compile.py +29 -11
  8. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/generate.py +177 -4
  9. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/handle.py +25 -2
  10. probability_flow-0.4.0/probability_flow/aspic/optimize.py +514 -0
  11. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/aspic/visualization.py +17 -1
  12. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/correlated_evidence.py +3 -1
  13. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/__init__.py +8 -1
  14. probability_flow-0.4.0/probability_flow/metrics/manipulability.py +474 -0
  15. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/layout.py +184 -62
  16. {probability_flow-0.3.0 → probability_flow-0.4.0}/pyproject.toml +12 -2
  17. probability_flow-0.3.0/probability_flow/metrics/manipulability.py +0 -207
  18. {probability_flow-0.3.0 → probability_flow-0.4.0}/LICENSE +0 -0
  19. {probability_flow-0.3.0 → probability_flow-0.4.0}/README.md +0 -0
  20. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/__init__.py +0 -0
  21. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/__init__.py +0 -0
  22. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/_logmath.py +0 -0
  23. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/bp/__init__.py +0 -0
  24. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/bp/engine.py +0 -0
  25. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/bp/message.py +0 -0
  26. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/__init__.py +0 -0
  27. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/base.py +0 -0
  28. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/independent_evidence.py +0 -0
  29. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_and.py +0 -0
  30. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_or.py +0 -0
  31. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/cpd/tabular.py +0 -0
  32. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/exact.py +0 -0
  33. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/network.py +0 -0
  34. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/core/node.py +0 -0
  35. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/_util.py +0 -0
  36. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/difficulty.py +0 -0
  37. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/dseparation.py +0 -0
  38. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/loopiness.py +0 -0
  39. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/metrics/structure.py +0 -0
  40. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/py.typed +0 -0
  41. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/__init__.py +0 -0
  42. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/animate.py +0 -0
  43. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/image.py +0 -0
  44. {probability_flow-0.3.0 → probability_flow-0.4.0}/probability_flow/visualization/style.py +0 -0
@@ -6,3 +6,6 @@ __pycache__/
6
6
  *.egg-info/
7
7
  .DS_Store
8
8
  _previews/
9
+
10
+ # Local secrets
11
+ .env
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: probability-flow
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: A from-scratch, modular discrete Bayesian-network library.
5
5
  Project-URL: Homepage, https://github.com/scalable-oversight-benchmarks/probability-flow
6
6
  Project-URL: Repository, https://github.com/scalable-oversight-benchmarks/probability-flow
@@ -28,6 +28,7 @@ Requires-Dist: pytest; extra == 'dev'
28
28
  Requires-Dist: ruff; extra == 'dev'
29
29
  Provides-Extra: jax
30
30
  Requires-Dist: jax; extra == 'jax'
31
+ Requires-Dist: scipy; extra == 'jax'
31
32
  Provides-Extra: viz
32
33
  Requires-Dist: matplotlib; extra == 'viz'
33
34
  Description-Content-Type: text/markdown
@@ -9,13 +9,16 @@ pure-BN `core`. Build an argument out of premises, conclusions, and attacks, the
9
9
  from .argument import ArgumentWarning, Axiom, Conclusion, Premise
10
10
  from .generate import (
11
11
  ArgumentGenerator,
12
+ BatchParams,
12
13
  DifficultyTargets,
13
14
  StructuralParams,
14
15
  generate,
16
+ generate_batch,
15
17
  )
16
18
  from .handle import Argument
17
19
 
18
20
  __all__ = [
19
21
  "Premise", "Axiom", "Conclusion", "ArgumentWarning", "Argument",
20
22
  "ArgumentGenerator", "StructuralParams", "DifficultyTargets", "generate",
23
+ "BatchParams", "generate_batch",
21
24
  ]
@@ -47,6 +47,11 @@ class _Claim(Node):
47
47
  # (source_a, source_b, J): residual correlation between two defeasible
48
48
  # sources of this conclusion, lowered to a CorrelatedEvidenceCPD coupling.
49
49
  self._correlations: list[tuple["Node", "Node", "float | None"]] = []
50
+ # (sources, lr, leak, rule): a conjunctive support group — the sources
51
+ # JOINTLY support this conclusion through one noisy-AND inference (the
52
+ # deduction rule), lowered to a hidden NoisyAnd gate fed in as one lr.
53
+ self._conjunctions: list[
54
+ tuple[list["Node"], "float | None", float, "str | None"]] = []
50
55
  self._no_undermine = False
51
56
  # opaque, serialized but uninterpreted by the library (see docs/aspic.md):
52
57
  self.desc: str | None = None # a longer description
@@ -56,8 +61,13 @@ class _Claim(Node):
56
61
  """At most one argumentative edge joins a pair of claims. This is what lets
57
62
  an undercut be addressed by its endpoints and keeps serialized edge ids
58
63
  (`source->target`) unique."""
59
- if any(s is src for s, _lr, _kind in self._edges) or any(
60
- s is src for s in self._strict
64
+ in_conj = any(
65
+ s is src for srcs, _lr, _leak, _rule in self._conjunctions for s in srcs
66
+ )
67
+ if (
68
+ any(s is src for s, _lr, _kind in self._edges)
69
+ or any(s is src for s in self._strict)
70
+ or in_conj
61
71
  ):
62
72
  raise ValueError(
63
73
  f"{self.name!r} already has an edge from {src.name!r}; at most one "
@@ -80,6 +90,51 @@ class _Claim(Node):
80
90
  self._edges.append((src, None if lr is None else float(lr), "support"))
81
91
  return src
82
92
 
93
+ def support_all(
94
+ self,
95
+ srcs: "list[Node]",
96
+ lr: "float | None" = None,
97
+ leak: float = 0.0,
98
+ rule: "str | None" = None,
99
+ ) -> "list[Node]":
100
+ """Conjunctive support: `srcs` JOINTLY support this conclusion through a
101
+ single noisy-AND inference — the deduction fires only when *all* the sources
102
+ hold. This is the faithful mapping of a deductive entailment step: N
103
+ antecedent siblings plus one inference rule yield one deduced conclusion,
104
+ the shape of every node in a MuSR-style reasoning tree. Returns `srcs`, so
105
+ inline sources can be built further upstream (as with `support`).
106
+
107
+ Lowered (see `compile.py`) to a hidden `NoisyAnd` gate over `srcs`, each at
108
+ activation 1 — the gate is the logical AND, firing iff every source holds —
109
+ fed into this conclusion as a single `lr` support. So all the evidential
110
+ strength is the one `lr > 1` (where the inference rule's re-elicited weight
111
+ lands), and the conclusion keeps its prior and composes with ordinary
112
+ `support` / `rebut` edges. `leak` is the rule's fallibility:
113
+ `P(fire | all sources hold) = 1 - leak`. `rule` is an optional label for the
114
+ inference licence (carried for audit / rendering, not interpreted). `lr` may
115
+ be omitted (`None`) to scaffold the group before its strength is known.
116
+ """
117
+ srcs = list(srcs)
118
+ if len(srcs) < 2:
119
+ raise ValueError(
120
+ "support_all needs >= 2 sources (it is a conjunction); use support "
121
+ "for a single argument"
122
+ )
123
+ if len({id(s) for s in srcs}) != len(srcs):
124
+ raise ValueError("support_all sources must be distinct")
125
+ if lr is not None and not lr > 1:
126
+ raise ValueError(
127
+ f"support_all lr must be > 1 (got {lr}); it is conjunctive support"
128
+ )
129
+ if not 0.0 <= leak <= 1.0:
130
+ raise ValueError("leak must be a probability in [0, 1]")
131
+ for src in srcs:
132
+ self._require_new_edge(src)
133
+ self._conjunctions.append(
134
+ (srcs, None if lr is None else float(lr), float(leak), rule)
135
+ )
136
+ return srcs
137
+
83
138
  def rebut(self, src: "Node", lr: "float | None" = None) -> "Node":
84
139
  """Add a defeasible argument *against* this conclusion (`0 < lr < 1`). `lr`
85
140
  may be omitted (`None`) to declare the edge before its strength is known."""
@@ -0,0 +1,254 @@
1
+ """Benchmark helpers: generate → optimize bridge for large-batch posterior targeting.
2
+
3
+ This module isolates the JAX/scipy dependency (via ``optimize``) from the pure-Python
4
+ ``generate.py``. The primary entry point is ``generate_with_targets``, which produces
5
+ N argument graphs each tuned to a requested root posterior.
6
+
7
+ Strategy
8
+ --------
9
+ For each (seed, target_posterior) pair:
10
+
11
+ 1. Try ``generate(seed, structural, DifficultyTargets(target_posterior=t))`` — this
12
+ bisects the root prior to hit *t* cheaply (no JAX needed).
13
+ 2. If the bisection cannot reach *t* for the drawn structure (``RuntimeError`` from
14
+ the rejection loop), fall back to:
15
+ a. Generate an *unconstrained* graph from the same seed.
16
+ b. Check whether *t* is inside ``achievable_interval(arg)``.
17
+ c. If yes, call ``optimize(arg, target_posterior=t)`` to tune LRs.
18
+ 3. If neither path succeeds, record the failure (skip / warn / raise per
19
+ ``on_failure``).
20
+
21
+ Both paths return an ``Argument`` with identical public interface. The ``optimize``
22
+ fallback is more expensive (JAX JIT + SLSQP) but reaches extreme posteriors that
23
+ bisection cannot.
24
+
25
+ Batch labeling
26
+ --------------
27
+ Each accepted argument is annotated with a ``_benchmark_meta`` attribute (a plain
28
+ dict) carrying the difficulty scalars the benchmark consumer (``debate-eval``) needs:
29
+
30
+ posterior, manipulability, d_sep_count, max_depth_metric,
31
+ upstream_size, circuit_rank, seed, is_exact_manipulability,
32
+ used_optimize_fallback
33
+
34
+ These come entirely from existing metrics — no new computation.
35
+
36
+ JAX/scipy availability
37
+ ----------------------
38
+ Both are required for the ``optimize`` fallback path. When they are absent the
39
+ module still imports cleanly; ``generate_with_targets`` will error only if an
40
+ optimize fallback is actually needed. Use ``generate_batch`` (no JAX) when you only
41
+ need bisection-reachable posteriors.
42
+ """
43
+ from __future__ import annotations
44
+
45
+ import warnings
46
+ from typing import TYPE_CHECKING, Optional
47
+
48
+ if TYPE_CHECKING:
49
+ from .handle import Argument
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Internal helpers
53
+ # ---------------------------------------------------------------------------
54
+
55
+
56
+ def _label(arg: "Argument", seed: int, used_optimize: bool) -> None:
57
+ """Attach ``_benchmark_meta`` to *arg* in-place."""
58
+ from ..metrics import (
59
+ circuit_rank,
60
+ d_separated_groups,
61
+ max_depth,
62
+ posterior_range,
63
+ upstream_size,
64
+ )
65
+
66
+ bn = arg.bn
67
+ p = arg.posterior(arg.target)
68
+
69
+ # posterior_range: exact=True on polytrees, outer bound on shared graphs.
70
+ # Report manipulability as the [min, max] achievable root posterior — far more
71
+ # informative than the width alone, since it shows *where* on [0,1] the judge can
72
+ # be pushed (e.g. a one-sided graph is pinned to one side of the 0.5 threshold).
73
+ pr = posterior_range(bn, arg.target, exact=True)
74
+
75
+ arg._benchmark_meta = {
76
+ "posterior": float(p),
77
+ "manipulability": [round(float(pr.lo), 4), round(float(pr.hi), 4)],
78
+ "manipulability_width": round(float(pr.hi - pr.lo), 4),
79
+ "d_sep_count": len(d_separated_groups(bn, arg.target)),
80
+ "max_depth_metric": max_depth(bn, arg.target),
81
+ "upstream_size": upstream_size(bn, arg.target),
82
+ "circuit_rank": circuit_rank(bn),
83
+ "seed": seed,
84
+ "is_exact_manipulability": bool(pr.is_exact),
85
+ "used_optimize_fallback": used_optimize,
86
+ }
87
+
88
+
89
+ def _try_optimize(
90
+ arg: "Argument",
91
+ target: float,
92
+ posterior_tol: float,
93
+ param_limits: Optional[dict],
94
+ ) -> "Optional[Argument]":
95
+ """Try to hit *target* via ``optimize``; return the new arg or None on failure."""
96
+ try:
97
+ from .optimize import (
98
+ InfeasibleTargetError,
99
+ OptimizeError,
100
+ achievable_interval,
101
+ optimize,
102
+ )
103
+ except ImportError:
104
+ return None
105
+
106
+ try:
107
+ lo, hi = achievable_interval(arg, param_limits=param_limits)
108
+ except Exception:
109
+ return None
110
+
111
+ if not (lo - 1e-4 <= target <= hi + 1e-4):
112
+ return None # infeasible for this structure
113
+
114
+ try:
115
+ return optimize(arg, target_posterior=target,
116
+ posterior_tol=posterior_tol,
117
+ param_limits=param_limits)
118
+ except (InfeasibleTargetError, OptimizeError):
119
+ return None
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Public API
124
+ # ---------------------------------------------------------------------------
125
+
126
+
127
+ def generate_with_targets(
128
+ n: int,
129
+ structural,
130
+ posterior_targets: "list[float]",
131
+ posterior_tol: float = 0.02,
132
+ param_limits: "Optional[dict]" = None,
133
+ seeds: "Optional[list[int]]" = None,
134
+ max_attempts_per_graph: int = 200,
135
+ n_jobs: int = 1,
136
+ on_failure: str = "skip",
137
+ verbose: bool = False,
138
+ ) -> "list[Argument]":
139
+ """Generate *n* graphs, each tuned to a target root posterior.
140
+
141
+ Parameters
142
+ ----------
143
+ n : int
144
+ Number of graphs to produce.
145
+ structural : StructuralParams
146
+ Structural shape shared across all graphs.
147
+ posterior_targets : list[float]
148
+ Desired root posterior for each graph. Must have length *n*.
149
+ posterior_tol : float
150
+ Accepted absolute deviation from each target posterior (default 0.02).
151
+ param_limits : dict or None
152
+ Passed to ``optimize`` on the fallback path. None uses default boxes.
153
+ seeds : list[int] or None
154
+ Per-graph seeds. Defaults to ``list(range(n))``.
155
+ max_attempts_per_graph : int
156
+ Rejection budget for the primary ``generate`` path.
157
+ n_jobs : int
158
+ Worker count. Currently only 1 is fully supported for the optimize
159
+ fallback (JAX / fork safety, OQ3 in spec); n_jobs > 1 on the primary
160
+ path only is safe.
161
+ on_failure : str
162
+ ``"skip"`` (default) | ``"warn"`` | ``"raise"``.
163
+ verbose : bool
164
+ Print progress per graph.
165
+
166
+ Returns
167
+ -------
168
+ list[Argument]
169
+ Accepted and tuned arguments. Each carries ``._benchmark_meta``.
170
+ Length <= *n* when ``on_failure != 'raise'``.
171
+
172
+ Notes
173
+ -----
174
+ The fallback threshold (when to call ``optimize`` instead of resampling) is:
175
+ whenever the bisection-only ``generate`` raises ``RuntimeError`` (budget
176
+ exhausted). ``optimize`` is then tried on an *unconstrained* graph from the
177
+ same seed; if the target is achievable in the parameter box, it succeeds.
178
+ """
179
+ from .generate import DifficultyTargets, generate
180
+
181
+ if len(posterior_targets) != n:
182
+ raise ValueError(
183
+ f"posterior_targets has {len(posterior_targets)} entries but n={n}"
184
+ )
185
+ if seeds is None:
186
+ seeds = list(range(n))
187
+ if len(seeds) != n:
188
+ raise ValueError(f"seeds has {len(seeds)} entries but n={n}")
189
+
190
+ accepted: list[Argument] = []
191
+
192
+ for i, (seed, t) in enumerate(zip(seeds, posterior_targets)):
193
+ if verbose:
194
+ print(f"[{i + 1}/{n}] seed={seed} target={t:.3f}", end=" ... ")
195
+
196
+ # --- primary path: bisection inside generate() -----------------------
197
+ arg: Optional[Argument] = None
198
+ used_optimize = False
199
+ try:
200
+ arg = generate(
201
+ seed=seed,
202
+ structural=structural,
203
+ targets=DifficultyTargets(
204
+ target_posterior=t,
205
+ posterior_tol=posterior_tol,
206
+ ),
207
+ max_attempts=max_attempts_per_graph,
208
+ verbose=False,
209
+ )
210
+ except RuntimeError:
211
+ pass # bisection budget exhausted — try the optimize fallback
212
+
213
+ # --- fallback: generate unconstrained then optimize ------------------
214
+ if arg is None:
215
+ try:
216
+ base_arg = generate(
217
+ seed=seed,
218
+ structural=structural,
219
+ max_attempts=max_attempts_per_graph,
220
+ verbose=False,
221
+ )
222
+ opt_arg = _try_optimize(base_arg, t, posterior_tol, param_limits)
223
+ if opt_arg is not None:
224
+ arg = opt_arg
225
+ used_optimize = True
226
+ except RuntimeError:
227
+ pass
228
+
229
+ # --- failure handling ------------------------------------------------
230
+ if arg is None:
231
+ msg = (
232
+ f"generate_with_targets: seed {seed} could not reach "
233
+ f"target {t:.4f} via bisection or optimize."
234
+ )
235
+ if on_failure == "raise":
236
+ raise RuntimeError(msg)
237
+ if on_failure == "warn":
238
+ warnings.warn(msg, stacklevel=2)
239
+ if verbose:
240
+ print("FAILED")
241
+ continue
242
+
243
+ # --- label and collect -----------------------------------------------
244
+ _label(arg, seed=seed, used_optimize=used_optimize)
245
+ accepted.append(arg)
246
+ if verbose:
247
+ m = arg._benchmark_meta
248
+ print(
249
+ f"ok P={m['posterior']:.3f} manip={m['manipulability']:.3f} "
250
+ f"nodes={m['upstream_size'] + 1} "
251
+ f"{'[opt]' if used_optimize else '[bisect]'}"
252
+ )
253
+
254
+ return accepted
@@ -67,6 +67,7 @@ def _forward(arg: "Argument"):
67
67
  current values (logit priors, log LRs), and `meta` labels each entry."""
68
68
  jax, jnp = _require_jax()
69
69
  from ..core import IndependentEvidenceCPD, NoisyOrCPD
70
+ from ..core.cpd.noisy_and import NoisyAndCPD
70
71
 
71
72
  bn = arg.bn
72
73
  nodes = list(bn.nodes) # topological (inputs before node)
@@ -105,6 +106,12 @@ def _forward(arg: "Argument"):
105
106
  a = jnp.asarray(cpd.activations)
106
107
  pin = jnp.stack([m[i] for i in ins])
107
108
  m[n] = 1 - (1 - cpd.leak) * jnp.prod(1 - a * pin)
109
+ elif isinstance(cpd, NoisyAndCPD):
110
+ # NoisyAnd activations are fixed structural constants (all 1.0 from
111
+ # the ASPIC compiler); they are NOT free parameters in theta.
112
+ a = jnp.asarray(cpd.activations)
113
+ pin = jnp.stack([m[i] for i in ins])
114
+ m[n] = (1 - cpd.leak) * jnp.prod(a * pin)
108
115
  elif _is_and_not_splice(cpd):
109
116
  e, u = ins
110
117
  m[n] = m[e] * (1 - m[u])
@@ -40,6 +40,8 @@ def _upstream(c: "Node") -> list["Node"]:
40
40
  out.extend(src for src, _lr, _kind in getattr(c, "_edges", []))
41
41
  out.extend(getattr(c, "_strict", []))
42
42
  out.extend(by for _source, by in getattr(c, "_undercuts", []))
43
+ for srcs, _lr, _leak, _rule in getattr(c, "_conjunctions", []):
44
+ out.extend(srcs)
43
45
  return out
44
46
 
45
47
 
@@ -86,7 +88,12 @@ def _validate(claims: list["Node"]) -> None:
86
88
  if not isinstance(c, _Claim):
87
89
  continue
88
90
 
89
- if c.role == "conclusion" and not c._edges and not c._strict:
91
+ if (
92
+ c.role == "conclusion"
93
+ and not c._edges
94
+ and not c._strict
95
+ and not getattr(c, "_conjunctions", [])
96
+ ):
90
97
  warnings.warn(
91
98
  f"{c.name!r} is a conclusion with no incoming argument; a leaf "
92
99
  "should be a Premise.",
@@ -160,21 +167,32 @@ def _lower(c: "_Claim") -> None:
160
167
 
161
168
  defeasible = [(resolve(src), lr) for src, lr, _kind in c._edges]
162
169
  strict = [resolve(src) for src in c._strict]
170
+ conjunctions = getattr(c, "_conjunctions", [])
171
+
172
+ # `host` carries the defeasible inputs: the conclusion itself, or a hidden `D`
173
+ # when strict edges divorce it (parent-divorcing).
174
+ host = Node(f"{c.name}/defeasible", prior=c.prior) if strict else c
175
+
176
+ for src, lr in defeasible:
177
+ host.add_input(src, lr=lr)
178
+
179
+ # Conjunctive support: one hidden NoisyAnd gate per group, fed into `host` as a
180
+ # single lr support. The gate is the logical AND (each source at activation 1,
181
+ # firing iff all hold); `leak` is the rule's fallibility; `lr` is its strength.
182
+ for k, (srcs, lr, leak, _rule) in enumerate(conjunctions):
183
+ gate = Node(f"{c.name}/conj[{k}]")
184
+ gate.noisy_and(leak=leak)
185
+ for src in srcs:
186
+ gate.add_input(resolve(src), activation=1.0)
187
+ host.add_input(gate, lr=lr)
188
+
189
+ _apply_correlations(c, host, resolve)
163
190
 
164
191
  if strict:
165
- # parent-divorcing: D holds the defeasible part (or just the prior).
166
- d = Node(f"{c.name}/defeasible", prior=c.prior)
167
- for src, lr in defeasible:
168
- d.add_input(src, lr=lr)
169
- _apply_correlations(c, d, resolve)
170
192
  c.noisy_or(leak=0.0)
171
193
  for src in strict:
172
194
  c.add_input(src, activation=1.0)
173
- c.add_input(d, activation=1.0)
174
- else:
175
- for src, lr in defeasible:
176
- c.add_input(src, lr=lr)
177
- _apply_correlations(c, c, resolve)
195
+ c.add_input(host, activation=1.0)
178
196
 
179
197
 
180
198
  def _apply_correlations(c: "_Claim", host: "Node", resolve) -> None:
@@ -23,11 +23,16 @@ prior to a *target posterior value* by bisection, grows to a target claim count
23
23
  constructively, and screens realized depth. The branch-level methods
24
24
  (`add_support_branch` / `add_attack_branch`) are usable directly to hand-script a
25
25
  template.
26
+
27
+ Phase-1 batch generation: use `StructuralParams.independent_only()` to get the
28
+ baseline preset (no strict/undercut/undermine/axiom edges) and `generate_batch` to
29
+ produce N graphs in parallel with isolated per-graph seeds.
26
30
  """
27
31
  from __future__ import annotations
28
32
 
29
33
  import math
30
34
  import random
35
+ import warnings
31
36
  from dataclasses import dataclass
32
37
  from typing import Optional
33
38
 
@@ -71,6 +76,28 @@ class StructuralParams:
71
76
  def n_groups(self) -> int:
72
77
  return self.n_support + self.n_attack
73
78
 
79
+ @classmethod
80
+ def independent_only(cls, **overrides) -> "StructuralParams":
81
+ """Phase-1 preset: purely independent-evidence CPDs.
82
+
83
+ Pins ``strict_prob = undercut_prob = undermine_prob = axiom_prob = 0`` so
84
+ every node compiles to an ``IndependentEvidenceCPD`` — no NoisyOr splices,
85
+ no undercut AND-NOT gates, no axioms. The graph is a polytree of defeasible
86
+ support / rebut edges only, and every metric is exact.
87
+
88
+ The d-sep guardrail is satisfied by keeping ``n_support + n_attack`` in
89
+ [1, 3] (the default is 2+1 = 3 groups). Pass overrides to relax any field,
90
+ e.g. ``StructuralParams.independent_only(n_support=1, n_attack=0)``.
91
+ """
92
+ defaults = dict(
93
+ strict_prob=0.0,
94
+ undercut_prob=0.0,
95
+ undermine_prob=0.0,
96
+ axiom_prob=0.0,
97
+ )
98
+ defaults.update(overrides)
99
+ return cls(**defaults)
100
+
74
101
  def __post_init__(self):
75
102
  if self.n_support < 0 or self.n_attack < 0:
76
103
  raise ValueError("n_support and n_attack must be >= 0")
@@ -93,6 +120,46 @@ class DifficultyTargets:
93
120
  min_manipulability: Optional[float] = None
94
121
  min_depth: Optional[int] = None # realized longest input-path into the root
95
122
  d_sep_groups: Optional[int] = None # realized d-separated group count (for share mode)
123
+ max_d_sep_groups: Optional[int] = None # upper bound on d-sep group count
124
+
125
+
126
+ @dataclass
127
+ class BatchParams:
128
+ """Parameters for batch generation via ``generate_batch``.
129
+
130
+ Attributes
131
+ ----------
132
+ n : int
133
+ Number of graphs to generate.
134
+ seeds : list[int] or None
135
+ Explicit per-graph seeds (length must equal ``n`` when provided).
136
+ Defaults to ``list(range(n))`` so seed i produces graph i.
137
+ max_attempts_per_graph : int
138
+ Rejection budget passed to each ``generate()`` call.
139
+ on_failure : str
140
+ One of ``"skip"`` (default), ``"warn"``, or ``"raise"``.
141
+ Controls what happens when a seed exhausts its budget.
142
+ n_jobs : int
143
+ Number of parallel worker processes (``multiprocessing.Pool``).
144
+ Default 1 (serial). See OQ3 in graph-generation-spec.md before
145
+ enabling parallelism with an optimize fallback.
146
+ """
147
+
148
+ n: int = 1000
149
+ seeds: Optional[list] = None
150
+ max_attempts_per_graph: int = 200
151
+ on_failure: str = "skip"
152
+ n_jobs: int = 1
153
+
154
+ def __post_init__(self):
155
+ if self.on_failure not in ("skip", "warn", "raise"):
156
+ raise ValueError(
157
+ f"on_failure must be 'skip', 'warn', or 'raise'; got {self.on_failure!r}"
158
+ )
159
+ if self.seeds is not None and len(self.seeds) != self.n:
160
+ raise ValueError(
161
+ f"seeds has {len(self.seeds)} entries but n={self.n}"
162
+ )
96
163
 
97
164
 
98
165
  class ArgumentGenerator:
@@ -205,10 +272,16 @@ class ArgumentGenerator:
205
272
  """Existing nodes that may legally become a new parent of `conclusion`:
206
273
  matching role, not the root, not already a source of it, not inside an
207
274
  undercutter subgraph, and not an ancestor of it (which would make a directed
208
- cycle). Reuse is what turns the forest into a DAG with shared parents."""
275
+ cycle). Candidates are drawn from `conclusion`'s OWN branch its already-built
276
+ upstream subtree (`_reachable(conclusion)`) — never from sibling branches, so a
277
+ reused node creates a within-branch reconvergence but never merges two
278
+ d-separated branches at the root. (Bug fix 2026-06-26: this used to iterate
279
+ `_reachable(self.root)`; because a branch is attached to the root only *after*
280
+ it is built, that pool excluded the current branch and pulled from previously
281
+ built sibling branches, collapsing root-level d-separation.)"""
209
282
  existing = {s for s, _lr, _k in conclusion._edges} | set(conclusion._strict)
210
283
  out = []
211
- for n in _reachable(self.root):
284
+ for n in _reachable(conclusion): # within this branch only
212
285
  if not isinstance(n, role) or n is self.root or n is conclusion:
213
286
  continue
214
287
  if n in self._uc_nodes or n in existing:
@@ -363,8 +436,12 @@ def generate(seed: Optional[int] = None, *,
363
436
  # cycle, so fall back to the exact solver (keep share-mode graphs small).
364
437
  p_root = (LoopySolver if is_polytree(bn) else ExactSolver)(bn).prob(arg.target, 1)
365
438
 
366
- if tgt.d_sep_groups is not None and \
367
- len(d_separated_groups(bn, arg.target)) != tgt.d_sep_groups:
439
+ n_dsep = len(d_separated_groups(bn, arg.target))
440
+ if tgt.d_sep_groups is not None and n_dsep != tgt.d_sep_groups:
441
+ continue
442
+ if tgt.max_d_sep_groups is not None and n_dsep > tgt.max_d_sep_groups:
443
+ if verbose:
444
+ print(f"attempt {attempt}: d_sep_groups {n_dsep} > max {tgt.max_d_sep_groups}")
368
445
  continue
369
446
  if tgt.posterior_side == "above" and not p_root > tgt.threshold:
370
447
  continue
@@ -395,3 +472,99 @@ def generate(seed: Optional[int] = None, *,
395
472
  "(move threshold toward 0.5, widen posterior_tol, lower min_manipulability/"
396
473
  "min_depth) or raise max_attempts"
397
474
  )
475
+
476
+
477
+ def generate_batch(
478
+ batch: BatchParams,
479
+ structural: Optional[StructuralParams] = None,
480
+ targets: Optional[DifficultyTargets] = None,
481
+ verbose: bool = False,
482
+ ) -> "tuple[list[Argument], list[tuple[int, Exception]]]":
483
+ """Batch-generate ``batch.n`` ``Argument`` objects.
484
+
485
+ Each graph is generated by ``generate(seed=s, structural=structural,
486
+ targets=targets, max_attempts=batch.max_attempts_per_graph)``. Seeds default
487
+ to ``0, 1, …, n-1`` or use ``batch.seeds`` for explicit per-graph seeds.
488
+
489
+ Parameters
490
+ ----------
491
+ batch : BatchParams
492
+ Controls count, seeds, failure policy, and parallelism.
493
+ structural : StructuralParams or None
494
+ Structural shape. None uses ``StructuralParams()`` defaults.
495
+ targets : DifficultyTargets or None
496
+ Difficulty filters. None uses unconstrained defaults.
497
+ verbose : bool
498
+ Forward to each ``generate`` call for per-attempt logging.
499
+
500
+ Returns
501
+ -------
502
+ args : list[Argument]
503
+ Accepted arguments, length <= ``batch.n`` when ``on_failure='skip'``.
504
+ failures : list[tuple[int, Exception]]
505
+ (seed, error) pairs for seeds that exhausted their budget.
506
+ Always empty when ``on_failure='raise'``.
507
+
508
+ Raises
509
+ ------
510
+ RuntimeError
511
+ If any seed fails and ``batch.on_failure == 'raise'``.
512
+ """
513
+ seeds = batch.seeds if batch.seeds is not None else list(range(batch.n))
514
+
515
+ def _one(seed: int):
516
+ try:
517
+ return generate(
518
+ seed=seed,
519
+ structural=structural,
520
+ targets=targets,
521
+ max_attempts=batch.max_attempts_per_graph,
522
+ verbose=verbose,
523
+ )
524
+ except Exception as exc:
525
+ return exc
526
+
527
+ accepted: list[Argument] = []
528
+ failures: list[tuple[int, Exception]] = []
529
+
530
+ if batch.n_jobs == 1:
531
+ for s in seeds:
532
+ result = _one(s)
533
+ if isinstance(result, Exception):
534
+ failures.append((s, result))
535
+ if batch.on_failure == "raise":
536
+ raise result
537
+ if batch.on_failure == "warn":
538
+ warnings.warn(
539
+ f"generate_batch: seed {s} failed: {result}",
540
+ stacklevel=2,
541
+ )
542
+ else:
543
+ accepted.append(result)
544
+ else:
545
+ # Parallelise with spawn-safe multiprocessing (avoid JAX fork issues, OQ3).
546
+ import multiprocessing as mp
547
+
548
+ ctx = mp.get_context("spawn")
549
+ with ctx.Pool(batch.n_jobs) as pool:
550
+ results = pool.map(_one, seeds)
551
+ for s, result in zip(seeds, results):
552
+ if isinstance(result, Exception):
553
+ failures.append((s, result))
554
+ if batch.on_failure == "raise":
555
+ raise result
556
+ if batch.on_failure == "warn":
557
+ warnings.warn(
558
+ f"generate_batch: seed {s} failed: {result}",
559
+ stacklevel=2,
560
+ )
561
+ else:
562
+ accepted.append(result)
563
+
564
+ if verbose and failures:
565
+ print(
566
+ f"generate_batch: {len(failures)} seeds failed, "
567
+ f"{len(accepted)} accepted out of {batch.n} requested."
568
+ )
569
+
570
+ return accepted, failures