tramdag 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tramdag-0.2.0/.claude/scheduled_tasks.lock +1 -0
  2. tramdag-0.2.0/.claude/settings.local.json +26 -0
  3. tramdag-0.2.0/.gitignore +24 -0
  4. tramdag-0.2.0/CHANGELOG.md +93 -0
  5. tramdag-0.2.0/CLAUDE.md +113 -0
  6. tramdag-0.2.0/LICENSE +21 -0
  7. tramdag-0.2.0/PKG-INFO +206 -0
  8. tramdag-0.2.0/README.md +183 -0
  9. tramdag-0.2.0/data/carefl/obs.csv +5001 -0
  10. tramdag-0.2.0/data/carefl/truth.json +206 -0
  11. tramdag-0.2.0/data/magic-mrclean/README.md +101 -0
  12. tramdag-0.2.0/data/magic-mrclean/fit_ls.R +98 -0
  13. tramdag-0.2.0/data/magic-mrclean/ls/obs.csv +1276 -0
  14. tramdag-0.2.0/data/magic-mrclean/ls/rct.csv +501 -0
  15. tramdag-0.2.0/data/magic-mrclean/ls/ref_ls/ate.csv +2 -0
  16. tramdag-0.2.0/data/magic-mrclean/ls/ref_ls/coefficients.csv +24 -0
  17. tramdag-0.2.0/data/magic-mrclean/ls/truth.json +12 -0
  18. tramdag-0.2.0/data/magic-mrclean/nl/obs.csv +1276 -0
  19. tramdag-0.2.0/data/magic-mrclean/nl/rct.csv +501 -0
  20. tramdag-0.2.0/data/magic-mrclean/nl/ref_ls/ate.csv +2 -0
  21. tramdag-0.2.0/data/magic-mrclean/nl/ref_ls/coefficients.csv +24 -0
  22. tramdag-0.2.0/data/magic-mrclean/nl/truth.json +12 -0
  23. tramdag-0.2.0/data/triangle/atan/obs.csv +5001 -0
  24. tramdag-0.2.0/data/triangle/atan/truth.json +19 -0
  25. tramdag-0.2.0/data/triangle/linear/obs.csv +5001 -0
  26. tramdag-0.2.0/data/triangle/linear/truth.json +21 -0
  27. tramdag-0.2.0/data/triangle/sin/obs.csv +5001 -0
  28. tramdag-0.2.0/data/triangle/sin/truth.json +19 -0
  29. tramdag-0.2.0/data/triangle-mixed/exp/obs.csv +5001 -0
  30. tramdag-0.2.0/data/triangle-mixed/exp/truth.json +30 -0
  31. tramdag-0.2.0/data/triangle-mixed/linear/obs.csv +5001 -0
  32. tramdag-0.2.0/data/triangle-mixed/linear/truth.json +32 -0
  33. tramdag-0.2.0/data/vaca/obs.csv +5001 -0
  34. tramdag-0.2.0/data/vaca/truth.json +35 -0
  35. tramdag-0.2.0/docs/img/nll_vs_time_stroke-ls.png +0 -0
  36. tramdag-0.2.0/docs/img/nll_vs_time_vaca-ci.png +0 -0
  37. tramdag-0.2.0/docs/stroke-case-study.md +111 -0
  38. tramdag-0.2.0/docs/training-speed.md +155 -0
  39. tramdag-0.2.0/experiments/all_ls_flow.py +21 -0
  40. tramdag-0.2.0/experiments/all_ls_long.py +14 -0
  41. tramdag-0.2.0/experiments/bench_training.py +290 -0
  42. tramdag-0.2.0/experiments/common.py +370 -0
  43. tramdag-0.2.0/experiments/counterfactual_demo.py +107 -0
  44. tramdag-0.2.0/experiments/nihss6_flow.py +21 -0
  45. tramdag-0.2.0/experiments/paper_carefl.py +75 -0
  46. tramdag-0.2.0/experiments/paper_common.py +115 -0
  47. tramdag-0.2.0/experiments/paper_triangle.py +87 -0
  48. tramdag-0.2.0/experiments/paper_triangle_mixed.py +109 -0
  49. tramdag-0.2.0/experiments/paper_vaca.py +85 -0
  50. tramdag-0.2.0/experiments/sim_flow.py +27 -0
  51. tramdag-0.2.0/experiments/validate_ls.py +112 -0
  52. tramdag-0.2.0/notebooks/README.md +47 -0
  53. tramdag-0.2.0/notebooks/demo_tram_dag_colab.ipynb +456 -0
  54. tramdag-0.2.0/notebooks/demo_tram_dag_colab.py +326 -0
  55. tramdag-0.2.0/notebooks/intro_tram_dag.py +486 -0
  56. tramdag-0.2.0/pyproject.toml +49 -0
  57. tramdag-0.2.0/src/tramdag/__init__.py +22 -0
  58. tramdag-0.2.0/src/tramdag/conditioners.py +74 -0
  59. tramdag-0.2.0/src/tramdag/flow.py +427 -0
  60. tramdag-0.2.0/src/tramdag/simulations/__init__.py +23 -0
  61. tramdag-0.2.0/src/tramdag/simulations/carefl.py +125 -0
  62. tramdag-0.2.0/src/tramdag/simulations/magic_mrclean.py +259 -0
  63. tramdag-0.2.0/src/tramdag/simulations/triangle.py +252 -0
  64. tramdag-0.2.0/src/tramdag/simulations/vaca.py +130 -0
  65. tramdag-0.2.0/src/tramdag/spec.py +99 -0
  66. tramdag-0.2.0/src/tramdag/transforms.py +271 -0
  67. tramdag-0.2.0/tests/test_fit_schedules.py +112 -0
  68. tramdag-0.2.0/tests/test_flow.py +193 -0
  69. tramdag-0.2.0/tests/test_paper_dgps.py +242 -0
  70. tramdag-0.2.0/tests/test_simulations.py +227 -0
  71. tramdag-0.2.0/uv.lock +1856 -0
