probability-flow 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {probability_flow-0.2.0 → probability_flow-0.4.0}/.gitignore +3 -0
  2. {probability_flow-0.2.0 → probability_flow-0.4.0}/PKG-INFO +9 -1
  3. {probability_flow-0.2.0 → probability_flow-0.4.0}/README.md +7 -0
  4. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/__init__.py +4 -0
  5. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/__init__.py +3 -0
  6. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/argument.py +110 -15
  7. probability_flow-0.4.0/probability_flow/aspic/benchmark.py +254 -0
  8. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/calibrate.py +7 -0
  9. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/compile.py +39 -9
  10. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/generate.py +177 -4
  11. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/handle.py +58 -10
  12. probability_flow-0.4.0/probability_flow/aspic/optimize.py +514 -0
  13. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/aspic/visualization.py +42 -5
  14. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/__init__.py +4 -0
  15. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/bp/engine.py +1 -0
  16. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/__init__.py +4 -1
  17. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/base.py +8 -0
  18. probability_flow-0.4.0/probability_flow/core/cpd/correlated_evidence.py +243 -0
  19. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/independent_evidence.py +13 -9
  20. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_and.py +2 -2
  21. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/noisy_or.py +2 -2
  22. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/exact.py +1 -0
  23. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/network.py +41 -3
  24. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/node.py +21 -0
  25. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/__init__.py +8 -1
  26. probability_flow-0.4.0/probability_flow/metrics/manipulability.py +474 -0
  27. probability_flow-0.4.0/probability_flow/visualization/animate.py +210 -0
  28. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/visualization/image.py +113 -25
  29. probability_flow-0.4.0/probability_flow/visualization/layout.py +423 -0
  30. {probability_flow-0.2.0 → probability_flow-0.4.0}/pyproject.toml +17 -2
  31. probability_flow-0.2.0/probability_flow/metrics/manipulability.py +0 -207
  32. probability_flow-0.2.0/probability_flow/visualization/layout.py +0 -166
  33. {probability_flow-0.2.0 → probability_flow-0.4.0}/LICENSE +0 -0
  34. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/_logmath.py +0 -0
  35. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/bp/__init__.py +0 -0
  36. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/bp/message.py +0 -0
  37. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/core/cpd/tabular.py +0 -0
  38. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/_util.py +0 -0
  39. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/difficulty.py +0 -0
  40. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/dseparation.py +0 -0
  41. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/loopiness.py +0 -0
  42. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/metrics/structure.py +0 -0
  43. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/py.typed +0 -0
  44. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/visualization/__init__.py +0 -0
  45. {probability_flow-0.2.0 → probability_flow-0.4.0}/probability_flow/visualization/style.py +0 -0
@@ -6,3 +6,6 @@ __pycache__/
6
6
  *.egg-info/
7
7
  .DS_Store
8
8
  _previews/
9
+
10
+ # Local secrets
11
+ .env
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: probability-flow
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary: A from-scratch, modular discrete Bayesian-network library.
5
5
  Project-URL: Homepage, https://github.com/scalable-oversight-benchmarks/probability-flow
6
6
  Project-URL: Repository, https://github.com/scalable-oversight-benchmarks/probability-flow
@@ -28,6 +28,7 @@ Requires-Dist: pytest; extra == 'dev'
28
28
  Requires-Dist: ruff; extra == 'dev'
29
29
  Provides-Extra: jax
30
30
  Requires-Dist: jax; extra == 'jax'
31
+ Requires-Dist: scipy; extra == 'jax'
31
32
  Provides-Extra: viz
32
33
  Requires-Dist: matplotlib; extra == 'viz'
33
34
  Description-Content-Type: text/markdown
@@ -118,6 +119,13 @@ combiner) from the *object* that implements it (the CPD).
118
119
  sources of evidence: `logit P(node=1) = logit(prior) + sum of log(lr)` over the
119
120
  active inputs. Adding weights of evidence is Bayes' rule for independent
120
121
  likelihood ratios. Set per edge with `add_input(x, lr=...)`.
