probability-flow 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {probability_flow-0.1.0 → probability_flow-0.2.0}/PKG-INFO +14 -9
  2. {probability_flow-0.1.0 → probability_flow-0.2.0}/README.md +13 -8
  3. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/visualization.py +7 -1
  4. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/__init__.py +15 -1
  5. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/visualization/image.py +194 -44
  6. probability_flow-0.2.0/probability_flow/visualization/layout.py +166 -0
  7. {probability_flow-0.1.0 → probability_flow-0.2.0}/pyproject.toml +1 -1
  8. {probability_flow-0.1.0 → probability_flow-0.2.0}/.gitignore +0 -0
  9. {probability_flow-0.1.0 → probability_flow-0.2.0}/LICENSE +0 -0
  10. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/__init__.py +0 -0
  11. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/__init__.py +0 -0
  12. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/argument.py +0 -0
  13. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/calibrate.py +0 -0
  14. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/compile.py +0 -0
  15. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/generate.py +0 -0
  16. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/aspic/handle.py +0 -0
  17. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/__init__.py +0 -0
  18. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/_logmath.py +0 -0
  19. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/bp/__init__.py +0 -0
  20. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/bp/engine.py +0 -0
  21. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/bp/message.py +0 -0
  22. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/__init__.py +0 -0
  23. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/base.py +0 -0
  24. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/independent_evidence.py +0 -0
  25. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/noisy_and.py +0 -0
  26. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/noisy_or.py +0 -0
  27. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/cpd/tabular.py +0 -0
  28. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/exact.py +0 -0
  29. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/network.py +0 -0
  30. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/core/node.py +0 -0
  31. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/_util.py +0 -0
  32. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/difficulty.py +0 -0
  33. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/dseparation.py +0 -0
  34. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/loopiness.py +0 -0
  35. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/manipulability.py +0 -0
  36. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/metrics/structure.py +0 -0
  37. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/py.typed +0 -0
  38. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/visualization/__init__.py +0 -0
  39. {probability_flow-0.1.0 → probability_flow-0.2.0}/probability_flow/visualization/style.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: probability-flow
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A from-scratch, modular discrete Bayesian-network library.
5
5
  Project-URL: Homepage, https://github.com/scalable-oversight-benchmarks/probability-flow
6
6
  Project-URL: Repository, https://github.com/scalable-oversight-benchmarks/probability-flow
@@ -237,12 +237,15 @@ target posterior is hit by calibrating the root prior. See `docs/generation.md`.
237
237
  ## Visualization (optional, `[viz]`)
238
238
 
239
239
  With the `[viz]` extra installed, a compiled network and an argument both render to
240
- a matplotlib figure with an in-house layered layout (likelihood-ratio edges
241
- coloured red→blue by their LR):
240
+ a matplotlib figure with an in-house layout — `radial` by default (the root at the
241
+ centre, the d-separated evidence branches fanning out under a force model), or
242
+ `layout="layered"` for the root-at-the-bottom tree. Likelihood-ratio edges are
243
+ coloured red→blue by their LR, and a long node name shrinks its font to fit:
242
244
 
243
245
  ```python
244
- bn.render() # or render(bn) from probability_flow.visualization
245
- guilty.assemble().render() # the argument view
246
+ bn.render() # radial by default
247
+ bn.render(layout="layered") # root-at-bottom tree (orientation="horizontal" too)
248
+ guilty.assemble().render() # the argument view (same layouts)
246
249
  ```
247
250
 
248
251
  matplotlib is imported lazily, only when you draw, so importing `probability_flow`
@@ -286,10 +289,12 @@ Working today: the build/compile flow, all four distributions, both solvers with
286
289
  linear-time structured messages, evidence, the ASPIC argument-compilation layer,
287
290
  argument serialization, the metrics seam (d-separation grouping, depth/size,
288
291
  loopiness, difficulty, manipulability), a random argument generator with structural
289
- and difficulty controls, optional matplotlib renderers, and optional JAX-based
290
- posterior calibration and parameter sensitivities. Planned (see `docs/ROADMAP.md`):
291
- a core-network serializer, the loopy-BP "topology zoo" robustness harness, and the
292
- exact manipulability range.
292
+ and difficulty controls, optional matplotlib renderers (radial and layered
293
+ layouts), and optional JAX-based posterior calibration and parameter sensitivities.
294
+ Loopy BP, the metrics seam, and the renderer are validated against the exact solver
295
+ on non-tree (shared-parent) graphs by the "topology zoo" harness
296
+ (`tools/topology_zoo.py`, `docs/topology_zoo.md`). Planned (see `docs/ROADMAP.md`):
297
+ a core-network serializer and the exact manipulability range.
293
298
 
294
299
  ## Learning more
295
300
 
@@ -203,12 +203,15 @@ target posterior is hit by calibrating the root prior. See `docs/generation.md`.
203
203
  ## Visualization (optional, `[viz]`)
