probability-flow 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {probability_flow-0.1.0 → probability_flow-0.2.0}/PKG-INFO +14 -9
- {probability_flow-0.1.0 → probability_flow-0.2.0}/README.md +13 -8
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/visualization.py +7 -1
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/__init__.py +15 -1
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/visualization/image.py +194 -44
- probability_flow-0.2.0/probability_flow/visualization/layout.py +166 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/pyproject.toml +1 -1
- {probability_flow-0.1.0 → probability_flow-0.2.0}/.gitignore +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/LICENSE +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/__init__.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/__init__.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/argument.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/calibrate.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/compile.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/generate.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/handle.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/__init__.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/_logmath.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/bp/__init__.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/bp/engine.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/bp/message.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/__init__.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/base.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/independent_evidence.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/noisy_and.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/noisy_or.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/tabular.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/exact.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/network.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/node.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/_util.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/difficulty.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/dseparation.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/loopiness.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/manipulability.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/structure.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/py.typed +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/visualization/__init__.py +0 -0
- {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/visualization/style.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: probability-flow
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A from-scratch, modular discrete Bayesian-network library.
|
|
5
5
|
Project-URL: Homepage, https://github.com/scalable-oversight-benchmarks/probability-flow
|
|
6
6
|
Project-URL: Repository, https://github.com/scalable-oversight-benchmarks/probability-flow
|
|
@@ -237,12 +237,15 @@ target posterior is hit by calibrating the root prior. See `docs/generation.md`.
|
|
|
237
237
|
## Visualization (optional, `[viz]`)
|
|
238
238
|
|
|
239
239
|
With the `[viz]` extra installed, a compiled network and an argument both render to
|
|
240
|
-
a matplotlib figure with an in-house
|
|
241
|
-
|
|
240
|
+
a matplotlib figure with an in-house layout — `radial` by default (the root at the
|
|
241
|
+
centre, the d-separated evidence branches fanning out under a force model), or
|
|
242
|
+
`layout="layered"` for the root-at-the-bottom tree. Likelihood-ratio edges are
|
|
243
|
+
coloured red→blue by their LR, and a long node name shrinks its font to fit:
|
|
242
244
|
|
|
243
245
|
```python
|
|
244
|
-
bn.render()
|
|
245
|
-
|
|
246
|
+
bn.render() # radial by default
|
|
247
|
+
bn.render(layout="layered") # root-at-bottom tree (orientation="horizontal" too)
|
|
248
|
+
guilty.assemble().render() # the argument view (same layouts)
|
|
246
249
|
```
|
|
247
250
|
|
|
248
251
|
matplotlib is imported lazily, only when you draw, so importing `probability_flow`
|
|
@@ -286,10 +289,12 @@ Working today: the build/compile flow, all four distributions, both solvers with
|
|
|
286
289
|
linear-time structured messages, evidence, the ASPIC argument-compilation layer,
|
|
287
290
|
argument serialization, the metrics seam (d-separation grouping, depth/size,
|
|
288
291
|
loopiness, difficulty, manipulability), a random argument generator with structural
|
|
289
|
-
and difficulty controls, optional matplotlib renderers
|
|
290
|
-
posterior calibration and parameter sensitivities.
|
|
291
|
-
|
|
292
|
-
|
|
292
|
+
and difficulty controls, optional matplotlib renderers (radial and layered
|
|
293
|
+
layouts), and optional JAX-based posterior calibration and parameter sensitivities.
|
|
294
|
+
Loopy BP, the metrics seam, and the renderer are validated against the exact solver
|
|
295
|
+
on non-tree (shared-parent) graphs by the "topology zoo" harness
|
|
296
|
+
(`tools/topology_zoo.py`, `docs/topology_zoo.md`). Planned (see `docs/ROADMAP.md`):
|
|
297
|
+
a core-network serializer and the exact manipulability range.
|
|
293
298
|
|
|
294
299
|
## Learning more
|
|
295
300
|
|
|
@@ -203,12 +203,15 @@ target posterior is hit by calibrating the root prior. See `docs/generation.md`.
|
|
|
203
203
|
## Visualization (optional, `[viz]`)
|
|
204
204
|
|
|
205
205
|
With the `[viz]` extra installed, a compiled network and an argument both render to
|
|
206
|
-
a matplotlib figure with an in-house
|
|
207
|
-
|
|
206
|
+
a matplotlib figure with an in-house layout — `radial` by default (the root at the
|
|
207
|
+
centre, the d-separated evidence branches fanning out under a force model), or
|
|
208
|
+
`layout="layered"` for the root-at-the-bottom tree. Likelihood-ratio edges are
|
|
209
|
+
coloured red→blue by their LR, and a long node name shrinks its font to fit:
|
|
208
210
|
|
|
209
211
|
```python
|
|
210
|
-
bn.render()
|
|
211
|
-
|
|
212
|
+
bn.render() # radial by default
|
|
213
|
+
bn.render(layout="layered") # root-at-bottom tree (orientation="horizontal" too)
|
|
214
|
+
guilty.assemble().render() # the argument view (same layouts)
|
|
212
215
|
```
|
|
213
216
|
|
|
214
217
|
matplotlib is imported lazily, only when you draw, so importing `probability_flow`
|
|
@@ -252,10 +255,12 @@ Working today: the build/compile flow, all four distributions, both solvers with
|
|
|
252
255
|
linear-time structured messages, evidence, the ASPIC argument-compilation layer,
|
|
253
256
|
argument serialization, the metrics seam (d-separation grouping, depth/size,
|
|
254
257
|
loopiness, difficulty, manipulability), a random argument generator with structural
|
|
255
|
-
and difficulty controls, optional matplotlib renderers
|
|
256
|
-
posterior calibration and parameter sensitivities.
|
|
257
|
-
|
|
258
|
-
|
|
258
|
+
and difficulty controls, optional matplotlib renderers (radial and layered
|
|
259
|
+
layouts), and optional JAX-based posterior calibration and parameter sensitivities.
|
|
260
|
+
Loopy BP, the metrics seam, and the renderer are validated against the exact solver
|
|
261
|
+
on non-tree (shared-parent) graphs by the "topology zoo" harness
|
|
262
|
+
(`tools/topology_zoo.py`, `docs/topology_zoo.md`). Planned (see `docs/ROADMAP.md`):
|
|
263
|
+
a core-network serializer and the exact manipulability range.
|
|
259
264
|
|
|
260
265
|
## Learning more
|
|
261
266
|
|
|
@@ -37,15 +37,21 @@ def render_argument(
|
|
|
37
37
|
orientation: str = "vertical",
|
|
38
38
|
title: Optional[str] = None,
|
|
39
39
|
path: Optional[str] = None,
|
|
40
|
+
layout: str = "radial",
|
|
40
41
|
):
|
|
41
42
|
"""Draw the argument rooted at `target`. Support / rebut / undermine edges are
|
|
42
43
|
coloured by their likelihood ratio on the same confirming/disconfirming scale
|
|
43
44
|
as the Bayesian-network view (strict is bold black, undercut purple-dotted onto
|
|
44
45
|
the edge it attacks). Each claim is annotated with its compiled belief
|
|
45
46
|
`P(=1 | evidence)`; pass `beliefs=False` to skip. `colorbar` / `legend` add the
|
|
46
|
-
LR scale and the colour key (default off).""
|
|
47
|
+
LR scale and the colour key (default off). `layout="radial"` puts the target at
|
|
48
|
+
the centre with the d-separated branches fanning outward (default "layered")."""
|
|
47
49
|
_require_matplotlib()
|
|
48
50
|
claims = [c for c in _reachable(target) if isinstance(c, _Claim)]
|
|
51
|
+
if layout == "radial":
|
|
52
|
+
from ..visualization.layout import radial_positions
|
|
53
|
+
positions = {**radial_positions(target.compile(), target),
|
|
54
|
+
**(positions or {})}
|
|
49
55
|
all_lrs = [lr for c in claims for _src, lr, _kind in c._edges]
|
|
50
56
|
color_for_lr, _ = _lr_colormap(all_lrs)
|
|
51
57
|
|
|
@@ -20,6 +20,7 @@ and docs/metrics.md.
|
|
|
20
20
|
"""
|
|
21
21
|
from __future__ import annotations
|
|
22
22
|
|
|
23
|
+
import warnings
|
|
23
24
|
from typing import TYPE_CHECKING, Optional
|
|
24
25
|
|
|
25
26
|
from ..core import ExactSolver
|
|
@@ -43,7 +44,20 @@ def posterior(bn_or_target, node: Optional["Node"] = None, evidence=None, *,
|
|
|
43
44
|
target = node if node is not None else default
|
|
44
45
|
if target is None:
|
|
45
46
|
raise ValueError("posterior needs a target node")
|
|
46
|
-
|
|
47
|
+
s = solver(bn)
|
|
48
|
+
p = s.prob(target, 1, evidence=evidence)
|
|
49
|
+
# `ExactSolver` is always exact; an iterative solver (`LoopySolver`, the "for
|
|
50
|
+
# scale" path) is exact only on a polytree and merely a tight approximation on
|
|
51
|
+
# a loopy graph — and on a graph it cannot settle, a silently non-converged
|
|
52
|
+
# value is worse than that. Surface it; the flag was otherwise unread.
|
|
53
|
+
if getattr(s, "last_converged", True) is False:
|
|
54
|
+
warnings.warn(
|
|
55
|
+
f"the iterative solver did not converge in {getattr(s, 'last_iters', '?')} "
|
|
56
|
+
"iterations, so this posterior is an unconverged approximation; raise "
|
|
57
|
+
"max_iters or add damping, or use the exact solver on a smaller graph.",
|
|
58
|
+
stacklevel=2,
|
|
59
|
+
)
|
|
60
|
+
return p
|
|
47
61
|
|
|
48
62
|
|
|
49
63
|
__all__ = [
|
|
@@ -1,12 +1,18 @@
|
|
|
1
|
-
"""matplotlib rendering:
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
drawn as data-coordinate patches (circle, or box for a
|
|
9
|
-
size relative to the spacing is fixed and arrowheads land
|
|
1
|
+
"""matplotlib rendering: the layered (DAG) layout and the network renderer.
|
|
2
|
+
matplotlib is imported lazily inside the drawing functions, so importing this
|
|
3
|
+
module (or `probability_flow`) never pulls it in.
|
|
4
|
+
|
|
5
|
+
Two layouts: the default `radial` (root at the centre, evidence radiating outward,
|
|
6
|
+
see `layout.py`) and the `layered` one here (`_positions`: root at the bottom, the
|
|
7
|
+
graph growing upward; a node's x is the mean of its inputs' x, so its inputs sit
|
|
8
|
+
centred over it). Nodes are drawn as data-coordinate patches (circle, or box for a
|
|
9
|
+
leaf premise), so their size relative to the spacing is fixed and arrowheads land
|
|
10
|
+
exactly on the border.
|
|
11
|
+
|
|
12
|
+
A node label is fit to the node by shrinking its font down to `MIN_FONT` (see
|
|
13
|
+
`_fit_label`); only a name too long to fit even then is truncated with an ellipsis.
|
|
14
|
+
Names still read best kept short — the full claim text belongs in `desc`, not the
|
|
15
|
+
name.
|
|
10
16
|
"""
|
|
11
17
|
from __future__ import annotations
|
|
12
18
|
|
|
@@ -21,12 +27,12 @@ if TYPE_CHECKING:
|
|
|
21
27
|
from ..core.node import Node
|
|
22
28
|
|
|
23
29
|
# geometry, in data units (the axes use equal aspect)
|
|
24
|
-
R_NODE = 0.
|
|
25
|
-
X_SPACING = 2.
|
|
26
|
-
Y_SPACING = 2.
|
|
30
|
+
R_NODE = 0.80 # node radius / box half-extent
|
|
31
|
+
X_SPACING = 2.2 # horizontal gap between adjacent leaves
|
|
32
|
+
Y_SPACING = 2.6 # vertical gap between layers
|
|
27
33
|
IN_PER_UNIT = 0.62 # figure inches per data unit
|
|
28
|
-
FONT = 6.5 #
|
|
29
|
-
|
|
34
|
+
FONT = 6.5 # preferred label point size; shrinks (to MIN_FONT) to fit
|
|
35
|
+
MIN_FONT = 3.6 # smallest label font before a name is truncated instead
|
|
30
36
|
DPI = 200 # saved-image resolution
|
|
31
37
|
_NEUTRAL_EDGE = "#9aa0a6"
|
|
32
38
|
_OBSERVED_EDGE = "#c77d00" # border for an observed (evidence) node
|
|
@@ -48,10 +54,18 @@ def _require_matplotlib():
|
|
|
48
54
|
|
|
49
55
|
|
|
50
56
|
def marginals(bn, evidence=None) -> dict:
|
|
51
|
-
"""`P(node = 1 | evidence)` for every node,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
57
|
+
"""`P(node = 1 | evidence)` for every node, for the node labels. Polytree-aware,
|
|
58
|
+
like the `Argument` handle: one exact `LoopySolver` propagation on a polytree
|
|
59
|
+
(linear, every node in a single pass), the brute-force `ExactSolver` only on a
|
|
60
|
+
loop (small graphs — it re-enumerates `2**n` per node, so it is what made
|
|
61
|
+
rendering a large graph hang). With `evidence=None` these are the
|
|
62
|
+
prior-propagated marginals."""
|
|
63
|
+
from ..core import ExactSolver, LoopySolver
|
|
64
|
+
from ..metrics import is_polytree
|
|
65
|
+
|
|
66
|
+
if is_polytree(bn):
|
|
67
|
+
m = LoopySolver(bn).marginals(evidence)
|
|
68
|
+
return {n: float(m[n][1]) for n in bn.nodes}
|
|
55
69
|
solver = ExactSolver(bn)
|
|
56
70
|
return {n: solver.prob(n, 1, evidence=evidence) for n in bn.nodes}
|
|
57
71
|
|
|
@@ -107,6 +121,28 @@ def _positions(nodes: Sequence, edges: Sequence[dict],
|
|
|
107
121
|
if n not in x:
|
|
108
122
|
x[n] = counter[0] * X_SPACING
|
|
109
123
|
counter[0] += 1
|
|
124
|
+
|
|
125
|
+
# De-collide within each layer. The mean-of-inputs rule places a node centred
|
|
126
|
+
# over its parents, which is tidy for a tree but collapses distinct nodes onto
|
|
127
|
+
# the same coordinate once parents are *shared* (a non-polytree): two siblings
|
|
128
|
+
# that share an input get the same mean and draw on top of each other. Sweep
|
|
129
|
+
# each layer left to right enforcing a minimum gap, then recentre so the tidy
|
|
130
|
+
# shape is preserved. A graph with no overlaps (every tree) is unchanged.
|
|
131
|
+
by_layer: dict = defaultdict(list)
|
|
132
|
+
for n in nodes:
|
|
133
|
+
by_layer[layer[n]].append(n)
|
|
134
|
+
for ns in by_layer.values():
|
|
135
|
+
if len(ns) < 2:
|
|
136
|
+
continue
|
|
137
|
+
ns.sort(key=lambda n: (x[n], getattr(n, "id", id(n))))
|
|
138
|
+
before = sum(x[n] for n in ns) / len(ns)
|
|
139
|
+
for i in range(1, len(ns)):
|
|
140
|
+
if x[ns[i]] - x[ns[i - 1]] < X_SPACING:
|
|
141
|
+
x[ns[i]] = x[ns[i - 1]] + X_SPACING
|
|
142
|
+
shift = before - sum(x[n] for n in ns) / len(ns)
|
|
143
|
+
for n in ns:
|
|
144
|
+
x[n] += shift
|
|
145
|
+
|
|
110
146
|
if orientation == "horizontal": # root (layer 0) on the right, leaves left
|
|
111
147
|
return {n: (-layer[n] * Y_SPACING, x[n]) for n in nodes}
|
|
112
148
|
return {n: (x[n], layer[n] * Y_SPACING) for n in nodes}
|
|
@@ -121,24 +157,125 @@ def _boundary(center, toward, marker, r=R_NODE):
|
|
|
121
157
|
return (center[0] + ux * t, center[1] + uy * t)
|
|
122
158
|
|
|
123
159
|
|
|
124
|
-
def
|
|
125
|
-
"""
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
160
|
+
def _node_pt_radius(ax, fig, x, y) -> float:
|
|
161
|
+
"""`R_NODE`'s on-screen size in points, so label fitting works at whatever scale
|
|
162
|
+
the figure ends up (the auto-sized canvas, or a big user-supplied one)."""
|
|
163
|
+
(x0, _), (x1, _) = ax.transData.transform([(x, y), (x + R_NODE, y)])
|
|
164
|
+
return abs(x1 - x0) * 72.0 / float(fig.dpi)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _split_label(text: str):
|
|
168
|
+
"""Separate a node label into its name and a trailing `P=…` / `prior=…` belief
|
|
169
|
+
line (kept intact, drawn below the name)."""
|
|
170
|
+
lines = text.split("\n")
|
|
171
|
+
if len(lines) >= 2 and "=" in lines[-1]:
|
|
172
|
+
return "\n".join(lines[:-1]), lines[-1]
|
|
173
|
+
return text, ""
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _wrap_to(name: str, cols: int) -> list[str]:
|
|
177
|
+
name = name.replace("->", " → ").replace("/", " / ")
|
|
178
|
+
return textwrap.fill(name, max(4, cols), break_long_words=False).split("\n")
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
# a serif glyph is ~0.52 of the point size wide; lines sit ~1.2 point sizes apart.
|
|
182
|
+
_CHAR_W, _LINE_H = 0.52, 1.2
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _fit_label(ax, fig, x, y, text: str, is_box: bool):
|
|
186
|
+
"""Fit a node label inside the node by shrinking the font (down to `MIN_FONT`)
|
|
187
|
+
so a long name stays readable rather than overflowing; only a name too long even
|
|
188
|
+
at `MIN_FONT` is truncated with an ellipsis. Returns `(text, fontsize)`."""
|
|
189
|
+
name, belief = _split_label(text)
|
|
190
|
+
half = _node_pt_radius(ax, fig, x, y) * (0.92 if is_box else 0.74) # usable radius
|
|
191
|
+
extra = 1 if belief else 0
|
|
192
|
+
|
|
193
|
+
f = FONT
|
|
194
|
+
while f >= MIN_FONT - 1e-9:
|
|
195
|
+
cols = max(4, int(2 * half / (_CHAR_W * f)))
|
|
196
|
+
lines = _wrap_to(name, cols)
|
|
197
|
+
widest = max((len(s) for s in lines + ([belief] if belief else [])), default=0)
|
|
198
|
+
if ((len(lines) + extra) * _LINE_H * f <= 2 * half
|
|
199
|
+
and widest * _CHAR_W * f <= 2 * half):
|
|
200
|
+
break
|
|
201
|
+
f -= 0.5
|
|
202
|
+
else: # doesn't fit even at MIN_FONT: truncate
|
|
203
|
+
f = MIN_FONT
|
|
204
|
+
cols = max(4, int(2 * half / (_CHAR_W * f)))
|
|
205
|
+
rows = max(1, int(2 * half / (_LINE_H * f)) - extra)
|
|
206
|
+
lines = _wrap_to(name, cols)
|
|
207
|
+
if len(lines) > rows:
|
|
208
|
+
lines = lines[:rows]
|
|
209
|
+
lines[-1] = lines[-1][:cols - 1].rstrip() + "…"
|
|
210
|
+
|
|
211
|
+
return "\n".join(lines + ([belief] if belief else [])), f
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
_CLEAR_BUFFER = 1.45 # an edge should clear a node's centre by this * R_NODE
|
|
215
|
+
_CURVE_TRIES = (0.12, 0.2, 0.3, 0.42, 0.58) # arc3 rad magnitudes, smallest first
|
|
132
216
|
|
|
133
217
|
|
|
134
|
-
def
|
|
135
|
-
"""
|
|
136
|
-
|
|
218
|
+
def _ctrl_for_rad(start, end, rad):
|
|
219
|
+
"""The quadratic control point matplotlib's `arc3,rad` uses: the chord midpoint
|
|
220
|
+
displaced by `rad * (dy, -dx)` (its right normal). Computing it ourselves lets
|
|
221
|
+
the collision check below match the curve that actually gets drawn."""
|
|
222
|
+
mx, my = (start[0] + end[0]) / 2, (start[1] + end[1]) / 2
|
|
223
|
+
dx, dy = end[0] - start[0], end[1] - start[1]
|
|
224
|
+
return (mx + rad * dy, my - rad * dx)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _bezier_point(start, ctrl, end, t):
|
|
228
|
+
mt = 1.0 - t
|
|
229
|
+
return (mt * mt * start[0] + 2 * mt * t * ctrl[0] + t * t * end[0],
|
|
230
|
+
mt * mt * start[1] + 2 * mt * t * ctrl[1] + t * t * end[1])
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _curve_clears(start, ctrl, end, obstacles, buffer):
|
|
234
|
+
pts = [_bezier_point(start, ctrl, end, i / 18.0) for i in range(19)]
|
|
235
|
+
for q, r in obstacles:
|
|
236
|
+
if min(math.hypot(q[0] - x, q[1] - y) for x, y in pts) < buffer * r:
|
|
237
|
+
return False
|
|
238
|
+
return True
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _route_rad(start, end, obstacles, *, buffer=_CLEAR_BUFFER):
|
|
242
|
+
"""Pick an `arc3` rad that bows the edge clear of any node its straight line
|
|
243
|
+
would pass through. `0.0` when the straight edge is already clear (the common
|
|
244
|
+
case, so trees and short edges are untouched). Bows away from the side the
|
|
245
|
+
blocking nodes sit on, taking the smallest curvature that clears them."""
|
|
246
|
+
dx, dy = end[0] - start[0], end[1] - start[1]
|
|
247
|
+
L2 = dx * dx + dy * dy or 1.0
|
|
248
|
+
blockers = []
|
|
249
|
+
for q, r in obstacles:
|
|
250
|
+
t = ((q[0] - start[0]) * dx + (q[1] - start[1]) * dy) / L2
|
|
251
|
+
if not (0.02 < t < 0.98): # only nodes alongside the shaft
|
|
252
|
+
continue
|
|
253
|
+
px, py = start[0] + t * dx, start[1] + t * dy
|
|
254
|
+
if math.hypot(q[0] - px, q[1] - py) < buffer * r:
|
|
255
|
+
blockers.append(dx * (q[1] - start[1]) - dy * (q[0] - start[0]))
|
|
256
|
+
if not blockers:
|
|
257
|
+
return 0.0
|
|
258
|
+
# `(dy, -dx)` is the chord's right normal, so a positive rad bows right; a node
|
|
259
|
+
# on the left (positive cross product) is escaped by bowing right.
|
|
260
|
+
sign = 1.0 if sum(blockers) >= 0 else -1.0
|
|
261
|
+
for mag in _CURVE_TRIES:
|
|
262
|
+
rad = mag * sign
|
|
263
|
+
if _curve_clears(start, _ctrl_for_rad(start, end, rad), end, obstacles,
|
|
264
|
+
buffer * 0.92):
|
|
265
|
+
return rad
|
|
266
|
+
return _CURVE_TRIES[-1] * sign # best effort if nothing fully clears
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _arrow(ax, start, end, *, color, linestyle, lw, rad=0.0):
|
|
270
|
+
"""Draw an arrow, optionally bowed by `rad` (matplotlib `arc3` curvature) to
|
|
271
|
+
route around a node. A dashed/dotted edge is split into a dashed shaft and a
|
|
272
|
+
separate SOLID arrowhead, so the head never looks dashed (kept straight)."""
|
|
137
273
|
from matplotlib.patches import FancyArrowPatch
|
|
138
274
|
|
|
139
275
|
if linestyle in ("-", "solid", None):
|
|
140
276
|
ax.add_patch(FancyArrowPatch(start, end, arrowstyle="-|>", mutation_scale=15,
|
|
141
|
-
shrinkA=0, shrinkB=0, lw=lw, color=color,
|
|
277
|
+
shrinkA=0, shrinkB=0, lw=lw, color=color,
|
|
278
|
+
connectionstyle=f"arc3,rad={rad}", zorder=1))
|
|
142
279
|
return
|
|
143
280
|
dx, dy = end[0] - start[0], end[1] - start[1]
|
|
144
281
|
dist = math.hypot(dx, dy) or 1.0
|
|
@@ -211,9 +348,16 @@ def draw_layered(
|
|
|
211
348
|
_, ax = plt.subplots(figsize=(w, h))
|
|
212
349
|
fig = ax.figure
|
|
213
350
|
ax.set_aspect("equal")
|
|
351
|
+
pad = R_NODE + 0.28
|
|
352
|
+
ax.set_xlim(min(xs) - pad, max(xs) + pad)
|
|
353
|
+
ax.set_ylim(min(ys) - pad, max(ys) + pad)
|
|
354
|
+
ax.axis("off")
|
|
355
|
+
fig.canvas.draw() # finalise the transform so label fitting
|
|
356
|
+
# sees the real on-screen node size
|
|
214
357
|
|
|
215
358
|
for e in edges:
|
|
216
359
|
cu = pos[e["u"]]
|
|
360
|
+
rad = 0.0
|
|
217
361
|
if "attacks" in e: # arrow lands on another edge
|
|
218
362
|
a, b = e["attacks"]
|
|
219
363
|
end = ((pos[a][0] + pos[b][0]) / 2, (pos[a][1] + pos[b][1]) / 2)
|
|
@@ -225,12 +369,15 @@ def draw_layered(
|
|
|
225
369
|
gdx, gdy = start[0] - end[0], start[1] - end[1]
|
|
226
370
|
gd = math.hypot(gdx, gdy) or 1.0
|
|
227
371
|
end = (end[0] + gdx / gd * 0.06, end[1] + gdy / gd * 0.06) # small gap
|
|
372
|
+
# bow the (solid) edge around any node it would otherwise cross.
|
|
373
|
+
obstacles = [(pos[n], R_NODE) for n in nodes
|
|
374
|
+
if n is not e["u"] and n is not e["v"]]
|
|
375
|
+
rad = _route_rad(start, end, obstacles)
|
|
228
376
|
_arrow(ax, start, end, color=e.get("color", _NEUTRAL_EDGE),
|
|
229
|
-
linestyle=e.get("style", "-"), lw=e.get("width", 1.7))
|
|
377
|
+
linestyle=e.get("style", "-"), lw=e.get("width", 1.7), rad=rad)
|
|
230
378
|
if e.get("label"):
|
|
231
379
|
f = e.get("label_pos", 0.5) # fraction from start to end
|
|
232
|
-
mx = start
|
|
233
|
-
my = start[1] + f * (end[1] - start[1])
|
|
380
|
+
mx, my = _bezier_point(start, _ctrl_for_rad(start, end, rad), end, f)
|
|
234
381
|
ax.text(mx, my, e["label"], fontsize=font - 1, ha="center", va="center",
|
|
235
382
|
zorder=3, bbox=dict(boxstyle="round,pad=0.12", fc="white",
|
|
236
383
|
ec="none", alpha=0.85))
|
|
@@ -247,13 +394,9 @@ def draw_layered(
|
|
|
247
394
|
ec=ec, lw=lw, zorder=2))
|
|
248
395
|
else:
|
|
249
396
|
ax.add_patch(Circle((x, y), R_NODE, fc=color_of(n), ec=ec, lw=lw, zorder=2))
|
|
250
|
-
ax
|
|
251
|
-
|
|
397
|
+
txt, fsz = _fit_label(ax, fig, x, y, label_of(n), marker_of(n) == "s")
|
|
398
|
+
ax.text(x, y, txt, fontsize=fsz, ha="center", va="center", zorder=4)
|
|
252
399
|
|
|
253
|
-
pad = R_NODE + 0.28
|
|
254
|
-
ax.set_xlim(min(xs) - pad, max(xs) + pad)
|
|
255
|
-
ax.set_ylim(min(ys) - pad, max(ys) + pad)
|
|
256
|
-
ax.axis("off")
|
|
257
400
|
if title:
|
|
258
401
|
ax.set_title(title, fontsize=font + 5, pad=3)
|
|
259
402
|
|
|
@@ -355,19 +498,26 @@ def _legend_entries(cpds) -> list:
|
|
|
355
498
|
def render(bn, values: Optional[Mapping["Node", float]] = None, evidence=None,
|
|
356
499
|
positions: Optional[Mapping] = None, colorbar: bool = False,
|
|
357
500
|
legend: bool = False, ax=None, orientation: str = "vertical",
|
|
358
|
-
title: Optional[str] = None, path: Optional[str] = None
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
the
|
|
501
|
+
title: Optional[str] = None, path: Optional[str] = None,
|
|
502
|
+
layout: str = "radial"):
|
|
503
|
+
"""Draw a compiled `BayesianNetwork`. `layout="radial"` (default) puts the root
|
|
504
|
+
at the centre with the d-separated evidence branches fanning outward (see
|
|
505
|
+
`layout.py`); `layout="layered"` puts the root at the bottom and grows upward,
|
|
506
|
+
and `orientation="horizontal"` lays that form out left to right.
|
|
362
507
|
|
|
363
508
|
A node with no parents shows its `prior`; a node with parents (or any node
|
|
364
509
|
under evidence) shows its posterior `P`. Observed nodes get a bold amber
|
|
365
510
|
border. Pass `values={}` to skip the numbers. Likelihood-ratio edges are
|
|
366
511
|
always coloured by their LR (red disconfirming, blue confirming); set
|
|
367
|
-
`colorbar=True` for the LR scale and `legend=True` for the node-colour key.
|
|
512
|
+
`colorbar=True` for the LR scale and `legend=True` for the node-colour key.
|
|
513
|
+
A long node name shrinks its font to fit the node (and truncates only if it
|
|
514
|
+
still will not fit); names read best kept short, with the full text in `desc`."""
|
|
368
515
|
_require_matplotlib()
|
|
369
516
|
if values is None:
|
|
370
517
|
values = marginals(bn, evidence)
|
|
518
|
+
if layout == "radial":
|
|
519
|
+
from .layout import radial_positions
|
|
520
|
+
positions = {**radial_positions(bn), **(positions or {})}
|
|
371
521
|
cpd_of = {n: bn.compiled_cpd(n).cpd for n in bn.nodes}
|
|
372
522
|
|
|
373
523
|
all_lrs = [edge_lr(cpd_of[n], inp) for n in bn.nodes for inp in cpd_of[n].inputs]
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Radial layout: the root at the centre, evidence radiating outward.
|
|
2
|
+
|
|
3
|
+
An alternative to the layered (root-at-bottom) layout in `image._positions`. The
|
|
4
|
+
root sits at the origin and radius grows with depth, so the root repels everything
|
|
5
|
+
outward and depths separate into rings. The angle comes from a radial-tree
|
|
6
|
+
assignment driven by **d-separation**: each independent evidence branch (a
|
|
7
|
+
d-separated group) gets its own angular sector with a gap between sectors, leaves
|
|
8
|
+
spread across the sector in DFS order so siblings stay adjacent, and every internal
|
|
9
|
+
node sits at the circular mean of its cone's leaves — a node shared across branches
|
|
10
|
+
naturally bridges the sectors it spans. A force relaxation then settles related
|
|
11
|
+
nodes (argument edges and co-input siblings) at an equilibrium distance — pushed
|
|
12
|
+
apart when too close, pulled together when too far — so each branch clusters
|
|
13
|
+
instead of spreading maximally, while the radius stays pinned to the depth ring so
|
|
14
|
+
the root remains centred.
|
|
15
|
+
|
|
16
|
+
`radial_positions(bn, target)` returns `{node: (x, y)}` for feeding the renderer
|
|
17
|
+
via `render(..., layout="radial")`.
|
|
18
|
+
"""
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from .image import _layers
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _sink(bn):
|
|
27
|
+
"""The root: the one node that feeds nothing (no outgoing edge)."""
|
|
28
|
+
used = {i for n in bn.nodes for i in bn.compiled_cpd(n).inputs}
|
|
29
|
+
sinks = [n for n in bn.nodes if n not in used]
|
|
30
|
+
return sinks[0] if sinks else bn.nodes[-1]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _radial_angles(bn, target, inputs_of, gap_frac):
|
|
34
|
+
from ..metrics import d_separated_groups
|
|
35
|
+
from ..metrics._util import ancestors
|
|
36
|
+
|
|
37
|
+
groups = d_separated_groups(bn, target)
|
|
38
|
+
|
|
39
|
+
def dfs_leaves(parents):
|
|
40
|
+
seen, order = set(), []
|
|
41
|
+
|
|
42
|
+
def rec(n):
|
|
43
|
+
if n in seen:
|
|
44
|
+
return
|
|
45
|
+
seen.add(n)
|
|
46
|
+
if not inputs_of[n]:
|
|
47
|
+
order.append(n)
|
|
48
|
+
else:
|
|
49
|
+
for i in inputs_of[n]:
|
|
50
|
+
rec(i)
|
|
51
|
+
for p in parents:
|
|
52
|
+
rec(p)
|
|
53
|
+
return order
|
|
54
|
+
|
|
55
|
+
group_leaves = [dfs_leaves(list(g.parents)) for g in groups]
|
|
56
|
+
total = sum(len(gl) for gl in group_leaves) or 1
|
|
57
|
+
ng = max(1, len(groups))
|
|
58
|
+
# gaps separate sectors only with more than one group; a single (fully shared)
|
|
59
|
+
# group gets the whole circle.
|
|
60
|
+
gap = 0.0 if ng <= 1 else gap_frac * 2 * np.pi / ng
|
|
61
|
+
usable = 2 * np.pi - gap * ng
|
|
62
|
+
|
|
63
|
+
leaf_angle: dict = {}
|
|
64
|
+
cur = 0.0
|
|
65
|
+
for gl in group_leaves:
|
|
66
|
+
span = usable * (len(gl) / total)
|
|
67
|
+
for k, leaf in enumerate(gl):
|
|
68
|
+
leaf_angle[leaf] = cur + span * ((k + 0.5) / max(1, len(gl)))
|
|
69
|
+
cur += span + gap
|
|
70
|
+
|
|
71
|
+
ang: dict = {}
|
|
72
|
+
for n in bn.nodes:
|
|
73
|
+
cone = ancestors(bn, n) | {n}
|
|
74
|
+
ls = [c for c in cone if not inputs_of[c] and c in leaf_angle]
|
|
75
|
+
if ls:
|
|
76
|
+
xs = float(np.mean([np.cos(leaf_angle[c]) for c in ls]))
|
|
77
|
+
ys = float(np.mean([np.sin(leaf_angle[c]) for c in ls]))
|
|
78
|
+
ang[n] = float(np.arctan2(ys, xs))
|
|
79
|
+
else:
|
|
80
|
+
ang[n] = 0.0
|
|
81
|
+
return ang
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _relax(pos, R, root_i, ea, eb, sa, sb, mod, k_anchor, *, iters, L0,
|
|
85
|
+
k_rep=2.2, k_spring=0.7, k_sib=0.45):
|
|
86
|
+
"""Equilibrium-distance relaxation on the angle. All pairs repel (short range,
|
|
87
|
+
harder across d-separated groups via `mod`); related pairs — argument edges
|
|
88
|
+
(`ea→eb`) and co-input siblings (`sa,sb`) — sit on springs with rest length
|
|
89
|
+
`L0`, so they are pushed apart when closer and pulled together when farther,
|
|
90
|
+
which clusters a branch instead of spreading it maximally. A weak anchor to the
|
|
91
|
+
radial-tree start (`pos0`) keeps a connected blob from collapsing off-centre,
|
|
92
|
+
and the radius is re-snapped to the depth ring each step so all of this acts
|
|
93
|
+
angularly and the root stays at the centre."""
|
|
94
|
+
keep = R > 0
|
|
95
|
+
pos0 = pos.copy()
|
|
96
|
+
for it in range(iters):
|
|
97
|
+
diff = pos[:, None, :] - pos[None, :, :]
|
|
98
|
+
d2 = (diff ** 2).sum(-1) + 1e-3
|
|
99
|
+
rep = k_rep * mod / d2
|
|
100
|
+
np.fill_diagonal(rep, 0.0)
|
|
101
|
+
disp = (rep[..., None] * diff / np.sqrt(d2)[..., None]).sum(1)
|
|
102
|
+
for ia, ib, k in ((ea, eb, k_spring), (sa, sb, k_sib)):
|
|
103
|
+
if len(ia):
|
|
104
|
+
dv = pos[ib] - pos[ia]
|
|
105
|
+
L = np.hypot(dv[:, 0], dv[:, 1]) + 1e-9
|
|
106
|
+
f = (k * (L - L0) / L)[:, None] * dv
|
|
107
|
+
np.add.at(disp, ia, f)
|
|
108
|
+
np.add.at(disp, ib, -f)
|
|
109
|
+
disp += k_anchor * (pos0 - pos)
|
|
110
|
+
dv = disp * (0.1 * (1.0 - 0.8 * it / iters))
|
|
111
|
+
dn = np.hypot(dv[:, 0], dv[:, 1]) + 1e-9
|
|
112
|
+
pos = pos + dv * np.minimum(1.0, 0.5 / dn)[:, None]
|
|
113
|
+
pos[root_i] = [0.0, 0.0]
|
|
114
|
+
r = np.hypot(pos[:, 0], pos[:, 1]) + 1e-9
|
|
115
|
+
pos[keep] = (pos[keep] / r[keep, None]) * R[keep, None]
|
|
116
|
+
return pos
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def radial_positions(bn, target=None, *, ring=2.9, relax=True, iters=500,
|
|
120
|
+
gap_frac=0.22, k_anchor=0.18, dsep_rep=1.8):
|
|
121
|
+
"""`{node: (x, y)}` for a root-centred radial layout. `target` defaults to the
|
|
122
|
+
network's sink (its root). With `relax`, related nodes settle at an equilibrium
|
|
123
|
+
distance (clustering each branch) while the root stays centred and depths stay
|
|
124
|
+
ringed. Deterministic."""
|
|
125
|
+
from ..metrics import d_separated_groups
|
|
126
|
+
|
|
127
|
+
target = target if target is not None else _sink(bn)
|
|
128
|
+
nodes = list(bn.nodes)
|
|
129
|
+
idx = {n: i for i, n in enumerate(nodes)}
|
|
130
|
+
inputs_of = {n: list(bn.compiled_cpd(n).inputs) for n in nodes}
|
|
131
|
+
edge_dicts = [{"u": inp, "v": n} for n in nodes for inp in inputs_of[n]]
|
|
132
|
+
depth = np.array([float(_layers(nodes, edge_dicts)[n]) for n in nodes])
|
|
133
|
+
|
|
134
|
+
ang = _radial_angles(bn, target, inputs_of, gap_frac)
|
|
135
|
+
R = depth * ring
|
|
136
|
+
pos = np.array([[R[idx[n]] * np.cos(ang[n]), R[idx[n]] * np.sin(ang[n])]
|
|
137
|
+
for n in nodes])
|
|
138
|
+
root_i = idx[target]
|
|
139
|
+
pos[root_i] = [0.0, 0.0]
|
|
140
|
+
|
|
141
|
+
if relax:
|
|
142
|
+
groups = d_separated_groups(bn, target)
|
|
143
|
+
ng = max(1, len(groups))
|
|
144
|
+
# many independent branches balance themselves around the root, so a light
|
|
145
|
+
# anchor lets them cluster; few (a densely shared blob) have nothing to fan,
|
|
146
|
+
# so anchor hard to the radial-tree to keep the root central.
|
|
147
|
+
anchor = k_anchor * (1.0 + 8.0 / ng ** 2)
|
|
148
|
+
grp = np.full(len(nodes), -1)
|
|
149
|
+
for gi, g in enumerate(groups):
|
|
150
|
+
for n in g.nodes:
|
|
151
|
+
if grp[idx[n]] == -1:
|
|
152
|
+
grp[idx[n]] = gi
|
|
153
|
+
same = grp[:, None] == grp[None, :]
|
|
154
|
+
known = (grp[:, None] >= 0) & (grp[None, :] >= 0)
|
|
155
|
+
mod = np.where(known & ~same, dsep_rep, 1.0)
|
|
156
|
+
|
|
157
|
+
ea = np.array([idx[i] for n in nodes for i in inputs_of[n]], dtype=int)
|
|
158
|
+
eb = np.array([idx[n] for n in nodes for _ in inputs_of[n]], dtype=int)
|
|
159
|
+
sibs = {(min(idx[i], idx[j]), max(idx[i], idx[j]))
|
|
160
|
+
for n in nodes for a, i in enumerate(inputs_of[n])
|
|
161
|
+
for j in inputs_of[n][a + 1:]}
|
|
162
|
+
sa = np.array([s[0] for s in sibs], dtype=int)
|
|
163
|
+
sb = np.array([s[1] for s in sibs], dtype=int)
|
|
164
|
+
pos = _relax(pos, R, root_i, ea, eb, sa, sb, mod, anchor,
|
|
165
|
+
iters=iters, L0=ring)
|
|
166
|
+
return {n: (float(pos[idx[n], 0]), float(pos[idx[n], 1])) for n in nodes}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/independent_evidence.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/manipulability.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/visualization/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|