122
+ - **CorrelatedEvidenceCPD**. Independent evidence plus pairwise couplings, for
123
+ *redundant* inputs the additive rule would otherwise double-count (two reports of
124
+ one fact, two clues from a shared cause): it adds a pairwise term
125
+ `+ sum J_ij s_i s_j` (a negative `J` makes two inputs sub-additive when both fire,
126
+ a positive one synergistic). The coupling lives inside the single CPD factor, so
127
+ it adds no edge to the graph and no loop to the solver, and stays a valid
128
+ distribution for any real `J`.
121
129
  - **NoisyOrCPD**. "Any one cause can fire the effect":
122
130
  `P(node=0) = (1 - leak) * product of (1 - activation)` over present causes.
123
131
  Declared with `node.noisy_or(leak=...)` and `add_input(cause, activation=...)`.
@@ -84,6 +84,13 @@ combiner) from the *object* that implements it (the CPD).
84
84
  sources of evidence: `logit P(node=1) = logit(prior) + sum of log(lr)` over the
85
85
  active inputs. Adding weights of evidence is Bayes' rule for independent
86
86
  likelihood ratios. Set per edge with `add_input(x, lr=...)`.
87
+ - **CorrelatedEvidenceCPD**. Independent evidence plus pairwise couplings, for
88
+ *redundant* inputs the additive rule would otherwise double-count (two reports of
89
+ one fact, two clues from a shared cause): it adds a pairwise term
90
+ `+ sum J_ij s_i s_j` (a negative `J` makes two inputs sub-additive when both fire,
91
+ a positive one synergistic). The coupling lives inside the single CPD factor, so
92
+ it adds no edge to the graph and no loop to the solver, and stays a valid
93
+ distribution for any real `J`.
87
94
  - **NoisyOrCPD**. "Any one cause can fire the effect":
88
95
  `P(node=0) = (1 - leak) * product of (1 - activation)` over present causes.
89
96
  Declared with `node.noisy_or(leak=...)` and `add_input(cause, activation=...)`.
@@ -12,12 +12,14 @@ from .core import (
12
12
  CPD,
13
13
  BayesianNetwork,
14
14
  CompiledCPD,
15
+ CorrelatedEvidenceCPD,
15
16
  ExactSolver,
16
17
  IndependentEvidenceCPD,
17
18
  LoopySolver,
18
19
  Node,
19
20
  NoisyAndCPD,
20
21
  NoisyOrCPD,
22
+ PendingWeightError,
21
23
  TabularCPD,
22
24
  )
23
25
 
@@ -36,6 +38,8 @@ __all__ = [
36
38
  "CPD",
37
39
  "TabularCPD",
38
40
  "IndependentEvidenceCPD",
41
+ "CorrelatedEvidenceCPD",
39
42
  "NoisyOrCPD",
40
43
  "NoisyAndCPD",
44
+ "PendingWeightError",
41
45
  ]
@@ -9,13 +9,16 @@ pure-BN `core`. Build an argument out of premises, conclusions, and attacks, the
9
9
  from .argument import ArgumentWarning, Axiom, Conclusion, Premise
10
10
  from .generate import (
11
11
  ArgumentGenerator,
12
+ BatchParams,
12
13
  DifficultyTargets,
13
14
  StructuralParams,
14
15
  generate,
16
+ generate_batch,
15
17
  )
16
18
  from .handle import Argument
17
19
 
18
20
  __all__ = [
19
21
  "Premise", "Axiom", "Conclusion", "ArgumentWarning", "Argument",
20
22
  "ArgumentGenerator", "StructuralParams", "DifficultyTargets", "generate",
23
+ "BatchParams", "generate_batch",
21
24
  ]
@@ -41,9 +41,17 @@ class _Claim(Node):
41
41
 
42
42
  def __init__(self, name: str, prior: float = 0.5):
43
43
  super().__init__(name, prior=prior)
44
- self._edges: list[tuple["Node", float, str]] = [] # (src, lr, kind)
44
+ self._edges: list[tuple["Node", "float | None", str]] = [] # (src, lr, kind)
45
45
  self._strict: list["Node"] = [] # strict sources
46
46
  self._undercuts: list[tuple["Node", "Node"]] = [] # (attacked source, undercutter)