204
204
 
205
205
  With the `[viz]` extra installed, a compiled network and an argument both render to
206
- a matplotlib figure with an in-house layered layout (likelihood-ratio edges
207
- coloured red→blue by their LR):
206
+ a matplotlib figure with an in-house layout — `radial` by default (the root at the
207
+ centre, the d-separated evidence branches fanning out under a force model), or
208
+ `layout="layered"` for the root-at-the-bottom tree. Likelihood-ratio edges are
209
+ coloured red→blue by their LR, and a long node name shrinks its font to fit:
208
210
 
209
211
  ```python
210
- bn.render() # or render(bn) from probability_flow.visualization
211
- guilty.assemble().render() # the argument view
212
+ bn.render() # radial by default
213
+ bn.render(layout="layered") # root-at-bottom tree (orientation="horizontal" too)
214
+ guilty.assemble().render() # the argument view (same layouts)
212
215
  ```
213
216
 
214
217
  matplotlib is imported lazily, only when you draw, so importing `probability_flow`
@@ -252,10 +255,12 @@ Working today: the build/compile flow, all four distributions, both solvers with
252
255
  linear-time structured messages, evidence, the ASPIC argument-compilation layer,
253
256
  argument serialization, the metrics seam (d-separation grouping, depth/size,
254
257
  loopiness, difficulty, manipulability), a random argument generator with structural
255
- and difficulty controls, optional matplotlib renderers, and optional JAX-based
256
- posterior calibration and parameter sensitivities. Planned (see `docs/ROADMAP.md`):
257
- a core-network serializer, the loopy-BP "topology zoo" robustness harness, and the
258
- exact manipulability range.
258
+ and difficulty controls, optional matplotlib renderers (radial and layered
259
+ layouts), and optional JAX-based posterior calibration and parameter sensitivities.
260
+ Loopy BP, the metrics seam, and the renderer are validated against the exact solver
261
+ on non-tree (shared-parent) graphs by the "topology zoo" harness
262
+ (`tools/topology_zoo.py`, `docs/topology_zoo.md`). Planned (see `docs/ROADMAP.md`):
263
+ a core-network serializer and the exact manipulability range.
259
264
 
260
265
  ## Learning more
261
266
 
@@ -37,15 +37,21 @@ def render_argument(
37
37
  orientation: str = "vertical",
38
38
  title: Optional[str] = None,
39
39
  path: Optional[str] = None,
40
+ layout: str = "radial",
40
41
  ):
41
42
  """Draw the argument rooted at `target`. Support / rebut / undermine edges are
42
43
  coloured by their likelihood ratio on the same confirming/disconfirming scale
43
44
  as the Bayesian-network view (strict is bold black, undercut purple-dotted onto
44
45
  the edge it attacks). Each claim is annotated with its compiled belief
45
46
  `P(=1 | evidence)`; pass `beliefs=False` to skip. `colorbar` / `legend` add the
46
- LR scale and the colour key (default off)."""
47
+ LR scale and the colour key (default off). `layout="radial"` puts the target at
48
+ the centre with the d-separated branches fanning outward (default "layered")."""
47
49
  _require_matplotlib()
48
50
  claims = [c for c in _reachable(target) if isinstance(c, _Claim)]
51
+ if layout == "radial":
52
+ from ..visualization.layout import radial_positions
53
+ positions = {**radial_positions(target.compile(), target),
54
+ **(positions or {})}
49
55
  all_lrs = [lr for c in claims for _src, lr, _kind in c._edges]
50
56
  color_for_lr, _ = _lr_colormap(all_lrs)
51
57
 
@@ -20,6 +20,7 @@ and docs/metrics.md.
20
20
  """
21
21
  from __future__ import annotations
22
22
 
23
+ import warnings
23
24
  from typing import TYPE_CHECKING, Optional
24
25
 
25
26
  from ..core import ExactSolver