@@ -0,0 +1 @@
1
+ {"sessionId":"ea0171fe-94d9-4044-b1c8-81400b3676ad","pid":50879,"procStart":"Thu Jun 11 09:52:35 2026","acquiredAt":1781179094640}
@@ -0,0 +1,26 @@
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(git add *)",
5
+ "Bash(git commit -m ' *)",
6
+ "Bash(git checkout *)",
7
+ "Bash(uv run *)",
8
+ "Bash(uv venv *)",
9
+ "Bash(uv pip *)",
10
+ "Bash(/tmp/zdtest/bin/python -c ' *)",
11
+ "Bash(MPLBACKEND=Agg uv run python notebooks/demo_tram_dag_colab.py)",
12
+ "Bash(uvx jupytext *)",
13
+ "Bash(git check-ignore *)",
14
+ "Bash(git commit -q -m 'fit\\(\\): lr schedules + per-node freezing \\(defaults unchanged\\) *)",
15
+ "Bash(git commit -q -m 'Training-speed benchmark + report *)",
16
+ "Bash(git commit -q -m 'Colab GPU demo \\(bimodal VACA benchmark\\) + README badge *)",
17
+ "Bash(git push *)",
18
+ "Bash(gh auth *)",
19
+ "Bash(gh api *)",
20
+ "Bash(git commit -q -m 'Address review: option guide in report + notebooks README *)",
21
+ "Bash(gh pr *)",
22
+ "Bash(git pull *)",
23
+ "Bash(git branch *)"
24
+ ]
25
+ }
26
+ }
@@ -0,0 +1,24 @@
1
+ # Python / uv
2
+ .venv/
3
+ __pycache__/
4
+ *.py[cod]
5
+ .pytest_cache/
6
+ *.egg-info/
7
+
8
+ # Experiment outputs — regenerable, and (for the clinical 'magic' source) derived
9
+ # from patient data, so never committed. The synthetic data/ folder IS tracked.
10
+ results/
11
+
12
+ # R
13
+ .Rhistory
14
+ .RData
15
+
16
+ # Notebooks: jupytext py:percent files are the source of truth;
17
+ # generated .ipynb (with embedded image outputs) stay out of git.
18
+ notebooks/*.ipynb
19
+ # exception: the Colab demo needs a tracked (output-stripped) ipynb for the
20
+ # "Open in Colab" badge; regenerate with `uvx jupytext --to ipynb <demo>.py`
21
+ !notebooks/demo_tram_dag_colab.ipynb
22
+
23
+ # build artifacts
24
+ dist/
@@ -0,0 +1,93 @@
1
+ # Changelog
2
+
3
+ ## 0.2.0 (2026-06-12)
4
+
5
+ First PyPI release: `pip install tramdag`.
6
+
7
+ ### Changed (naming & packaging)
8
+
9
+ - **Renamed**: Python package `zuko_dag` → **`tramdag`** (conventional alias
10
+ `import tramdag as td`); GitHub repo `tram-dag-zuko` → `tensorchiefs/tramdag`
11
+ (old URLs redirect). The package implements TRAM-DAGs; zuko names the backend.
12
+ No API changes; old checkpoints still load. References to the original
13
+ Keras/TF implementation (tensorchiefs/tram-dag) reworded to avoid
14
+ self-reference.
15
+ - **MIT license** added; PyPI metadata (authors, urls, classifiers); runtime
16
+ dependencies trimmed to `torch`, `zuko`, `numpy`, `pandas` (pytest/scipy/
17
+ statsmodels/scikit-learn/matplotlib moved to the `dev` dependency group).
18
+ - **README rewritten method-first**: the repo is the reference implementation of
19
+ the CLeaR 2025 paper (arXiv:2503.16206); the stroke analysis is the case study
20
+ (arXiv:2606.12623) with its detail moved to `docs/stroke-case-study.md`.
21
+ Citation BibTeX added for both papers.
22
+
23
+ ### Added
24
+
25
+ - **`fit(schedule=..., freeze_patience=...)`** — learning-rate schedules and
26
+ per-node early stopping (defaults unchanged). The optimizer now holds one
27
+ param group per node; `schedule="plateau"` decays each node's lr off its own
28
+ validation NLL, and `freeze_patience` drops converged nodes from the loss
29
+ (real FLOP savings — per-node gradients are independent) with early exit when
30
+ all nodes froze. Also `"onecycle"`/`"cosine"`. Benchmarks + recommendation in
31
+ `docs/training-speed.md` (`experiments/bench_training.py`): plateau+freeze
32
+ matches the hand-tuned two-phase recipe's time-to-accuracy with **no budget
33
+ tuning and ~3× less total compute**; full-batch LBFGS solves the classical
34
+ all-`ls` MLE in <2 s (2/3 seeds). Existing defaults intentionally untouched.
35
+ - **Colab demo** `notebooks/demo_tram_dag_colab.py` (+ tracked output-stripped
36
+ `.ipynb` for the badge): the paper's bimodal VACA benchmark fitted live
37
+ (cuda/cpu auto-detect), L1 pairs plot, analytic do-checks, per-individual
38
+ counterfactuals vs DGP truth, GPU-vs-CPU race.
39
+
40
+ - **The TRAM-DAG paper's DGPs** (Sick & Dürr, CLeaR 2025, arXiv:2503.16206) as
41
+ simulation registry families, each a numpy-only SCM with known/analytic ground
42
+ truth + frozen n=5000 CSVs (`data/<name>/`, the test contract) and CLIs:
43
+ - `simulations/triangle.py` — `TriangleContinuous` (§6.1: logistic-latent TRAM
44
+ DGP, h₂=5x₂+2x₁, h₃=0.63x₃−0.2x₁−f(x₂)) and `TriangleMixed` (§6.2: ordinal x₃,
45
+ θ=(−2, 0.42, 1.02)); f variants `linear`/`cubic`/`exp`/`atan`/`sin`; supports
46
+ array-valued `do` (C.4 soft interventions).
47
+ - `simulations/vaca.py` — `VacaTriangle` (App. C.1 bimodal Gaussian L1/L2
48
+ benchmark vs CNF).
49
+ - `simulations/carefl.py` — `Carefl4` (App. C.2 Laplace SCM; **analytic**
50
+ counterfactuals via `abduct_noise`/`true_counterfactual`).
51
+ - `experiments/paper_{triangle,triangle_mixed,vaca,carefl}.py` (+ `paper_common.py`)
52
+ — replicate the paper's figures: coefficient trajectories (Fig. 14/15/19), CS-curve
53
+ recovery (Fig. 7), L1/L2 distribution overlays (Fig. 4/5/9/16/20), counterfactual
54
+ curves at the paper's x_obs (Fig. 6), and the C.4 odds-ratio check (OR ≈ 7.4).
55
+ - `tests/test_paper_dgps.py` — generator pinning (KS TRAM-identities, frozen-CSV
56
+ contract, analytic ground truth) + flow recovery (coefficients with the ordinal
57
+ sign-flip, CS curve, VACA do-moments, CAREFL counterfactual MAE).
58
+
59
+ ### Changed (behavior)
60
+
61
+ - **`CausalFlowDAG.fit(..., restore_best=False)` is now the default.** Training keeps
62
+ the **final converged weights** instead of restoring per-node best-validation
63
+ weights. Rationale:
64
+ - *Least surprise* — `fit()` returns the model you trained, not a silently
65
+ swapped earlier epoch.
66
+ - *Exact classical comparison* — an all-`ls` model trained to convergence is now
67
+ exactly the maximum-likelihood (proportional-odds) estimate, matching
68
+ `statsmodels` `OrderedModel` and R `MASS::polr` to ~1e-3 (see
69
+ `experiments/validate_ls.py`, `tests/test_simulations.py::test_all_ls_flow_is_exact_mle`).
70
+ This was **not achievable before**: best-validation restoration pinned the fit
71
+ off the training optimum.
72
+ - Early stopping is now an explicit, opt-in regularization choice.
73
+
74
+ To restore the previous behavior, pass `restore_best=True`.
75
+
76
+ **Note for flexible (`ci`/`cs`) models:** their MLE *overfits the observational
77
+ confounding*, so they need `restore_best=True` to recover the causal effect (lower
78
+ validation NLL confirms it generalizes better). `experiments/run_experiment`
79
+ therefore defaults `restore_best` per style — off for all-`ls`, on for flexible.
80
+
81
+ ### Added
82
+
83
+ - `src/tramdag/simulations/magic_mrclean.py` — synthetic stroke cohort (SCM with
84
+ known ground truth); `ls`/`nl` variants; CLI to (re)generate frozen CSVs.
85
+ - `data/magic-mrclean/` — frozen public CSVs + `fit_ls.R` classical R reference and
86
+ committed `ref_ls/` outputs. The public, reproducible substitute for the private
87
+ clinical data.
88
+ - `experiments/common.py::load_data(source)` — switch between `"magic"` (private) and
89
+ `"magic-mrclean/{ls,nl}"` (synthetic, default).
90
+ - `experiments/sim_flow.py` — known-truth recovery storyline; `validate_ls.py`
91
+ rewritten as a spot-on flow-vs-MLE-vs-R comparison.
92
+ - `tests/test_simulations.py` — generator, known-truth recovery, the all-`ls`
93
+ spot-on MLE check, and the Python-vs-R regression.
@@ -0,0 +1,113 @@
1
+ # CLAUDE.md — working context for tramdag
2
+
3
+ ## What this is
4
+
5
+ A causal normalizing-flow implementation of **TRAM-DAG** (transformation models on a
6
+ DAG) built on [zuko](https://zuko.readthedocs.io/stable/). One triangular flow from iid
7
+ standard-logistic latents to the observed variables; Jacobian sparsity = the DAG.
8
+ Supports the do-operator, Pearl abduction (counterfactuals), analytic interventional
9
+ PMFs, and per-node configurable monotone transforms (Bernstein / RQ-spline / affine).
10
+
11
+ Origin: extracted from the private `tensorchiefs/tram-dag-stroke` paper repo (as
12
+ `zuko_dag`; renamed to `tramdag` in June 2026, repo `tensorchiefs/tramdag`). The paper analyzed the MAGIC stroke cohort against the MR CLEAN RCT;
13
+ that **clinical data is NOT in this repo** and never should be. The synthetic
14
+ `data/magic-mrclean/` cohort is the public stand-in (same schema, known ground truth).
15
+
16
+ ## Commands
17
+
18
+ ```bash
19
+ uv sync # install (uv.lock pinned: zuko, torch, statsmodels, ...)
20
+ uv run pytest tests/ -q # full suite ~11 min; tests/test_flow.py alone ~20 s
21
+ cd experiments
22
+ uv run python sim_flow.py nl # headline storyline (all-ls vs flexible vs known truth)
23
+ uv run python validate_ls.py # spot-on flow == statsmodels == R polr check
24
+ uv run python paper_triangle.py atan cs # TRAM-DAG paper replications (paper_*.py)
25
+ ```
26
+
27
+ Experiments default to the synthetic data (`magic-mrclean/nl`). The `magic` source
28
+ (private clinical data) only works inside the original paper monorepo.
29
+
30
+ ## Architecture (src/tramdag/)
31
+
32
+ - `spec.py` — user-facing DAG spec: `{name: ContinuousNode|OrdinalNode}`, each node
33
+ declares `parents={parent: term}` with term ∈ `ls` (linear shift), `cs` (complex
34
+ shift MLP), `ci` (complex intercept — transform params from parents; multiple ci
35
+ parents feed ONE joint network).
36
+ - `transforms.py` — monotone 1-D transforms wrapping zuko (`BernsteinUT`, `SplineUT`,
37
+ `AffineUT`; pre-scaled from train 5%/95% quantiles to [-5,5], expanding-bracket
38
+ bisection inverse) + the ordinal ordered-logit transform
39
+ (`P(Y<=k) = sigmoid(theta_k - shift)`, cutpoints `[t0, t0+cumsum(exp(...))]`).
40
+ - `conditioners.py` — ls/cs/ci networks (widths replicate the original Keras/TF implementation).
41
+ - `flow.py` — `CausalFlowDAG`: `fit`, `sample(n, do=, u=)`, `abduct`, `pmf`,
42
+ `log_prob`, `save/load`. NLL decomposes per node → one Adam fits all nodes jointly.
43
+ - `simulations/` — numpy-only SCM generators with known ground truth, looked up via
44
+ `REGISTRY`; each module has a CLI that regenerates its frozen `data/<name>/` CSVs:
45
+ `magic_mrclean.py` (stroke SCM, `ls`/`nl`), `triangle.py` (paper §6 continuous +
46
+ ordinal triangles, f variants linear/cubic/exp/atan/sin), `vaca.py` (App. C.1
47
+ bimodal L1/L2 benchmark), `carefl.py` (App. C.2 Laplace SCM, **analytic**
48
+ counterfactuals).
49
+
50
+ ## Conventions that matter (easy to get wrong)
51
+
52
+ - **Latent scale**: continuous `z = h(x) + shift` (shifts ADDED); ordinal
53
+ `P(Y<=k) = sigmoid(theta_k − shift)` (shift SUBTRACTED). Both follow the original TRAM-DAG
54
+ conventions; tests pin them.
55
+ - **Parent encoding**: continuous parents enter RAW (no standardization); ordinal
56
+ parents one-hot (all levels). With cutpoints, only shift *differences* between
57
+ one-hot levels are identified — compare `w[k] − w[0]` against classical references.
58
+ - **Ordinal log-prob is computed in log-space** (`logsigmoid` + stable `log1mexp`,
59
+ better-conditioned side chosen per element). The naive sigmoid difference saturates
60
+ in float32 → *exactly zero* gradients → a node can freeze at init forever. Do not
61
+ "simplify" it back.
62
+ - **Seeding**: weight init happens at construction — call `torch.manual_seed` BEFORE
63
+ `CausalFlowDAG(spec)`, not just in `fit`.
64
+ - **`fit(restore_best=False)` is the default** (keeps final converged weights = exact
65
+ MLE; an all-`ls` model then matches statsmodels/R-polr to ~1e-3). `restore_best=True`
66
+ = per-node best-validation restoration (early stopping). Key empirical finding:
67
+ **flexible (ci/cs) models overfit observational confounding at the MLE and need
68
+ `restore_best=True` to recover the causal effect; all-`ls` models don't.**
69
+ `run_experiment` defaults per style. See CHANGELOG.md.
70
+
71
+ ## Ground truth & reference numbers (seed 7 synthetic data)
72
+
73
+ - `data/magic-mrclean/{ls,nl}/truth.json` — true ATE from the SCM: `ls` +0.132,
74
+ `nl` +0.104; naive confounded contrast +0.26/+0.30.
75
+ - `nl` storyline: all-`ls` flow ≈ +0.076 (biased — can't extrapolate the age-fading
76
+ treatment effect to the younger RCT population), flexible flow ≈ +0.10 (recovers).
77
+ - Spot-on check (`ls` variant, full-data, restore_best=False): flow = statsmodels =
78
+ R polr at Age 0.0526, NIHSSa 0.1630, T −0.9424; ATE +0.1429 vs +0.1428.
79
+ - R reference: `data/magic-mrclean/fit_ls.R` (needs `tram`, `MASS`); its committed
80
+ `ref_ls/` outputs let tests run without R.
81
+ - Original clinical-data numbers (context only, not reproducible here): TRAM-DAG
82
+ nihss6 +0.108, md_dag_ls +0.054, MR CLEAN RCT +0.135 [0.057, 0.213].
83
+ - **Paper DGPs** (seed 42, arXiv:2503.16206): `triangle` true coefficients β12=+2,
84
+ β13=−0.2 (+0.3 on x2 for `linear`); a fitted `cs` learns −f(x2)+const.
85
+ `triangle-mixed` cutpoints θ=(−2, 0.42, 1.02); **ordinal sign flip**: the paper
86
+ ADDS the ordinal shift, the flow SUBTRACTS → fitted weights −0.2 / +0.3; the C.4
87
+ odds-ratio check gives OR ≈ e² ≈ 7.4. `vaca`: E[x3|do(x2=a)] = −0.25 + 0.25a
88
+ (do(x2=−3) is off-manifold extrapolation — looser tolerance). `carefl`:
89
+ counterfactuals are analytic (`Carefl4.true_counterfactual`); the paper's x_obs has
90
+ a ~4σ abducted noise, so tests score 300 typical rows instead of that single point.
91
+
92
+ ## Testing policy
93
+
94
+ - Frozen CSVs in `data/` (`magic-mrclean`, `triangle*`, `vaca`, `carefl`) are a
95
+ contract — **never regenerate silently**; a new seed/equations → new folder
96
+ (sim2-style), regenerate `ref_ls/` with R where applicable, update
97
+ truth-dependent tests. `test_paper_dgps.py::test_frozen_csv_contract` pins the
98
+ paper-DGP CSVs to their generators bit-exactly.
99
+ - Fit tests for the paper DGPs train on **regenerated n=20k** (deterministic
100
+ `observational(n, seed_offset=100)`), not the frozen n=5k — β13 multiplies the
101
+ low-variance x1 ∈ [0.25, 0.73] and is too weakly identified at n=5k.
102
+ - New causal features should be validated against the simulator's known truth
103
+ (`MagicMrClean.true_ate`, `counterfactual_pair` gives true individual
104
+ counterfactuals via shared latents).
105
+
106
+ ## Roadmap notes
107
+
108
+ - ~~Generalize `simulations/` registry beyond the stroke DAG~~ — done for the
109
+ TRAM-DAG paper's DGPs (triangle/triangle-mixed/vaca/carefl, June 2026). Still
110
+ open: hidden confounding à la DeCaFlow.
111
+ - ~~Package for PyPI~~ — published as `tramdag` 0.2.0 (June 2026); release flow:
112
+ bump version in pyproject + `__init__`, `uv build`, `uv publish` (Oliver's
113
+ PyPI token), CHANGELOG section.
tramdag-0.2.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Oliver Dürr, Beate Sick (tensorchiefs)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
tramdag-0.2.0/PKG-INFO ADDED
@@ -0,0 +1,206 @@
1
+ Metadata-Version: 2.4
2
+ Name: tramdag
3
+ Version: 0.2.0
4
+ Summary: Interpretable Neural Causal Models (TRAM-DAGs) in PyTorch: one causal normalizing flow for observational, interventional and counterfactual queries
5
+ Project-URL: Homepage, https://github.com/tensorchiefs/tramdag
6
+ Project-URL: Method paper (CLeaR 2025), https://arxiv.org/abs/2503.16206
7
+ Project-URL: Stroke case study, https://arxiv.org/abs/2606.12623
8
+ Author: Beate Sick
9
+ Author-email: Oliver Dürr <oliver.duerr@gmail.com>
10
+ License-Expression: MIT
11
+ License-File: LICENSE
12
+ Keywords: causal-inference,counterfactuals,interpretability,normalizing-flows,structural-causal-models,transformation-models
13
+ Classifier: Development Status :: 4 - Beta
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.10
18
+ Requires-Dist: numpy
19
+ Requires-Dist: pandas
20
+ Requires-Dist: torch>=2.0
21
+ Requires-Dist: zuko>=1.3
22
+ Description-Content-Type: text/markdown
23
+
24
+ # tramdag — Interpretable Neural Causal Models (TRAM-DAGs) in PyTorch
25
+
26
+ [![Open the demo in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tensorchiefs/tramdag/blob/main/notebooks/demo_tram_dag_colab.ipynb)
27
+ [![PyPI](https://img.shields.io/pypi/v/tramdag)](https://pypi.org/project/tramdag/)
28
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
29
+
30
+ **TRAM-DAGs** model each variable of a structural causal model with a
31
+ (transformation-model) flow: one triangular normalizing flow from iid
32
+ standard-logistic latents to the observed variables, whose Jacobian sparsity is
33
+ exactly your causal DAG. Fit it **once** on observational data and answer all
34
+ three rungs of Pearl's causal hierarchy — observational (L1), interventional
35
+ (L2, the do-operator), and counterfactual (L3, Pearl abduction) — while keeping
36
+ **interpretable effects**: every linear-shift coefficient is a log-odds ratio,
37
+ exactly as in classical proportional-odds models.
38
+
39
+ > Beate Sick & Oliver Dürr, *Interpretable Neural Causal Models with TRAM-DAGs*,
40
+ > CLeaR 2025 ([arXiv:2503.16206](https://arxiv.org/abs/2503.16206)).
41
+ > This repo is the reference implementation (PyTorch, built on
42
+ > [zuko](https://zuko.readthedocs.io/stable/)); all of the paper's experiments are
43
+ > replicated here with pinned tests.
44
+
45
+ **5-minute showcase**: the Colab badge above fits the paper's bimodal benchmark
46
+ live (GPU-ready) and walks L1 → L2 → L3, every answer checked against analytic
47
+ ground truth. Didactic walkthrough of the model:
48
+ [`notebooks/intro_tram_dag.py`](notebooks/intro_tram_dag.py).
49
+
50
+ ## Install
51
+
52
+ ```bash
53
+ pip install tramdag # PyPI
54
+ uv sync # or: dev setup from a clone (tests, experiments)
55
+ ```
56
+
57
+ ## 30 seconds of API
58
+
59
+ ```python
60
+ import tramdag as td
61
+ from tramdag import CausalFlowDAG, ContinuousNode, OrdinalNode
62
+
63
+ spec = { # the spec IS the labelled DAG
64
+ "Age": ContinuousNode(),
65
+ "mRS_pre": OrdinalNode(levels=6, parents={"Age": "ci"}),
66
+ "NIHSSa": ContinuousNode(parents={"Age": "ci", "mRS_pre": "ls"}),
67
+ "T": OrdinalNode(levels=2,
68
+ parents={"Age": "ci", "mRS_pre": "ls", "NIHSSa": "cs"}),
69
+ "mRS_3m": OrdinalNode(levels=7,
70
+ parents={"Age": "ci", "mRS_pre": "ls",
71
+ "NIHSSa": "cs", "T": "ls"}),
72
+ }
73
+ flow = CausalFlowDAG(spec) # validates acyclicity, builds the flow
74
+
75
+ # self-stopping training: per-node plateau lr decay + freezing of converged
76
+ # nodes (exact, since the per-node NLLs have independent gradients);
77
+ # see docs/training-speed.md for benchmarks and the classic two-phase recipe
78
+ flow.fit(train_df, val_df, epochs=4000, learning_rate=1e-2,
79
+ schedule="plateau", plateau_patience=30, freeze_patience=120)
80
+
81
+ flow.log_prob(df) # L1: joint log-likelihood per row
82
+ flow.sample(1000) # L1: observational sampling
83
+ flow.sample(1000, do={"T": 1}) # L2: interventional (graph mutilation)
84
+ flow.pmf(df, node="mRS_3m", do={"T": 1}) # L2: analytic interventional PMF
85
+
86
+ u = flow.abduct(df) # L3 step 1: latents from observations
87
+ cf = flow.sample(do={"T": 1}, u=u) # L3 steps 2+3: counterfactuals
88
+
89
+ flow.save("flow.pt"); flow = CausalFlowDAG.load("flow.pt")
90
+
91
+ td.simulations.REGISTRY # synthetic DGPs with known ground truth
92
+ ```
93
+
94
+ ## The model in one table
95
+
96
+ Per node, the transformation is additive on the latent (log-odds) scale —
97
+ `u = h(x; θ) + Σ β·x_pa + Σ g(x_pa)` — and each parent edge declares how it enters:
98
+
99
+ | edge term | meaning | interpretability |
100
+ |---|---|---|
101
+ | `ls` | linear shift `β·x_pa` | `exp(β)` is an odds ratio — one number per edge |
102
+ | `cs` | complex shift `g(x_pa)` (MLP), still additive | odds-ratio *function*, plot `g` |
103
+ | `ci` | complex intercept: the transform's parameters depend on the parents (several `ci` parents feed one joint network) | maximal flexibility, interactions |
104
+
105
+ Continuous nodes carry a monotone 1-D transform (`bernstein` — TRAM-faithful
106
+ default, `spline`, `affine`; `ContinuousNode(transform=..., transform_kwargs=...)`);
107
+ ordinal nodes an ordered-logit head `P(x ≤ k) = σ(θ_k − shift)`. Abduction is exact
108
+ for continuous nodes and truncated-logistic for ordinal ones, so
109
+ `flow.sample(u=flow.abduct(df))` reproduces `df` exactly / level-exactly.
110
+
111
+ ## Validation (all pinned by tests)
112
+
113
+ - **Paper replication** — every experiment of the CLeaR paper is a registry
114
+ family (numpy-only SCM + frozen CSVs + replication script):
115
+
116
+ | family | paper | demonstrates |
117
+ |---|---|---|
118
+ | `triangle` (`linear`,`atan`,`sin`) | §6.1 | LS coefficient recovery (β = 2, −0.2, +0.3), CS curve ≡ −f(x₂), non-monotone f |
119
+ | `triangle-mixed` (`linear`,`exp`) | §6.2 | mixed data L1/L2 + the C.4 odds-ratio check (OR ≈ 7.4) |
120
+ | `vaca` | §5.1–5.2 | the bimodal L1 case a default CNF misses; L2 `p(x₃ \| do(x₂))` |
121
+ | `carefl` | §5.3 | L3 counterfactual curves vs **analytic** truth |
122
+
123
+ ```bash
124
+ cd experiments && uv run python paper_triangle.py atan cs # etc., see paper_*.py
125
+ ```
126
+
127
+ Sign note: ordinal shifts are *subtracted* here but *added* in the paper, so
128
+ fitted ordinal weights are the paper's with flipped sign (`truth.json` records
129
+ both conventions per family).
130
+
131
+ - **Exact classical equivalence** — an all-`ls` flow trained to convergence *is*
132
+ the proportional-odds MLE: coefficients match `statsmodels` **and** R
133
+ `MASS::polr` to ~4 decimals (`experiments/validate_ls.py`, R reference committed
134
+ under `data/magic-mrclean/*/ref_ls/`).
135
+
136
+ - **Training speed** — schedules, per-node freezing, LBFGS and device benchmarks:
137
+ [`docs/training-speed.md`](docs/training-speed.md).
138
+
139
+ ## Case study: individualized treatment effects in stroke
140
+
141
+ The method's flagship application estimates individualized thrombectomy effects
142
+ from the observational MAGIC cohort with external validation against the
143
+ MR CLEAN trial:
144
+
145
+ > Dürr, Herzog, Bühler, Wegener & Sick, *Estimating Individualized Treatment
146
+ > Effects in Acute Ischemic Stroke with Causal Transformation Models (TRAM-DAG)*
147
+ > ([arXiv:2606.12623](https://arxiv.org/abs/2606.12623)).
148
+
149
+ The clinical data is private and **never** part of this repo. Its public
150
+ stand-in is `data/magic-mrclean/` — a fully synthetic cohort with the same
151
+ schema and **known ground truth** (true ATE, true individual counterfactuals),
152
+ including an `nl` variant where an all-`ls` model is provably misspecified:
153
+
154
+ | `nl` variant | ATE | vs true **+0.104** |
155
+ |---|---|---|
156
+ | naive observational contrast | +0.303 | confounded (overstates 2.9×) |
157
+ | all-`ls` flow | +0.076 | undershoots (misses the age-varying effect) |
158
+ | flexible (`ci`/`cs`) flow | +0.101 | **recovers the truth** |
159
+
160
+ Full storyline, clinical-data context, R cross-check and reading notes:
161
+ [`docs/stroke-case-study.md`](docs/stroke-case-study.md).
162
+
163
+ ## Layout
164
+
165
+ ```
166
+ src/tramdag/ spec.py transforms.py conditioners.py flow.py
167
+ simulations/ (magic_mrclean, triangle, vaca, carefl + CLIs)
168
+ data/ frozen synthetic CSVs + truth.json — a test contract
169
+ experiments/ stroke pipeline, paper replications, training benchmark
170
+ notebooks/ intro (didactic) + Colab demo (jupytext .py — see README there)
171
+ tests/ 66 tests: unit, known-truth recovery, R regression
172
+ docs/ training-speed.md, stroke-case-study.md
173
+ ```
174
+
175
+ Implementation conventions (latent-scale signs, raw/one-hot parent encoding,
176
+ log-space ordinal likelihood, seeding) are documented in
177
+ [`CLAUDE.md`](CLAUDE.md) and pinned by tests.
178
+
179
+ ## Citation
180
+
181
+ If you use `tramdag`, please cite the method paper:
182
+
183
+ ```bibtex
184
+ @inproceedings{sick2025tramdag,
185
+ title = {Interpretable Neural Causal Models with TRAM-DAGs},
186
+ author = {Sick, Beate and D{\"u}rr, Oliver},
187
+ booktitle = {Proceedings of the 4th Conference on Causal Learning and Reasoning (CLeaR)},
188
+ series = {Proceedings of Machine Learning Research},
189
+ volume = {275},
190
+ year = {2025},
191
+ }
192
+ ```
193
+
194
+ For the stroke application (and the `magic-mrclean` cohort design) additionally:
195
+
196
+ ```bibtex
197
+ @article{duerr2026stroke,
198
+ title = {Estimating Individualized Treatment Effects in Acute Ischemic Stroke
199
+ with Causal Transformation Models (TRAM-DAG): A Multi-Centre
200
+ Observational Study with External RCT Validation},
201
+ author = {D{\"u}rr, Oliver and Herzog, Lisa and B{\"u}hler, Pascal and
202
+ Wegener, Susanne and Sick, Beate},
203
+ journal = {arXiv preprint arXiv:2606.12623},
204
+ year = {2026},
205
+ }
206
+ ```