47
+ # (source_a, source_b, J): residual correlation between two defeasible
48
+ # sources of this conclusion, lowered to a CorrelatedEvidenceCPD coupling.
49
+ self._correlations: list[tuple["Node", "Node", "float | None"]] = []
50
+ # (sources, lr, leak, rule): a conjunctive support group — the sources
51
+ # JOINTLY support this conclusion through one noisy-AND inference (the
52
+ # deduction rule), lowered to a hidden NoisyAnd gate fed in as one lr.
53
+ self._conjunctions: list[
54
+ tuple[list["Node"], "float | None", float, "str | None"]] = []
47
55
  self._no_undermine = False
48
56
  # opaque, serialized but uninterpreted by the library (see docs/aspic.md):
49
57
  self.desc: str | None = None # a longer description
@@ -53,8 +61,13 @@ class _Claim(Node):
53
61
  """At most one argumentative edge joins a pair of claims. This is what lets
54
62
  an undercut be addressed by its endpoints and keeps serialized edge ids
55
63
  (`source->target`) unique."""
56
- if any(s is src for s, _lr, _kind in self._edges) or any(
57
- s is src for s in self._strict
64
+ in_conj = any(
65
+ s is src for srcs, _lr, _leak, _rule in self._conjunctions for s in srcs
66
+ )
67
+ if (
68
+ any(s is src for s, _lr, _kind in self._edges)
69
+ or any(s is src for s in self._strict)
70
+ or in_conj
58
71
  ):
59
72
  raise ValueError(
60
73
  f"{self.name!r} already has an edge from {src.name!r}; at most one "
@@ -63,21 +76,69 @@ class _Claim(Node):
63
76
 
64
77
  # --- defeasible and strict support (methods on the downstream conclusion) --
65
78
 
66
- def support(self, src: "Node", lr: float) -> "Node":
79
+ def support(self, src: "Node", lr: "float | None" = None) -> "Node":
67
80
  """Add a defeasible argument *for* this conclusion (`lr > 1`). Returns
68
81
  `src`, so an inline source can be built further upstream, as with core
69
- `add_input`."""
70
- if not lr > 1:
82
+ `add_input`. `lr` may be omitted (`None`) to declare the edge before its
83
+ strength is known the topology compiles and renders (dashed), but the
84
+ network cannot be solved until every weight is assigned."""
85
+ if lr is not None and not lr > 1:
71
86
  raise ValueError(
72
87
  f"support lr must be > 1 (got {lr}); use rebut for an argument against"
73
88
  )
74
89
  self._require_new_edge(src)
75
- self._edges.append((src, float(lr), "support"))
90
+ self._edges.append((src, None if lr is None else float(lr), "support"))
76
91
  return src
77
92
 
78
- def rebut(self, src: "Node", lr: float) -> "Node":
79
- """Add a defeasible argument *against* this conclusion (`0 < lr < 1`)."""
80
- if not 0 < lr < 1:
93
+ def support_all(
94
+ self,
95
+ srcs: "list[Node]",
96
+ lr: "float | None" = None,
97
+ leak: float = 0.0,
98
+ rule: "str | None" = None,
99
+ ) -> "list[Node]":
100
+ """Conjunctive support: `srcs` JOINTLY support this conclusion through a
101
+ single noisy-AND inference — the deduction fires only when *all* the sources
102
+ hold. This is the faithful mapping of a deductive entailment step: N
103
+ antecedent siblings plus one inference rule yield one deduced conclusion,
104
+ the shape of every node in a MuSR-style reasoning tree. Returns `srcs`, so
105
+ inline sources can be built further upstream (as with `support`).
106
+
107
+ Lowered (see `compile.py`) to a hidden `NoisyAnd` gate over `srcs`, each at
108
+ activation 1 — the gate is the logical AND, firing iff every source holds —
109
+ fed into this conclusion as a single `lr` support. So all the evidential
110
+ strength is the one `lr > 1` (where the inference rule's re-elicited weight
111
+ lands), and the conclusion keeps its prior and composes with ordinary
112
+ `support` / `rebut` edges. `leak` is the rule's fallibility:
113
+ `P(fire | all sources hold) = 1 - leak`. `rule` is an optional label for the
114
+ inference licence (carried for audit / rendering, not interpreted). `lr` may
115
+ be omitted (`None`) to scaffold the group before its strength is known.
116
+ """
117
+ srcs = list(srcs)
118
+ if len(srcs) < 2:
119
+ raise ValueError(
120
+ "support_all needs >= 2 sources (it is a conjunction); use support "
121
+ "for a single argument"
122
+ )
123
+ if len({id(s) for s in srcs}) != len(srcs):
124
+ raise ValueError("support_all sources must be distinct")
125
+ if lr is not None and not lr > 1:
126
+ raise ValueError(
127
+ f"support_all lr must be > 1 (got {lr}); it is conjunctive support"
128
+ )
129
+ if not 0.0 <= leak <= 1.0:
130
+ raise ValueError("leak must be a probability in [0, 1]")
131
+ for src in srcs:
132
+ self._require_new_edge(src)
133
+ self._conjunctions.append(
134
+ (srcs, None if lr is None else float(lr), float(leak), rule)
135
+ )
136
+ return srcs
137
+
138
+ def rebut(self, src: "Node", lr: "float | None" = None) -> "Node":
139
+ """Add a defeasible argument *against* this conclusion (`0 < lr < 1`). `lr`
140
+ may be omitted (`None`) to declare the edge before its strength is known."""
141
+ if lr is not None and not 0 < lr < 1:
81
142
  raise ValueError(
82
143
  f"rebut lr must be in (0, 1) (got {lr}); use support for an argument for"
83
144
  )
@@ -89,7 +150,7 @@ class _Claim(Node):
89
150
  stacklevel=2,
90
151
  )
91
152
  self._require_new_edge(src)
92
- self._edges.append((src, float(lr), "rebut"))
153
+ self._edges.append((src, None if lr is None else float(lr), "rebut"))
93
154
  return src
94
155
 
95
156
  def strict(self, src: "Node") -> "Node":
@@ -101,18 +162,19 @@ class _Claim(Node):
101
162
 
102
163
  # --- attacks (also methods on the downstream node) ------------------------
103
164
 
104
- def undermine(self, by: "Node", lr: float) -> "Node":
165
+ def undermine(self, by: "Node", lr: "float | None" = None) -> "Node":
105
166
  """Attack this premise with `by` (`0 < lr < 1`). Mechanically a rebut into
106
167
  the attacked node, named distinctly because it is the ASPIC-correct verb
107
- for attacking a premise. An axiom cannot be undermined."""
168
+ for attacking a premise. An axiom cannot be undermined. `lr` may be omitted
169
+ (`None`) to declare the edge before its strength is known."""
108
170
  if self._no_undermine:
109
171
  raise ValueError(
110
172
  f"cannot undermine {self.name!r}: an axiom has no defeasible premise to attack"
111
173
  )
112
- if not 0 < lr < 1:
174
+ if lr is not None and not 0 < lr < 1:
113
175
  raise ValueError(f"undermine lr must be in (0, 1) (got {lr})")
114
176
  self._require_new_edge(by)
115
- self._edges.append((by, float(lr), "undermine"))
177
+ self._edges.append((by, None if lr is None else float(lr), "undermine"))
116
178
  return by
117
179
 
118
180
  def undercut(self, source: "Node", by: "Node") -> "Node":
@@ -123,6 +185,39 @@ class _Claim(Node):
123
185
  self._undercuts.append((source, by))
124
186
  return by
125
187
 
188
+ def correlate(self, source_a: "Node", source_b: "Node",
189
+ coupling: "float | None" = None) -> "_Claim":
190
+ """Declare residual correlation between two **defeasible sources** of this
191
+ conclusion — redundancy (two reports of one fact, a shared cause) or
192
+ synergy. Lowered to a pairwise `CorrelatedEvidenceCPD` coupling `J`:
193
+ negative discounts the pair (redundant / sub-additive), positive boosts it.
194
+ `J` is the engine-native quantity (a caller with a correlation or a pair of
195
+ conditional probabilities does the inverse map). `coupling` may be omitted
196
+ to scaffold the pair before its value is chosen. Returns self for chaining.
197
+
198
+ This overrides the core `Node.correlate`: a `_Claim` declares the pairing
199
+ against its argument sources here, and `compile` lowers it onto whichever
200
+ core node carries the defeasible inputs.
201
+ """
202
+ sources = {id(s) for s, _lr, kind in self._edges if kind in ("support", "rebut")}
203
+ for x in (source_a, source_b):
204
+ if id(x) not in sources:
205
+ raise ValueError(
206
+ f"{getattr(x, 'name', x)!r} is not a support/rebut source of "
207
+ f"{self.name!r}; correlate couples two of its defeasible sources"
208
+ )
209
+ if source_a is source_b:
210
+ raise ValueError("cannot correlate a source with itself")
211
+ if any({id(source_a), id(source_b)} == {id(a), id(b)}
212
+ for a, b, _J in self._correlations):
213
+ raise ValueError(
214
+ f"{source_a.name!r} and {source_b.name!r} are already correlated on "
215
+ f"{self.name!r}"
216
+ )
217
+ self._correlations.append(
218
+ (source_a, source_b, None if coupling is None else float(coupling)))
219
+ return self
220
+
126
221
  def compile(self) -> "BayesianNetwork":
127
222
  """Lower this argument (as the target) to a core `BayesianNetwork`."""
128
223
  from .compile import compile_argument
@@ -0,0 +1,254 @@
1
+ """Benchmark helpers: generate → optimize bridge for large-batch posterior targeting.
2
+
3
+ This module isolates the JAX/scipy dependency (via ``optimize``) from the pure-Python
4
+ ``generate.py``. The primary entry point is ``generate_with_targets``, which produces
5
+ N argument graphs each tuned to a requested root posterior.
6
+
7
+ Strategy
8
+ --------
9
+ For each (seed, target_posterior) pair:
10
+
11
+ 1. Try ``generate(seed, structural, DifficultyTargets(target_posterior=t))`` — this
12
+ bisects the root prior to hit *t* cheaply (no JAX needed).
13
+ 2. If the bisection cannot reach *t* for the drawn structure (``RuntimeError`` from
14
+ the rejection loop), fall back to:
15
+ a. Generate an *unconstrained* graph from the same seed.
16
+ b. Check whether *t* is inside ``achievable_interval(arg)``.
17
+ c. If yes, call ``optimize(arg, target_posterior=t)`` to tune LRs.
18
+ 3. If neither path succeeds, record the failure (skip / warn / raise per
19
+ ``on_failure``).
20
+
21
+ Both paths return an ``Argument`` with identical public interface. The ``optimize``
22
+ fallback is more expensive (JAX JIT + SLSQP) but reaches extreme posteriors that
23
+ bisection cannot.
24
+
25
+ Batch labeling
26
+ --------------
27
+ Each accepted argument is annotated with a ``_benchmark_meta`` attribute (a plain
28
+ dict) carrying the difficulty scalars the benchmark consumer (``debate-eval``) needs:
29
+
30
+ posterior, manipulability, d_sep_count, max_depth_metric,
31
+ upstream_size, circuit_rank, seed, is_exact_manipulability,
32
+ used_optimize_fallback
33
+
34
+ These come entirely from existing metrics — no new computation.
35
+
36
+ JAX/scipy availability
37
+ ----------------------
38
+ Both are required for the ``optimize`` fallback path. When they are absent the
39
+ module still imports cleanly; ``generate_with_targets`` will error only if an
40
+ optimize fallback is actually needed. Use ``generate_batch`` (no JAX) when you only
41
+ need bisection-reachable posteriors.
42
+ """
43
+ from __future__ import annotations
44
+
45
+ import warnings
46
+ from typing import TYPE_CHECKING, Optional
47
+
48
+ if TYPE_CHECKING:
49
+ from .handle import Argument
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # Internal helpers
53
+ # ---------------------------------------------------------------------------
54
+
55
+
56
+ def _label(arg: "Argument", seed: int, used_optimize: bool) -> None:
57
+ """Attach ``_benchmark_meta`` to *arg* in-place."""
58
+ from ..metrics import (
59
+ circuit_rank,
60
+ d_separated_groups,
61
+ max_depth,
62
+ posterior_range,
63
+ upstream_size,
64
+ )
65
+
66
+ bn = arg.bn
67
+ p = arg.posterior(arg.target)
68
+
69
+ # posterior_range: exact=True on polytrees, outer bound on shared graphs.
70
+ # Report manipulability as the [min, max] achievable root posterior — far more
71
+ # informative than the width alone, since it shows *where* on [0,1] the judge can
72
+ # be pushed (e.g. a one-sided graph is pinned to one side of the 0.5 threshold).
73
+ pr = posterior_range(bn, arg.target, exact=True)
74
+
75
+ arg._benchmark_meta = {
76
+ "posterior": float(p),
77
+ "manipulability": [round(float(pr.lo), 4), round(float(pr.hi), 4)],
78
+ "manipulability_width": round(float(pr.hi - pr.lo), 4),
79
+ "d_sep_count": len(d_separated_groups(bn, arg.target)),
80
+ "max_depth_metric": max_depth(bn, arg.target),
81
+ "upstream_size": upstream_size(bn, arg.target),
82
+ "circuit_rank": circuit_rank(bn),
83
+ "seed": seed,
84
+ "is_exact_manipulability": bool(pr.is_exact),
85
+ "used_optimize_fallback": used_optimize,
86
+ }
87
+
88
+
89
+ def _try_optimize(
90
+ arg: "Argument",
91
+ target: float,
92
+ posterior_tol: float,
93
+ param_limits: Optional[dict],
94
+ ) -> "Optional[Argument]":
95
+ """Try to hit *target* via ``optimize``; return the new arg or None on failure."""
96
+ try:
97
+ from .optimize import (
98
+ InfeasibleTargetError,
99
+ OptimizeError,
100
+ achievable_interval,
101
+ optimize,
102
+ )
103
+ except ImportError:
104
+ return None
105
+
106
+ try:
107
+ lo, hi = achievable_interval(arg, param_limits=param_limits)
108
+ except Exception:
109
+ return None
110
+
111
+ if not (lo - 1e-4 <= target <= hi + 1e-4):
112
+ return None # infeasible for this structure
113
+
114
+ try:
115
+ return optimize(arg, target_posterior=target,
116
+ posterior_tol=posterior_tol,
117
+ param_limits=param_limits)
118
+ except (InfeasibleTargetError, OptimizeError):
119
+ return None
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Public API
124
+ # ---------------------------------------------------------------------------
125
+
126
+
127
+ def generate_with_targets(
128
+ n: int,
129
+ structural,
130
+ posterior_targets: "list[float]",
131
+ posterior_tol: float = 0.02,
132
+ param_limits: "Optional[dict]" = None,
133
+ seeds: "Optional[list[int]]" = None,
134
+ max_attempts_per_graph: int = 200,
135
+ n_jobs: int = 1,
136
+ on_failure: str = "skip",
137
+ verbose: bool = False,
138
+ ) -> "list[Argument]":
139
+ """Generate *n* graphs, each tuned to a target root posterior.
140
+
141
+ Parameters
142
+ ----------
143
+ n : int
144
+ Number of graphs to produce.
145
+ structural : StructuralParams
146
+ Structural shape shared across all graphs.
147
+ posterior_targets : list[float]
148
+ Desired root posterior for each graph. Must have length *n*.
149
+ posterior_tol : float
150
+ Accepted absolute deviation from each target posterior (default 0.02).
151
+ param_limits : dict or None
152
+ Passed to ``optimize`` on the fallback path. None uses default boxes.
153
+ seeds : list[int] or None
154
+ Per-graph seeds. Defaults to ``list(range(n))``.
155
+ max_attempts_per_graph : int
156
+ Rejection budget for the primary ``generate`` path.
157
+ n_jobs : int
158
+ Worker count. Currently only 1 is fully supported for the optimize
159
+ fallback (JAX / fork safety, OQ3 in spec); n_jobs > 1 on the primary
160
+ path only is safe.
161
+ on_failure : str
162
+ ``"skip"`` (default) | ``"warn"`` | ``"raise"``.
163
+ verbose : bool
164
+ Print progress per graph.
165
+
166
+ Returns
167
+ -------
168
+ list[Argument]
169
+ Accepted and tuned arguments. Each carries ``._benchmark_meta``.
170
+ Length <= *n* when ``on_failure != 'raise'``.
171
+
172
+ Notes
173
+ -----
174
+ The fallback threshold (when to call ``optimize`` instead of resampling) is:
175
+ whenever the bisection-only ``generate`` raises ``RuntimeError`` (budget
176
+ exhausted). ``optimize`` is then tried on an *unconstrained* graph from the
177
+ same seed; if the target is achievable in the parameter box, it succeeds.
178
+ """
179
+ from .generate import DifficultyTargets, generate
180
+
181
+ if len(posterior_targets) != n:
182
+ raise ValueError(
183
+ f"posterior_targets has {len(posterior_targets)} entries but n={n}"
184
+ )
185
+ if seeds is None:
186
+ seeds = list(range(n))
187
+ if len(seeds) != n:
188
+ raise ValueError(f"seeds has {len(seeds)} entries but n={n}")
189
+
190
+ accepted: list[Argument] = []
191
+
192
+ for i, (seed, t) in enumerate(zip(seeds, posterior_targets)):
193
+ if verbose:
194
+ print(f"[{i + 1}/{n}] seed={seed} target={t:.3f}", end=" ... ")
195
+
196
+ # --- primary path: bisection inside generate() -----------------------
197
+ arg: Optional[Argument] = None
198
+ used_optimize = False
199
+ try:
200
+ arg = generate(
201
+ seed=seed,
202
+ structural=structural,
203
+ targets=DifficultyTargets(
204
+ target_posterior=t,
205
+ posterior_tol=posterior_tol,
206
+ ),
207
+ max_attempts=max_attempts_per_graph,
208
+ verbose=False,
209
+ )
210
+ except RuntimeError:
211
+ pass # bisection budget exhausted — try the optimize fallback
212
+
213
+ # --- fallback: generate unconstrained then optimize ------------------
214
+ if arg is None:
215
+ try:
216
+ base_arg = generate(
217
+ seed=seed,
218
+ structural=structural,
219
+ max_attempts=max_attempts_per_graph,
220
+ verbose=False,
221
+ )
222
+ opt_arg = _try_optimize(base_arg, t, posterior_tol, param_limits)
223
+ if opt_arg is not None:
224
+ arg = opt_arg
225
+ used_optimize = True
226
+ except RuntimeError:
227
+ pass
228
+
229
+ # --- failure handling ------------------------------------------------
230
+ if arg is None:
231
+ msg = (
232
+ f"generate_with_targets: seed {seed} could not reach "
233
+ f"target {t:.4f} via bisection or optimize."
234
+ )
235
+ if on_failure == "raise":
236
+ raise RuntimeError(msg)
237
+ if on_failure == "warn":
238
+ warnings.warn(msg, stacklevel=2)
239
+ if verbose:
240
+ print("FAILED")
241
+ continue
242
+
243
+ # --- label and collect -----------------------------------------------
244
+ _label(arg, seed=seed, used_optimize=used_optimize)
245
+ accepted.append(arg)
246
+ if verbose:
247
+ m = arg._benchmark_meta
248
+ print(
249
+ f"ok P={m['posterior']:.3f} manip={m['manipulability']:.3f} "
250
+ f"nodes={m['upstream_size'] + 1} "
251
+ f"{'[opt]' if used_optimize else '[bisect]'}"
252
+ )
253
+
254
+ return accepted
@@ -67,6 +67,7 @@ def _forward(arg: "Argument"):
67
67
  current values (logit priors, log LRs), and `meta` labels each entry."""
68
68
  jax, jnp = _require_jax()
69
69
  from ..core import IndependentEvidenceCPD, NoisyOrCPD
70
+ from ..core.cpd.noisy_and import NoisyAndCPD
70
71
 
71
72
  bn = arg.bn
72
73
  nodes = list(bn.nodes) # topological (inputs before node)
@@ -105,6 +106,12 @@ def _forward(arg: "Argument"):
105
106
  a = jnp.asarray(cpd.activations)
106
107
  pin = jnp.stack([m[i] for i in ins])
107
108
  m[n] = 1 - (1 - cpd.leak) * jnp.prod(1 - a * pin)
109
+ elif isinstance(cpd, NoisyAndCPD):
110
+ # NoisyAnd activations are fixed structural constants (all 1.0 from
111
+ # the ASPIC compiler); they are NOT free parameters in theta.
112
+ a = jnp.asarray(cpd.activations)
113
+ pin = jnp.stack([m[i] for i in ins])
114
+ m[n] = (1 - cpd.leak) * jnp.prod(a * pin)
108
115
  elif _is_and_not_splice(cpd):
109
116
  e, u = ins
110
117
  m[n] = m[e] * (1 - m[u])
@@ -40,6 +40,8 @@ def _upstream(c: "Node") -> list["Node"]:
40
40
  out.extend(src for src, _lr, _kind in getattr(c, "_edges", []))
41
41
  out.extend(getattr(c, "_strict", []))
42
42
  out.extend(by for _source, by in getattr(c, "_undercuts", []))
43
+ for srcs, _lr, _leak, _rule in getattr(c, "_conjunctions", []):
44
+ out.extend(srcs)
43
45
  return out
44
46
 
45
47
 
@@ -86,7 +88,12 @@ def _validate(claims: list["Node"]) -> None:
86
88
  if not isinstance(c, _Claim):
87
89
  continue
88
90
 
89
- if c.role == "conclusion" and not c._edges and not c._strict:
91
+ if (
92
+ c.role == "conclusion"
93
+ and not c._edges
94
+ and not c._strict
95
+ and not getattr(c, "_conjunctions", [])
96
+ ):
90
97
  warnings.warn(
91
98
  f"{c.name!r} is a conclusion with no incoming argument; a leaf "
92
99
  "should be a Premise.",
@@ -160,16 +167,39 @@ def _lower(c: "_Claim") -> None:
160
167
 
161
168
  defeasible = [(resolve(src), lr) for src, lr, _kind in c._edges]
162
169
  strict = [resolve(src) for src in c._strict]
170
+ conjunctions = getattr(c, "_conjunctions", [])
171
+
172
+ # `host` carries the defeasible inputs: the conclusion itself, or a hidden `D`
173
+ # when strict edges divorce it (parent-divorcing).
174
+ host = Node(f"{c.name}/defeasible", prior=c.prior) if strict else c
175
+
176
+ for src, lr in defeasible:
177
+ host.add_input(src, lr=lr)
178
+
179
+ # Conjunctive support: one hidden NoisyAnd gate per group, fed into `host` as a
180
+ # single lr support. The gate is the logical AND (each source at activation 1,
181
+ # firing iff all hold); `leak` is the rule's fallibility; `lr` is its strength.
182
+ for k, (srcs, lr, leak, _rule) in enumerate(conjunctions):
183
+ gate = Node(f"{c.name}/conj[{k}]")
184
+ gate.noisy_and(leak=leak)
185
+ for src in srcs:
186
+ gate.add_input(resolve(src), activation=1.0)
187
+ host.add_input(gate, lr=lr)
188
+
189
+ _apply_correlations(c, host, resolve)
163
190
 
164
191
  if strict:
165
- # parent-divorcing: D holds the defeasible part (or just the prior).
166
- d = Node(f"{c.name}/defeasible", prior=c.prior)
167
- for src, lr in defeasible:
168
- d.add_input(src, lr=lr)
169
192
  c.noisy_or(leak=0.0)
170
193
  for src in strict:
171
194
  c.add_input(src, activation=1.0)
172
- c.add_input(d, activation=1.0)
173
- else:
174
- for src, lr in defeasible:
175
- c.add_input(src, lr=lr)
195
+ c.add_input(host, activation=1.0)
196
+
197
+
198
+ def _apply_correlations(c: "_Claim", host: "Node", resolve) -> None:
199
+ """Lower `c`'s declared source correlations onto `host` — the core node that
200
+ actually carries the defeasible inputs (the conclusion itself, or its hidden
201
+ `/defeasible` node when strict edges divorced it). Uses `Node.correlate`
202
+ explicitly: `c` overrides `correlate` with the argument-level declaration, so
203
+ the core coupling must be reached through the base class."""
204
+ for a, b, coupling in getattr(c, "_correlations", []):
205
+ Node.correlate(host, resolve(a), resolve(b), coupling=coupling)