@@ -43,7 +44,20 @@ def posterior(bn_or_target, node: Optional["Node"] = None, evidence=None, *,
43
44
  target = node if node is not None else default
44
45
  if target is None:
45
46
  raise ValueError("posterior needs a target node")
46
- return solver(bn).prob(target, 1, evidence=evidence)
47
+ s = solver(bn)
48
+ p = s.prob(target, 1, evidence=evidence)
49
+ # `ExactSolver` is always exact; an iterative solver (`LoopySolver`, the "for
50
+ # scale" path) is exact only on a polytree and merely a tight approximation on
51
+ # a loopy graph — and on a graph it cannot settle, a silently non-converged
52
+ # value is worse than that. Surface it; the flag was otherwise unread.
53
+ if getattr(s, "last_converged", True) is False:
54
+ warnings.warn(
55
+ f"the iterative solver did not converge in {getattr(s, 'last_iters', '?')} "
56
+ "iterations, so this posterior is an unconverged approximation; raise "
57
+ "max_iters or add damping, or use the exact solver on a smaller graph.",
58
+ stacklevel=2,
59
+ )
60
+ return p
47
61
 
48
62
 
49
63
  __all__ = [
@@ -1,12 +1,18 @@
1
- """matplotlib rendering: an in-house layered (DAG) layout and the network
2
- renderer. matplotlib is imported lazily inside the drawing functions, so importing
3
- this module (or `probability_flow`) never pulls it in.
4
-
5
- Layout: the root (sink) sits at the bottom and the graph grows upward to the
6
- leaves. A node's x is the mean of its inputs' x, so each node's inputs sit
7
- centred over it (a tidy tree, no edges reaching across the figure). Nodes are
8
- drawn as data-coordinate patches (circle, or box for a leaf premise), so their
9
- size relative to the spacing is fixed and arrowheads land exactly on the border.
1
+ """matplotlib rendering: the layered (DAG) layout and the network renderer.
2
+ matplotlib is imported lazily inside the drawing functions, so importing this
3
+ module (or `probability_flow`) never pulls it in.
4
+
5
+ Two layouts: the default `radial` (root at the centre, evidence radiating outward,
6
+ see `layout.py`) and the `layered` one here (`_positions`: root at the bottom, the
7
+ graph growing upward; a node's x is the mean of its inputs' x, so its inputs sit
8
+ centred over it). Nodes are drawn as data-coordinate patches (circle, or box for a
9
+ leaf premise), so their size relative to the spacing is fixed and arrowheads land
10
+ exactly on the border.
11
+
12
+ A node label is fit to the node by shrinking its font down to `MIN_FONT` (see
13
+ `_fit_label`); only a name too long to fit even then is truncated with an ellipsis.
14
+ Names still read best kept short — the full claim text belongs in `desc`, not the
15
+ name.
10
16
  """
11
17
  from __future__ import annotations
12
18
 
@@ -21,12 +27,12 @@ if TYPE_CHECKING:
21
27
  from ..core.node import Node
22
28
 
23
29
  # geometry, in data units (the axes use equal aspect)
24
- R_NODE = 0.72 # node radius / box half-extent
25
- X_SPACING = 2.1 # horizontal gap between adjacent leaves
26
- Y_SPACING = 2.5 # vertical gap between layers
30
+ R_NODE = 0.80 # node radius / box half-extent
31
+ X_SPACING = 2.2 # horizontal gap between adjacent leaves
32
+ Y_SPACING = 2.6 # vertical gap between layers
27
33
  IN_PER_UNIT = 0.62 # figure inches per data unit
28
- FONT = 6.5 # default label point size (the BN renderer's)
29
- WRAP = 11 # characters before a label line wraps
34
+ FONT = 6.5 # preferred label point size; shrinks (to MIN_FONT) to fit
35
+ MIN_FONT = 3.6 # smallest label font before a name is truncated instead
30
36
  DPI = 200 # saved-image resolution
31
37
  _NEUTRAL_EDGE = "#9aa0a6"
32
38
  _OBSERVED_EDGE = "#c77d00" # border for an observed (evidence) node
@@ -48,10 +54,18 @@ def _require_matplotlib():
48
54
 
49
55
 
50
56
  def marginals(bn, evidence=None) -> dict:
51
- """`P(node = 1 | evidence)` for every node, via the exact solver (small graphs
52
- only). With `evidence=None` these are the prior-propagated marginals."""
53
- from ..core import ExactSolver
54
-
57
+ """`P(node = 1 | evidence)` for every node, for the node labels. Polytree-aware,
58
+ like the `Argument` handle: one exact `LoopySolver` propagation on a polytree
59
+ (linear, every node in a single pass), the brute-force `ExactSolver` only on a
60
+ loop (small graphs — it re-enumerates `2**n` per node, so it is what made
61
+ rendering a large graph hang). With `evidence=None` these are the
62
+ prior-propagated marginals."""
63
+ from ..core import ExactSolver, LoopySolver
64
+ from ..metrics import is_polytree
65
+
66
+ if is_polytree(bn):
67
+ m = LoopySolver(bn).marginals(evidence)
68
+ return {n: float(m[n][1]) for n in bn.nodes}
55
69
  solver = ExactSolver(bn)
56
70
  return {n: solver.prob(n, 1, evidence=evidence) for n in bn.nodes}
57
71
 
@@ -107,6 +121,28 @@ def _positions(nodes: Sequence, edges: Sequence[dict],
107
121
  if n not in x:
108
122
  x[n] = counter[0] * X_SPACING
109
123
  counter[0] += 1
124
+
125
+ # De-collide within each layer. The mean-of-inputs rule places a node centred
126
+ # over its parents, which is tidy for a tree but collapses distinct nodes onto
127
+ # the same coordinate once parents are *shared* (a non-polytree): two siblings
128
+ # that share an input get the same mean and draw on top of each other. Sweep
129
+ # each layer left to right enforcing a minimum gap, then recentre so the tidy
130
+ # shape is preserved. A graph with no overlaps (every tree) is unchanged.
131
+ by_layer: dict = defaultdict(list)
132
+ for n in nodes:
133
+ by_layer[layer[n]].append(n)
134
+ for ns in by_layer.values():
135
+ if len(ns) < 2:
136
+ continue
137
+ ns.sort(key=lambda n: (x[n], getattr(n, "id", id(n))))
138
+ before = sum(x[n] for n in ns) / len(ns)
139
+ for i in range(1, len(ns)):
140
+ if x[ns[i]] - x[ns[i - 1]] < X_SPACING:
141
+ x[ns[i]] = x[ns[i - 1]] + X_SPACING
142
+ shift = before - sum(x[n] for n in ns) / len(ns)
143
+ for n in ns:
144
+ x[n] += shift
145
+
110
146
  if orientation == "horizontal": # root (layer 0) on the right, leaves left
111
147
  return {n: (-layer[n] * Y_SPACING, x[n]) for n in nodes}
112
148
  return {n: (x[n], layer[n] * Y_SPACING) for n in nodes}
@@ -121,24 +157,125 @@ def _boundary(center, toward, marker, r=R_NODE):
121
157
  return (center[0] + ux * t, center[1] + uy * t)
122
158
 
123
159
 
124
- def _wrap(text: str) -> str:
125
- """Wrap a label to the node width. `->` and `/` (common in auto-named helper
126
- nodes) become break opportunities, and words are never split mid-token."""
127
- out = []
128
- for line in text.split("\n"):
129
- line = line.replace("->", " → ").replace("/", " / ")
130
- out.append(textwrap.fill(line, WRAP, break_long_words=False))
131
- return "\n".join(out)
160
+ def _node_pt_radius(ax, fig, x, y) -> float:
161
+ """`R_NODE`'s on-screen size in points, so label fitting works at whatever scale
162
+ the figure ends up (the auto-sized canvas, or a big user-supplied one)."""
163
+ (x0, _), (x1, _) = ax.transData.transform([(x, y), (x + R_NODE, y)])
164
+ return abs(x1 - x0) * 72.0 / float(fig.dpi)
165
+
166
+
167
+ def _split_label(text: str):
168
+ """Separate a node label into its name and a trailing `P=…` / `prior=…` belief
169
+ line (kept intact, drawn below the name)."""
170
+ lines = text.split("\n")
171
+ if len(lines) >= 2 and "=" in lines[-1]:
172
+ return "\n".join(lines[:-1]), lines[-1]
173
+ return text, ""
174
+
175
+
176
+ def _wrap_to(name: str, cols: int) -> list[str]:
177
+ name = name.replace("->", " → ").replace("/", " / ")
178
+ return textwrap.fill(name, max(4, cols), break_long_words=False).split("\n")
179
+
180
+
181
+ # a serif glyph is ~0.52 of the point size wide; lines sit ~1.2 point sizes apart.
182
+ _CHAR_W, _LINE_H = 0.52, 1.2
183
+
184
+
185
+ def _fit_label(ax, fig, x, y, text: str, is_box: bool):
186
+ """Fit a node label inside the node by shrinking the font (down to `MIN_FONT`)
187
+ so a long name stays readable rather than overflowing; only a name too long even
188
+ at `MIN_FONT` is truncated with an ellipsis. Returns `(text, fontsize)`."""
189
+ name, belief = _split_label(text)
190
+ half = _node_pt_radius(ax, fig, x, y) * (0.92 if is_box else 0.74) # usable radius
191
+ extra = 1 if belief else 0
192
+
193
+ f = FONT
194
+ while f >= MIN_FONT - 1e-9:
195
+ cols = max(4, int(2 * half / (_CHAR_W * f)))
196
+ lines = _wrap_to(name, cols)
197
+ widest = max((len(s) for s in lines + ([belief] if belief else [])), default=0)
198
+ if ((len(lines) + extra) * _LINE_H * f <= 2 * half
199
+ and widest * _CHAR_W * f <= 2 * half):
200
+ break
201
+ f -= 0.5
202
+ else: # doesn't fit even at MIN_FONT: truncate
203
+ f = MIN_FONT
204
+ cols = max(4, int(2 * half / (_CHAR_W * f)))
205
+ rows = max(1, int(2 * half / (_LINE_H * f)) - extra)
206
+ lines = _wrap_to(name, cols)
207
+ if len(lines) > rows:
208
+ lines = lines[:rows]
209
+ lines[-1] = lines[-1][:cols - 1].rstrip() + "…"
210
+
211
+ return "\n".join(lines + ([belief] if belief else [])), f
212
+
213
+
214
+ _CLEAR_BUFFER = 1.45 # an edge should clear a node's centre by this * R_NODE
215
+ _CURVE_TRIES = (0.12, 0.2, 0.3, 0.42, 0.58) # arc3 rad magnitudes, smallest first
132
216
 
133
217
 
134
- def _arrow(ax, start, end, *, color, linestyle, lw):
135
- """Draw an arrow. A dashed/dotted edge is split into a dashed shaft and a
136
- separate SOLID arrowhead, so the head never looks dashed."""
218
+ def _ctrl_for_rad(start, end, rad):
219
+ """The quadratic control point matplotlib's `arc3,rad` uses: the chord midpoint
220
+ displaced by `rad * (dy, -dx)` (its right normal). Computing it ourselves lets
221
+ the collision check below match the curve that actually gets drawn."""
222
+ mx, my = (start[0] + end[0]) / 2, (start[1] + end[1]) / 2
223
+ dx, dy = end[0] - start[0], end[1] - start[1]
224
+ return (mx + rad * dy, my - rad * dx)
225
+
226
+
227
+ def _bezier_point(start, ctrl, end, t):
228
+ mt = 1.0 - t
229
+ return (mt * mt * start[0] + 2 * mt * t * ctrl[0] + t * t * end[0],
230
+ mt * mt * start[1] + 2 * mt * t * ctrl[1] + t * t * end[1])
231
+
232
+
233
+ def _curve_clears(start, ctrl, end, obstacles, buffer):
234
+ pts = [_bezier_point(start, ctrl, end, i / 18.0) for i in range(19)]
235
+ for q, r in obstacles:
236
+ if min(math.hypot(q[0] - x, q[1] - y) for x, y in pts) < buffer * r:
237
+ return False
238
+ return True
239
+
240
+
241
+ def _route_rad(start, end, obstacles, *, buffer=_CLEAR_BUFFER):
242
+ """Pick an `arc3` rad that bows the edge clear of any node its straight line
243
+ would pass through. `0.0` when the straight edge is already clear (the common
244
+ case, so trees and short edges are untouched). Bows away from the side the
245
+ blocking nodes sit on, taking the smallest curvature that clears them."""
246
+ dx, dy = end[0] - start[0], end[1] - start[1]
247
+ L2 = dx * dx + dy * dy or 1.0
248
+ blockers = []
249
+ for q, r in obstacles:
250
+ t = ((q[0] - start[0]) * dx + (q[1] - start[1]) * dy) / L2
251
+ if not (0.02 < t < 0.98): # only nodes alongside the shaft
252
+ continue
253
+ px, py = start[0] + t * dx, start[1] + t * dy
254
+ if math.hypot(q[0] - px, q[1] - py) < buffer * r:
255
+ blockers.append(dx * (q[1] - start[1]) - dy * (q[0] - start[0]))
256
+ if not blockers:
257
+ return 0.0
258
+ # `(dy, -dx)` is the chord's right normal, so a positive rad bows right; a node
259
+ # on the left (positive cross product) is escaped by bowing right.
260
+ sign = 1.0 if sum(blockers) >= 0 else -1.0
261
+ for mag in _CURVE_TRIES:
262
+ rad = mag * sign
263
+ if _curve_clears(start, _ctrl_for_rad(start, end, rad), end, obstacles,
264
+ buffer * 0.92):
265
+ return rad
266
+ return _CURVE_TRIES[-1] * sign # best effort if nothing fully clears
267
+
268
+
269
+ def _arrow(ax, start, end, *, color, linestyle, lw, rad=0.0):
270
+ """Draw an arrow, optionally bowed by `rad` (matplotlib `arc3` curvature) to
271
+ route around a node. A dashed/dotted edge is split into a dashed shaft and a
272
+ separate SOLID arrowhead, so the head never looks dashed (kept straight)."""
137
273
  from matplotlib.patches import FancyArrowPatch
138
274
 
139
275
  if linestyle in ("-", "solid", None):
140
276
  ax.add_patch(FancyArrowPatch(start, end, arrowstyle="-|>", mutation_scale=15,
141
- shrinkA=0, shrinkB=0, lw=lw, color=color, zorder=1))
277
+ shrinkA=0, shrinkB=0, lw=lw, color=color,
278
+ connectionstyle=f"arc3,rad={rad}", zorder=1))
142
279
  return
143
280
  dx, dy = end[0] - start[0], end[1] - start[1]
144
281
  dist = math.hypot(dx, dy) or 1.0
@@ -211,9 +348,16 @@ def draw_layered(
211
348
  _, ax = plt.subplots(figsize=(w, h))
212
349
  fig = ax.figure
213
350
  ax.set_aspect("equal")
351
+ pad = R_NODE + 0.28
352
+ ax.set_xlim(min(xs) - pad, max(xs) + pad)
353
+ ax.set_ylim(min(ys) - pad, max(ys) + pad)
354
+ ax.axis("off")
355
+ fig.canvas.draw() # finalise the transform so label fitting
356
+ # sees the real on-screen node size
214
357
 
215
358
  for e in edges:
216
359
  cu = pos[e["u"]]
360
+ rad = 0.0
217
361
  if "attacks" in e: # arrow lands on another edge
218
362
  a, b = e["attacks"]
219
363
  end = ((pos[a][0] + pos[b][0]) / 2, (pos[a][1] + pos[b][1]) / 2)
@@ -225,12 +369,15 @@ def draw_layered(
225
369
  gdx, gdy = start[0] - end[0], start[1] - end[1]
226
370
  gd = math.hypot(gdx, gdy) or 1.0
227
371
  end = (end[0] + gdx / gd * 0.06, end[1] + gdy / gd * 0.06) # small gap
372
+ # bow the (solid) edge around any node it would otherwise cross.
373
+ obstacles = [(pos[n], R_NODE) for n in nodes
374
+ if n is not e["u"] and n is not e["v"]]
375
+ rad = _route_rad(start, end, obstacles)
228
376
  _arrow(ax, start, end, color=e.get("color", _NEUTRAL_EDGE),
229
- linestyle=e.get("style", "-"), lw=e.get("width", 1.7))
377
+ linestyle=e.get("style", "-"), lw=e.get("width", 1.7), rad=rad)
230
378
  if e.get("label"):
231
379
  f = e.get("label_pos", 0.5) # fraction from start to end
232
- mx = start[0] + f * (end[0] - start[0])
233
- my = start[1] + f * (end[1] - start[1])
380
+ mx, my = _bezier_point(start, _ctrl_for_rad(start, end, rad), end, f)
234
381
  ax.text(mx, my, e["label"], fontsize=font - 1, ha="center", va="center",
235
382
  zorder=3, bbox=dict(boxstyle="round,pad=0.12", fc="white",
236
383
  ec="none", alpha=0.85))
@@ -247,13 +394,9 @@ def draw_layered(
247
394
  ec=ec, lw=lw, zorder=2))
248
395
  else:
249
396
  ax.add_patch(Circle((x, y), R_NODE, fc=color_of(n), ec=ec, lw=lw, zorder=2))
250
- ax.text(x, y, _wrap(label_of(n)), fontsize=font, ha="center", va="center",
251
- zorder=4)
397
+ txt, fsz = _fit_label(ax, fig, x, y, label_of(n), marker_of(n) == "s")
398
+ ax.text(x, y, txt, fontsize=fsz, ha="center", va="center", zorder=4)
252
399
 
253
- pad = R_NODE + 0.28
254
- ax.set_xlim(min(xs) - pad, max(xs) + pad)
255
- ax.set_ylim(min(ys) - pad, max(ys) + pad)
256
- ax.axis("off")
257
400
  if title:
258
401
  ax.set_title(title, fontsize=font + 5, pad=3)
259
402
 
@@ -355,19 +498,26 @@ def _legend_entries(cpds) -> list:
355
498
  def render(bn, values: Optional[Mapping["Node", float]] = None, evidence=None,
356
499
  positions: Optional[Mapping] = None, colorbar: bool = False,
357
500
  legend: bool = False, ax=None, orientation: str = "vertical",
358
- title: Optional[str] = None, path: Optional[str] = None):
359
- """Draw a compiled `BayesianNetwork`. By default the root sits at the bottom;
360
- pass `orientation="horizontal"` to lay it out left to right with the root on
361
- the right.
501
+ title: Optional[str] = None, path: Optional[str] = None,
502
+ layout: str = "radial"):
503
+ """Draw a compiled `BayesianNetwork`. `layout="radial"` (default) puts the root
504
+ at the centre with the d-separated evidence branches fanning outward (see
505
+ `layout.py`); `layout="layered"` puts the root at the bottom and grows upward,
506
+ and `orientation="horizontal"` lays that form out left to right.
362
507
 
363
508
  A node with no parents shows its `prior`; a node with parents (or any node
364
509
  under evidence) shows its posterior `P`. Observed nodes get a bold amber
365
510
  border. Pass `values={}` to skip the numbers. Likelihood-ratio edges are
366
511
  always coloured by their LR (red disconfirming, blue confirming); set
367
- `colorbar=True` for the LR scale and `legend=True` for the node-colour key."""
512
+ `colorbar=True` for the LR scale and `legend=True` for the node-colour key.
513
+ A long node name shrinks its font to fit the node (and truncates only if it
514
+ still will not fit); names read best kept short, with the full text in `desc`."""
368
515
  _require_matplotlib()
369
516
  if values is None:
370
517
  values = marginals(bn, evidence)
518
+ if layout == "radial":
519
+ from .layout import radial_positions
520
+ positions = {**radial_positions(bn), **(positions or {})}
371
521
  cpd_of = {n: bn.compiled_cpd(n).cpd for n in bn.nodes}
372
522
 
373
523
  all_lrs = [edge_lr(cpd_of[n], inp) for n in bn.nodes for inp in cpd_of[n].inputs]
@@ -0,0 +1,166 @@
1
+ """Radial layout: the root at the centre, evidence radiating outward.
2
+
3
+ An alternative to the layered (root-at-bottom) layout in `image._positions`. The
4
+ root sits at the origin and radius grows with depth, so the root repels everything
5
+ outward and depths separate into rings. The angle comes from a radial-tree
6
+ assignment driven by **d-separation**: each independent evidence branch (a
7
+ d-separated group) gets its own angular sector with a gap between sectors, leaves
8
+ spread across the sector in DFS order so siblings stay adjacent, and every internal
9
+ node sits at the circular mean of its cone's leaves — a node shared across branches
10
+ naturally bridges the sectors it spans. A force relaxation then settles related
11
+ nodes (argument edges and co-input siblings) at an equilibrium distance — pushed
12
+ apart when too close, pulled together when too far — so each branch clusters
13
+ instead of spreading maximally, while the radius stays pinned to the depth ring so
14
+ the root remains centred.
15
+
16
+ `radial_positions(bn, target)` returns `{node: (x, y)}` for feeding the renderer
17
+ via `render(..., layout="radial")`.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import numpy as np
22
+
23
+ from .image import _layers
24
+
25
+
26
+ def _sink(bn):
27
+ """The root: the one node that feeds nothing (no outgoing edge)."""
28
+ used = {i for n in bn.nodes for i in bn.compiled_cpd(n).inputs}
29
+ sinks = [n for n in bn.nodes if n not in used]
30
+ return sinks[0] if sinks else bn.nodes[-1]
31
+
32
+
33
+ def _radial_angles(bn, target, inputs_of, gap_frac):
34
+ from ..metrics import d_separated_groups
35
+ from ..metrics._util import ancestors
36
+
37
+ groups = d_separated_groups(bn, target)
38
+
39
+ def dfs_leaves(parents):
40
+ seen, order = set(), []
41
+
42
+ def rec(n):
43
+ if n in seen:
44
+ return
45
+ seen.add(n)
46
+ if not inputs_of[n]:
47
+ order.append(n)
48
+ else:
49
+ for i in inputs_of[n]:
50
+ rec(i)
51
+ for p in parents:
52
+ rec(p)
53
+ return order
54
+
55
+ group_leaves = [dfs_leaves(list(g.parents)) for g in groups]
56
+ total = sum(len(gl) for gl in group_leaves) or 1
57
+ ng = max(1, len(groups))
58
+ # gaps separate sectors only with more than one group; a single (fully shared)
59
+ # group gets the whole circle.
60
+ gap = 0.0 if ng <= 1 else gap_frac * 2 * np.pi / ng
61
+ usable = 2 * np.pi - gap * ng
62
+
63
+ leaf_angle: dict = {}
64
+ cur = 0.0
65
+ for gl in group_leaves:
66
+ span = usable * (len(gl) / total)
67
+ for k, leaf in enumerate(gl):
68
+ leaf_angle[leaf] = cur + span * ((k + 0.5) / max(1, len(gl)))
69
+ cur += span + gap
70
+
71
+ ang: dict = {}
72
+ for n in bn.nodes:
73
+ cone = ancestors(bn, n) | {n}
74
+ ls = [c for c in cone if not inputs_of[c] and c in leaf_angle]
75
+ if ls:
76
+ xs = float(np.mean([np.cos(leaf_angle[c]) for c in ls]))
77
+ ys = float(np.mean([np.sin(leaf_angle[c]) for c in ls]))
78
+ ang[n] = float(np.arctan2(ys, xs))
79
+ else:
80
+ ang[n] = 0.0
81
+ return ang
82
+
83
+
84
+ def _relax(pos, R, root_i, ea, eb, sa, sb, mod, k_anchor, *, iters, L0,
85
+ k_rep=2.2, k_spring=0.7, k_sib=0.45):
86
+ """Equilibrium-distance relaxation on the angle. All pairs repel (short range,
87
+ harder across d-separated groups via `mod`); related pairs — argument edges
88
+ (`ea→eb`) and co-input siblings (`sa,sb`) — sit on springs with rest length
89
+ `L0`, so they are pushed apart when closer and pulled together when farther,
90
+ which clusters a branch instead of spreading it maximally. A weak anchor to the
91
+ radial-tree start (`pos0`) keeps a connected blob from collapsing off-centre,
92
+ and the radius is re-snapped to the depth ring each step so all of this acts
93
+ angularly and the root stays at the centre."""
94
+ keep = R > 0
95
+ pos0 = pos.copy()
96
+ for it in range(iters):
97
+ diff = pos[:, None, :] - pos[None, :, :]
98
+ d2 = (diff ** 2).sum(-1) + 1e-3
99
+ rep = k_rep * mod / d2
100
+ np.fill_diagonal(rep, 0.0)
101
+ disp = (rep[..., None] * diff / np.sqrt(d2)[..., None]).sum(1)
102
+ for ia, ib, k in ((ea, eb, k_spring), (sa, sb, k_sib)):
103
+ if len(ia):
104
+ dv = pos[ib] - pos[ia]
105
+ L = np.hypot(dv[:, 0], dv[:, 1]) + 1e-9
106
+ f = (k * (L - L0) / L)[:, None] * dv
107
+ np.add.at(disp, ia, f)
108
+ np.add.at(disp, ib, -f)
109
+ disp += k_anchor * (pos0 - pos)
110
+ dv = disp * (0.1 * (1.0 - 0.8 * it / iters))
111
+ dn = np.hypot(dv[:, 0], dv[:, 1]) + 1e-9
112
+ pos = pos + dv * np.minimum(1.0, 0.5 / dn)[:, None]
113
+ pos[root_i] = [0.0, 0.0]
114
+ r = np.hypot(pos[:, 0], pos[:, 1]) + 1e-9
115
+ pos[keep] = (pos[keep] / r[keep, None]) * R[keep, None]
116
+ return pos
117
+
118
+
119
+ def radial_positions(bn, target=None, *, ring=2.9, relax=True, iters=500,
120
+ gap_frac=0.22, k_anchor=0.18, dsep_rep=1.8):
121
+ """`{node: (x, y)}` for a root-centred radial layout. `target` defaults to the
122
+ network's sink (its root). With `relax`, related nodes settle at an equilibrium
123
+ distance (clustering each branch) while the root stays centred and depths stay
124
+ ringed. Deterministic."""
125
+ from ..metrics import d_separated_groups
126
+
127
+ target = target if target is not None else _sink(bn)
128
+ nodes = list(bn.nodes)
129
+ idx = {n: i for i, n in enumerate(nodes)}
130
+ inputs_of = {n: list(bn.compiled_cpd(n).inputs) for n in nodes}
131
+ edge_dicts = [{"u": inp, "v": n} for n in nodes for inp in inputs_of[n]]
132
+ depth = np.array([float(_layers(nodes, edge_dicts)[n]) for n in nodes])
133
+
134
+ ang = _radial_angles(bn, target, inputs_of, gap_frac)
135
+ R = depth * ring
136
+ pos = np.array([[R[idx[n]] * np.cos(ang[n]), R[idx[n]] * np.sin(ang[n])]
137
+ for n in nodes])
138
+ root_i = idx[target]
139
+ pos[root_i] = [0.0, 0.0]
140
+
141
+ if relax:
142
+ groups = d_separated_groups(bn, target)
143
+ ng = max(1, len(groups))
144
+ # many independent branches balance themselves around the root, so a light
145
+ # anchor lets them cluster; few (a densely shared blob) have nothing to fan,
146
+ # so anchor hard to the radial-tree to keep the root central.
147
+ anchor = k_anchor * (1.0 + 8.0 / ng ** 2)
148
+ grp = np.full(len(nodes), -1)
149
+ for gi, g in enumerate(groups):
150
+ for n in g.nodes:
151
+ if grp[idx[n]] == -1:
152
+ grp[idx[n]] = gi
153
+ same = grp[:, None] == grp[None, :]
154
+ known = (grp[:, None] >= 0) & (grp[None, :] >= 0)
155
+ mod = np.where(known & ~same, dsep_rep, 1.0)
156
+
157
+ ea = np.array([idx[i] for n in nodes for i in inputs_of[n]], dtype=int)
158
+ eb = np.array([idx[n] for n in nodes for _ in inputs_of[n]], dtype=int)
159
+ sibs = {(min(idx[i], idx[j]), max(idx[i], idx[j]))
160
+ for n in nodes for a, i in enumerate(inputs_of[n])
161
+ for j in inputs_of[n][a + 1:]}
162
+ sa = np.array([s[0] for s in sibs], dtype=int)
163
+ sb = np.array([s[1] for s in sibs], dtype=int)
164
+ pos = _relax(pos, R, root_i, ea, eb, sa, sb, mod, anchor,
165
+ iters=iters, L0=ring)
166
+ return {n: (float(pos[idx[n], 0]), float(pos[idx[n], 1])) for n in nodes}
@@ -10,7 +10,7 @@ build-backend = "hatchling.build"
10
10
 
11
11
  [project]
12
12
  name = "probability-flow"
13
- version = "0.1.0"
13
+ version = "0.2.0"
14
14
  description = "A from-scratch, modular discrete Bayesian-network library."
15
15
  readme = "README.md"
16
16
  requires-python = ">=3.11"