tramdag 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tramdag-0.2.0/.claude/scheduled_tasks.lock +1 -0
- tramdag-0.2.0/.claude/settings.local.json +26 -0
- tramdag-0.2.0/.gitignore +24 -0
- tramdag-0.2.0/CHANGELOG.md +93 -0
- tramdag-0.2.0/CLAUDE.md +113 -0
- tramdag-0.2.0/LICENSE +21 -0
- tramdag-0.2.0/PKG-INFO +206 -0
- tramdag-0.2.0/README.md +183 -0
- tramdag-0.2.0/data/carefl/obs.csv +5001 -0
- tramdag-0.2.0/data/carefl/truth.json +206 -0
- tramdag-0.2.0/data/magic-mrclean/README.md +101 -0
- tramdag-0.2.0/data/magic-mrclean/fit_ls.R +98 -0
- tramdag-0.2.0/data/magic-mrclean/ls/obs.csv +1276 -0
- tramdag-0.2.0/data/magic-mrclean/ls/rct.csv +501 -0
- tramdag-0.2.0/data/magic-mrclean/ls/ref_ls/ate.csv +2 -0
- tramdag-0.2.0/data/magic-mrclean/ls/ref_ls/coefficients.csv +24 -0
- tramdag-0.2.0/data/magic-mrclean/ls/truth.json +12 -0
- tramdag-0.2.0/data/magic-mrclean/nl/obs.csv +1276 -0
- tramdag-0.2.0/data/magic-mrclean/nl/rct.csv +501 -0
- tramdag-0.2.0/data/magic-mrclean/nl/ref_ls/ate.csv +2 -0
- tramdag-0.2.0/data/magic-mrclean/nl/ref_ls/coefficients.csv +24 -0
- tramdag-0.2.0/data/magic-mrclean/nl/truth.json +12 -0
- tramdag-0.2.0/data/triangle/atan/obs.csv +5001 -0
- tramdag-0.2.0/data/triangle/atan/truth.json +19 -0
- tramdag-0.2.0/data/triangle/linear/obs.csv +5001 -0
- tramdag-0.2.0/data/triangle/linear/truth.json +21 -0
- tramdag-0.2.0/data/triangle/sin/obs.csv +5001 -0
- tramdag-0.2.0/data/triangle/sin/truth.json +19 -0
- tramdag-0.2.0/data/triangle-mixed/exp/obs.csv +5001 -0
- tramdag-0.2.0/data/triangle-mixed/exp/truth.json +30 -0
- tramdag-0.2.0/data/triangle-mixed/linear/obs.csv +5001 -0
- tramdag-0.2.0/data/triangle-mixed/linear/truth.json +32 -0
- tramdag-0.2.0/data/vaca/obs.csv +5001 -0
- tramdag-0.2.0/data/vaca/truth.json +35 -0
- tramdag-0.2.0/docs/img/nll_vs_time_stroke-ls.png +0 -0
- tramdag-0.2.0/docs/img/nll_vs_time_vaca-ci.png +0 -0
- tramdag-0.2.0/docs/stroke-case-study.md +111 -0
- tramdag-0.2.0/docs/training-speed.md +155 -0
- tramdag-0.2.0/experiments/all_ls_flow.py +21 -0
- tramdag-0.2.0/experiments/all_ls_long.py +14 -0
- tramdag-0.2.0/experiments/bench_training.py +290 -0
- tramdag-0.2.0/experiments/common.py +370 -0
- tramdag-0.2.0/experiments/counterfactual_demo.py +107 -0
- tramdag-0.2.0/experiments/nihss6_flow.py +21 -0
- tramdag-0.2.0/experiments/paper_carefl.py +75 -0
- tramdag-0.2.0/experiments/paper_common.py +115 -0
- tramdag-0.2.0/experiments/paper_triangle.py +87 -0
- tramdag-0.2.0/experiments/paper_triangle_mixed.py +109 -0
- tramdag-0.2.0/experiments/paper_vaca.py +85 -0
- tramdag-0.2.0/experiments/sim_flow.py +27 -0
- tramdag-0.2.0/experiments/validate_ls.py +112 -0
- tramdag-0.2.0/notebooks/README.md +47 -0
- tramdag-0.2.0/notebooks/demo_tram_dag_colab.ipynb +456 -0
- tramdag-0.2.0/notebooks/demo_tram_dag_colab.py +326 -0
- tramdag-0.2.0/notebooks/intro_tram_dag.py +486 -0
- tramdag-0.2.0/pyproject.toml +49 -0
- tramdag-0.2.0/src/tramdag/__init__.py +22 -0
- tramdag-0.2.0/src/tramdag/conditioners.py +74 -0
- tramdag-0.2.0/src/tramdag/flow.py +427 -0
- tramdag-0.2.0/src/tramdag/simulations/__init__.py +23 -0
- tramdag-0.2.0/src/tramdag/simulations/carefl.py +125 -0
- tramdag-0.2.0/src/tramdag/simulations/magic_mrclean.py +259 -0
- tramdag-0.2.0/src/tramdag/simulations/triangle.py +252 -0
- tramdag-0.2.0/src/tramdag/simulations/vaca.py +130 -0
- tramdag-0.2.0/src/tramdag/spec.py +99 -0
- tramdag-0.2.0/src/tramdag/transforms.py +271 -0
- tramdag-0.2.0/tests/test_fit_schedules.py +112 -0
- tramdag-0.2.0/tests/test_flow.py +193 -0
- tramdag-0.2.0/tests/test_paper_dgps.py +242 -0
- tramdag-0.2.0/tests/test_simulations.py +227 -0
- tramdag-0.2.0/uv.lock +1856 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"sessionId":"ea0171fe-94d9-4044-b1c8-81400b3676ad","pid":50879,"procStart":"Thu Jun 11 09:52:35 2026","acquiredAt":1781179094640}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
{
|
|
2
|
+
"permissions": {
|
|
3
|
+
"allow": [
|
|
4
|
+
"Bash(git add *)",
|
|
5
|
+
"Bash(git commit -m ' *)",
|
|
6
|
+
"Bash(git checkout *)",
|
|
7
|
+
"Bash(uv run *)",
|
|
8
|
+
"Bash(uv venv *)",
|
|
9
|
+
"Bash(uv pip *)",
|
|
10
|
+
"Bash(/tmp/zdtest/bin/python -c ' *)",
|
|
11
|
+
"Bash(MPLBACKEND=Agg uv run python notebooks/demo_tram_dag_colab.py)",
|
|
12
|
+
"Bash(uvx jupytext *)",
|
|
13
|
+
"Bash(git check-ignore *)",
|
|
14
|
+
"Bash(git commit -q -m 'fit\\(\\): lr schedules + per-node freezing \\(defaults unchanged\\) *)",
|
|
15
|
+
"Bash(git commit -q -m 'Training-speed benchmark + report *)",
|
|
16
|
+
"Bash(git commit -q -m 'Colab GPU demo \\(bimodal VACA benchmark\\) + README badge *)",
|
|
17
|
+
"Bash(git push *)",
|
|
18
|
+
"Bash(gh auth *)",
|
|
19
|
+
"Bash(gh api *)",
|
|
20
|
+
"Bash(git commit -q -m 'Address review: option guide in report + notebooks README *)",
|
|
21
|
+
"Bash(gh pr *)",
|
|
22
|
+
"Bash(git pull *)",
|
|
23
|
+
"Bash(git branch *)"
|
|
24
|
+
]
|
|
25
|
+
}
|
|
26
|
+
}
|
tramdag-0.2.0/.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Python / uv
|
|
2
|
+
.venv/
|
|
3
|
+
__pycache__/
|
|
4
|
+
*.py[cod]
|
|
5
|
+
.pytest_cache/
|
|
6
|
+
*.egg-info/
|
|
7
|
+
|
|
8
|
+
# Experiment outputs — regenerable, and (for the clinical 'magic' source) derived
|
|
9
|
+
# from patient data, so never committed. The synthetic data/ folder IS tracked.
|
|
10
|
+
results/
|
|
11
|
+
|
|
12
|
+
# R
|
|
13
|
+
.Rhistory
|
|
14
|
+
.RData
|
|
15
|
+
|
|
16
|
+
# Notebooks: jupytext py:percent files are the source of truth;
|
|
17
|
+
# generated .ipynb (with embedded image outputs) stay out of git.
|
|
18
|
+
notebooks/*.ipynb
|
|
19
|
+
# exception: the Colab demo needs a tracked (output-stripped) ipynb for the
|
|
20
|
+
# "Open in Colab" badge; regenerate with `uvx jupytext --to ipynb <demo>.py`
|
|
21
|
+
!notebooks/demo_tram_dag_colab.ipynb
|
|
22
|
+
|
|
23
|
+
# build artifacts
|
|
24
|
+
dist/
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## 0.2.0 (2026-06-12)
|
|
4
|
+
|
|
5
|
+
First PyPI release: `pip install tramdag`.
|
|
6
|
+
|
|
7
|
+
### Changed (naming & packaging)
|
|
8
|
+
|
|
9
|
+
- **Renamed**: Python package `zuko_dag` → **`tramdag`** (conventional alias
|
|
10
|
+
`import tramdag as td`); GitHub repo `tram-dag-zuko` → `tensorchiefs/tramdag`
|
|
11
|
+
(old URLs redirect). The package implements TRAM-DAGs; zuko names the backend.
|
|
12
|
+
No API changes; old checkpoints still load. References to the original
|
|
13
|
+
Keras/TF implementation (tensorchiefs/tram-dag) reworded to avoid
|
|
14
|
+
self-reference.
|
|
15
|
+
- **MIT license** added; PyPI metadata (authors, urls, classifiers); runtime
|
|
16
|
+
dependencies trimmed to `torch`, `zuko`, `numpy`, `pandas` (pytest/scipy/
|
|
17
|
+
statsmodels/scikit-learn/matplotlib moved to the `dev` dependency group).
|
|
18
|
+
- **README rewritten method-first**: the repo is the reference implementation of
|
|
19
|
+
the CLeaR 2025 paper (arXiv:2503.16206); the stroke analysis is the case study
|
|
20
|
+
(arXiv:2606.12623) with its detail moved to `docs/stroke-case-study.md`.
|
|
21
|
+
Citation BibTeX added for both papers.
|
|
22
|
+
|
|
23
|
+
### Added
|
|
24
|
+
|
|
25
|
+
- **`fit(schedule=..., freeze_patience=...)`** — learning-rate schedules and
|
|
26
|
+
per-node early stopping (defaults unchanged). The optimizer now holds one
|
|
27
|
+
param group per node; `schedule="plateau"` decays each node's lr off its own
|
|
28
|
+
validation NLL, and `freeze_patience` drops converged nodes from the loss
|
|
29
|
+
(real FLOP savings — per-node gradients are independent) with early exit when
|
|
30
|
+
all nodes froze. Also `"onecycle"`/`"cosine"`. Benchmarks + recommendation in
|
|
31
|
+
`docs/training-speed.md` (`experiments/bench_training.py`): plateau+freeze
|
|
32
|
+
matches the hand-tuned two-phase recipe's time-to-accuracy with **no budget
|
|
33
|
+
tuning and ~3× less total compute**; full-batch LBFGS solves the classical
|
|
34
|
+
all-`ls` MLE in <2 s (2/3 seeds). Existing defaults intentionally untouched.
|
|
35
|
+
- **Colab demo** `notebooks/demo_tram_dag_colab.py` (+ tracked output-stripped
|
|
36
|
+
`.ipynb` for the badge): the paper's bimodal VACA benchmark fitted live
|
|
37
|
+
(cuda/cpu auto-detect), L1 pairs plot, analytic do-checks, per-individual
|
|
38
|
+
counterfactuals vs DGP truth, GPU-vs-CPU race.
|
|
39
|
+
|
|
40
|
+
- **The TRAM-DAG paper's DGPs** (Sick & Dürr, CLeaR 2025, arXiv:2503.16206) as
|
|
41
|
+
simulation registry families, each a numpy-only SCM with known/analytic ground
|
|
42
|
+
truth + frozen n=5000 CSVs (`data/<name>/`, the test contract) and CLIs:
|
|
43
|
+
- `simulations/triangle.py` — `TriangleContinuous` (§6.1: logistic-latent TRAM
|
|
44
|
+
DGP, h₂=5x₂+2x₁, h₃=0.63x₃−0.2x₁−f(x₂)) and `TriangleMixed` (§6.2: ordinal x₃,
|
|
45
|
+
θ=(−2, 0.42, 1.02)); f variants `linear`/`cubic`/`exp`/`atan`/`sin`; supports
|
|
46
|
+
array-valued `do` (C.4 soft interventions).
|
|
47
|
+
- `simulations/vaca.py` — `VacaTriangle` (App. C.1 bimodal Gaussian L1/L2
|
|
48
|
+
benchmark vs CNF).
|
|
49
|
+
- `simulations/carefl.py` — `Carefl4` (App. C.2 Laplace SCM; **analytic**
|
|
50
|
+
counterfactuals via `abduct_noise`/`true_counterfactual`).
|
|
51
|
+
- `experiments/paper_{triangle,triangle_mixed,vaca,carefl}.py` (+ `paper_common.py`)
|
|
52
|
+
— replicate the paper's figures: coefficient trajectories (Fig. 14/15/19), CS-curve
|
|
53
|
+
recovery (Fig. 7), L1/L2 distribution overlays (Fig. 4/5/9/16/20), counterfactual
|
|
54
|
+
curves at the paper's x_obs (Fig. 6), and the C.4 odds-ratio check (OR ≈ 7.4).
|
|
55
|
+
- `tests/test_paper_dgps.py` — generator pinning (KS TRAM-identities, frozen-CSV
|
|
56
|
+
contract, analytic ground truth) + flow recovery (coefficients with the ordinal
|
|
57
|
+
sign-flip, CS curve, VACA do-moments, CAREFL counterfactual MAE).
|
|
58
|
+
|
|
59
|
+
### Changed (behavior)
|
|
60
|
+
|
|
61
|
+
- **`CausalFlowDAG.fit(..., restore_best=False)` is now the default.** Training keeps
|
|
62
|
+
the **final converged weights** instead of restoring per-node best-validation
|
|
63
|
+
weights. Rationale:
|
|
64
|
+
- *Least surprise* — `fit()` returns the model you trained, not a silently
|
|
65
|
+
swapped earlier epoch.
|
|
66
|
+
- *Exact classical comparison* — an all-`ls` model trained to convergence is now
|
|
67
|
+
exactly the maximum-likelihood (proportional-odds) estimate, matching
|
|
68
|
+
`statsmodels` `OrderedModel` and R `MASS::polr` to ~1e-3 (see
|
|
69
|
+
`experiments/validate_ls.py`, `tests/test_simulations.py::test_all_ls_flow_is_exact_mle`).
|
|
70
|
+
This was **not achievable before**: best-validation restoration pinned the fit
|
|
71
|
+
off the training optimum.
|
|
72
|
+
- Early stopping is now an explicit, opt-in regularization choice.
|
|
73
|
+
|
|
74
|
+
To restore the previous behavior, pass `restore_best=True`.
|
|
75
|
+
|
|
76
|
+
**Note for flexible (`ci`/`cs`) models:** their MLE *overfits the observational
|
|
77
|
+
confounding*, so they need `restore_best=True` to recover the causal effect (lower
|
|
78
|
+
validation NLL confirms it generalizes better). `experiments/run_experiment`
|
|
79
|
+
therefore defaults `restore_best` per style — off for all-`ls`, on for flexible.
|
|
80
|
+
|
|
81
|
+
### Added
|
|
82
|
+
|
|
83
|
+
- `src/tramdag/simulations/magic_mrclean.py` — synthetic stroke cohort (SCM with
|
|
84
|
+
known ground truth); `ls`/`nl` variants; CLI to (re)generate frozen CSVs.
|
|
85
|
+
- `data/magic-mrclean/` — frozen public CSVs + `fit_ls.R` classical R reference and
|
|
86
|
+
committed `ref_ls/` outputs. The public, reproducible substitute for the private
|
|
87
|
+
clinical data.
|
|
88
|
+
- `experiments/common.py::load_data(source)` — switch between `"magic"` (private) and
|
|
89
|
+
`"magic-mrclean/{ls,nl}"` (synthetic, default).
|
|
90
|
+
- `experiments/sim_flow.py` — known-truth recovery storyline; `validate_ls.py`
|
|
91
|
+
rewritten as a spot-on flow-vs-MLE-vs-R comparison.
|
|
92
|
+
- `tests/test_simulations.py` — generator, known-truth recovery, the all-`ls`
|
|
93
|
+
spot-on MLE check, and the Python-vs-R regression.
|
tramdag-0.2.0/CLAUDE.md
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# CLAUDE.md — working context for tramdag
|
|
2
|
+
|
|
3
|
+
## What this is
|
|
4
|
+
|
|
5
|
+
A causal normalizing-flow implementation of **TRAM-DAG** (transformation models on a
|
|
6
|
+
DAG) built on [zuko](https://zuko.readthedocs.io/stable/). One triangular flow from iid
|
|
7
|
+
standard-logistic latents to the observed variables; Jacobian sparsity = the DAG.
|
|
8
|
+
Supports the do-operator, Pearl abduction (counterfactuals), analytic interventional
|
|
9
|
+
PMFs, and per-node configurable monotone transforms (Bernstein / RQ-spline / affine).
|
|
10
|
+
|
|
11
|
+
Origin: extracted from the private `tensorchiefs/tram-dag-stroke` paper repo (as
|
|
12
|
+
`zuko_dag`; renamed to `tramdag` in June 2026, repo `tensorchiefs/tramdag`). The paper analyzed the MAGIC stroke cohort against the MR CLEAN RCT;
|
|
13
|
+
that **clinical data is NOT in this repo** and never should be. The synthetic
|
|
14
|
+
`data/magic-mrclean/` cohort is the public stand-in (same schema, known ground truth).
|
|
15
|
+
|
|
16
|
+
## Commands
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
uv sync # install (uv.lock pinned: zuko, torch, statsmodels, ...)
|
|
20
|
+
uv run pytest tests/ -q # full suite ~11 min; tests/test_flow.py alone ~20 s
|
|
21
|
+
cd experiments
|
|
22
|
+
uv run python sim_flow.py nl # headline storyline (all-ls vs flexible vs known truth)
|
|
23
|
+
uv run python validate_ls.py # spot-on flow == statsmodels == R polr check
|
|
24
|
+
uv run python paper_triangle.py atan cs # TRAM-DAG paper replications (paper_*.py)
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Experiments default to the synthetic data (`magic-mrclean/nl`). The `magic` source
|
|
28
|
+
(private clinical data) only works inside the original paper monorepo.
|
|
29
|
+
|
|
30
|
+
## Architecture (src/tramdag/)
|
|
31
|
+
|
|
32
|
+
- `spec.py` — user-facing DAG spec: `{name: ContinuousNode|OrdinalNode}`, each node
|
|
33
|
+
declares `parents={parent: term}` with term ∈ `ls` (linear shift), `cs` (complex
|
|
34
|
+
shift MLP), `ci` (complex intercept — transform params from parents; multiple ci
|
|
35
|
+
parents feed ONE joint network).
|
|
36
|
+
- `transforms.py` — monotone 1-D transforms wrapping zuko (`BernsteinUT`, `SplineUT`,
|
|
37
|
+
`AffineUT`; pre-scaled from train 5%/95% quantiles to [-5,5], expanding-bracket
|
|
38
|
+
bisection inverse) + the ordinal ordered-logit transform
|
|
39
|
+
(`P(Y<=k) = sigmoid(theta_k - shift)`, cutpoints `[t0, t0+cumsum(exp(...))]`).
|
|
40
|
+
- `conditioners.py` — ls/cs/ci networks (widths replicate the original Keras/TF implementation).
|
|
41
|
+
- `flow.py` — `CausalFlowDAG`: `fit`, `sample(n, do=, u=)`, `abduct`, `pmf`,
|
|
42
|
+
`log_prob`, `save/load`. NLL decomposes per node → one Adam fits all nodes jointly.
|
|
43
|
+
- `simulations/` — numpy-only SCM generators with known ground truth, looked up via
|
|
44
|
+
`REGISTRY`; each module has a CLI that regenerates its frozen `data/<name>/` CSVs:
|
|
45
|
+
`magic_mrclean.py` (stroke SCM, `ls`/`nl`), `triangle.py` (paper §6 continuous +
|
|
46
|
+
ordinal triangles, f variants linear/cubic/exp/atan/sin), `vaca.py` (App. C.1
|
|
47
|
+
bimodal L1/L2 benchmark), `carefl.py` (App. C.2 Laplace SCM, **analytic**
|
|
48
|
+
counterfactuals).
|
|
49
|
+
|
|
50
|
+
## Conventions that matter (easy to get wrong)
|
|
51
|
+
|
|
52
|
+
- **Latent scale**: continuous `z = h(x) + shift` (shifts ADDED); ordinal
|
|
53
|
+
`P(Y<=k) = sigmoid(theta_k − shift)` (shift SUBTRACTED). Both follow the original TRAM-DAG
|
|
54
|
+
conventions; tests pin them.
|
|
55
|
+
- **Parent encoding**: continuous parents enter RAW (no standardization); ordinal
|
|
56
|
+
parents one-hot (all levels). With cutpoints, only shift *differences* between
|
|
57
|
+
one-hot levels are identified — compare `w[k] − w[0]` against classical references.
|
|
58
|
+
- **Ordinal log-prob is computed in log-space** (`logsigmoid` + stable `log1mexp`,
|
|
59
|
+
better-conditioned side chosen per element). The naive sigmoid difference saturates
|
|
60
|
+
in float32 → *exactly zero* gradients → a node can freeze at init forever. Do not
|
|
61
|
+
"simplify" it back.
|
|
62
|
+
- **Seeding**: weight init happens at construction — call `torch.manual_seed` BEFORE
|
|
63
|
+
`CausalFlowDAG(spec)`, not just in `fit`.
|
|
64
|
+
- **`fit(restore_best=False)` is the default** (keeps final converged weights = exact
|
|
65
|
+
MLE; an all-`ls` model then matches statsmodels/R-polr to ~1e-3). `restore_best=True`
|
|
66
|
+
= per-node best-validation restoration (early stopping). Key empirical finding:
|
|
67
|
+
**flexible (ci/cs) models overfit observational confounding at the MLE and need
|
|
68
|
+
`restore_best=True` to recover the causal effect; all-`ls` models don't.**
|
|
69
|
+
`run_experiment` defaults per style. See CHANGELOG.md.
|
|
70
|
+
|
|
71
|
+
## Ground truth & reference numbers (seed 7 synthetic data)
|
|
72
|
+
|
|
73
|
+
- `data/magic-mrclean/{ls,nl}/truth.json` — true ATE from the SCM: `ls` +0.132,
|
|
74
|
+
`nl` +0.104; naive confounded contrast +0.26/+0.30.
|
|
75
|
+
- `nl` storyline: all-`ls` flow ≈ +0.076 (biased — can't extrapolate the age-fading
|
|
76
|
+
treatment effect to the younger RCT population), flexible flow ≈ +0.10 (recovers).
|
|
77
|
+
- Spot-on check (`ls` variant, full-data, restore_best=False): flow = statsmodels =
|
|
78
|
+
R polr at Age 0.0526, NIHSSa 0.1630, T −0.9424; ATE +0.1429 vs +0.1428.
|
|
79
|
+
- R reference: `data/magic-mrclean/fit_ls.R` (needs `tram`, `MASS`); its committed
|
|
80
|
+
`ref_ls/` outputs let tests run without R.
|
|
81
|
+
- Original clinical-data numbers (context only, not reproducible here): TRAM-DAG
|
|
82
|
+
nihss6 +0.108, md_dag_ls +0.054, MR CLEAN RCT +0.135 [0.057, 0.213].
|
|
83
|
+
- **Paper DGPs** (seed 42, arXiv:2503.16206): `triangle` true coefficients β12=+2,
|
|
84
|
+
β13=−0.2 (+0.3 on x2 for `linear`); a fitted `cs` learns −f(x2)+const.
|
|
85
|
+
`triangle-mixed` cutpoints θ=(−2, 0.42, 1.02); **ordinal sign flip**: the paper
|
|
86
|
+
ADDS the ordinal shift, the flow SUBTRACTS → fitted weights −0.2 / +0.3; the C.4
|
|
87
|
+
odds-ratio check gives OR ≈ e² ≈ 7.4. `vaca`: E[x3|do(x2=a)] = −0.25 + 0.25a
|
|
88
|
+
(do(x2=−3) is off-manifold extrapolation — looser tolerance). `carefl`:
|
|
89
|
+
counterfactuals are analytic (`Carefl4.true_counterfactual`); the paper's x_obs has
|
|
90
|
+
a ~4σ abducted noise, so tests score 300 typical rows instead of that single point.
|
|
91
|
+
|
|
92
|
+
## Testing policy
|
|
93
|
+
|
|
94
|
+
- Frozen CSVs in `data/` (`magic-mrclean`, `triangle*`, `vaca`, `carefl`) are a
|
|
95
|
+
contract — **never regenerate silently**; a new seed/equations → new folder
|
|
96
|
+
(sim2-style), regenerate `ref_ls/` with R where applicable, update
|
|
97
|
+
truth-dependent tests. `test_paper_dgps.py::test_frozen_csv_contract` pins the
|
|
98
|
+
paper-DGP CSVs to their generators bit-exactly.
|
|
99
|
+
- Fit tests for the paper DGPs train on **regenerated n=20k** (deterministic
|
|
100
|
+
`observational(n, seed_offset=100)`), not the frozen n=5k — β13 multiplies the
|
|
101
|
+
low-variance x1 ∈ [0.25, 0.73] and is too weakly identified at n=5k.
|
|
102
|
+
- New causal features should be validated against the simulator's known truth
|
|
103
|
+
(`MagicMrClean.true_ate`, `counterfactual_pair` gives true individual
|
|
104
|
+
counterfactuals via shared latents).
|
|
105
|
+
|
|
106
|
+
## Roadmap notes
|
|
107
|
+
|
|
108
|
+
- ~~Generalize `simulations/` registry beyond the stroke DAG~~ — done for the
|
|
109
|
+
TRAM-DAG paper's DGPs (triangle/triangle-mixed/vaca/carefl, June 2026). Still
|
|
110
|
+
open: hidden confounding à la DeCaFlow.
|
|
111
|
+
- ~~Package for PyPI~~ — published as `tramdag` 0.2.0 (June 2026); release flow:
|
|
112
|
+
bump version in pyproject + `__init__`, `uv build`, `uv publish` (Oliver's
|
|
113
|
+
PyPI token), CHANGELOG section.
|
tramdag-0.2.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Oliver Dürr, Beate Sick (tensorchiefs)
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
tramdag-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tramdag
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Interpretable Neural Causal Models (TRAM-DAGs) in PyTorch: one causal normalizing flow for observational, interventional and counterfactual queries
|
|
5
|
+
Project-URL: Homepage, https://github.com/tensorchiefs/tramdag
|
|
6
|
+
Project-URL: Method paper (CLeaR 2025), https://arxiv.org/abs/2503.16206
|
|
7
|
+
Project-URL: Stroke case study, https://arxiv.org/abs/2606.12623
|
|
8
|
+
Author: Beate Sick
|
|
9
|
+
Author-email: Oliver Dürr <oliver.duerr@gmail.com>
|
|
10
|
+
License-Expression: MIT
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Keywords: causal-inference,counterfactuals,interpretability,normalizing-flows,structural-causal-models,transformation-models
|
|
13
|
+
Classifier: Development Status :: 4 - Beta
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.10
|
|
18
|
+
Requires-Dist: numpy
|
|
19
|
+
Requires-Dist: pandas
|
|
20
|
+
Requires-Dist: torch>=2.0
|
|
21
|
+
Requires-Dist: zuko>=1.3
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
|
|
24
|
+
# tramdag — Interpretable Neural Causal Models (TRAM-DAGs) in PyTorch
|
|
25
|
+
|
|
26
|
+
[](https://colab.research.google.com/github/tensorchiefs/tramdag/blob/main/notebooks/demo_tram_dag_colab.ipynb)
|
|
27
|
+
[](https://pypi.org/project/tramdag/)
|
|
28
|
+
[](LICENSE)
|
|
29
|
+
|
|
30
|
+
**TRAM-DAGs** model each variable of a structural causal model with a
|
|
31
|
+
(transformation-model) flow: one triangular normalizing flow from iid
|
|
32
|
+
standard-logistic latents to the observed variables, whose Jacobian sparsity is
|
|
33
|
+
exactly your causal DAG. Fit it **once** on observational data and answer all
|
|
34
|
+
three rungs of Pearl's causal hierarchy — observational (L1), interventional
|
|
35
|
+
(L2, the do-operator), and counterfactual (L3, Pearl abduction) — while keeping
|
|
36
|
+
**interpretable effects**: every linear-shift coefficient is a log-odds ratio,
|
|
37
|
+
exactly as in classical proportional-odds models.
|
|
38
|
+
|
|
39
|
+
> Beate Sick & Oliver Dürr, *Interpretable Neural Causal Models with TRAM-DAGs*,
|
|
40
|
+
> CLeaR 2025 ([arXiv:2503.16206](https://arxiv.org/abs/2503.16206)).
|
|
41
|
+
> This repo is the reference implementation (PyTorch, built on
|
|
42
|
+
> [zuko](https://zuko.readthedocs.io/stable/)); all of the paper's experiments are
|
|
43
|
+
> replicated here with pinned tests.
|
|
44
|
+
|
|
45
|
+
**5-minute showcase**: the Colab badge above fits the paper's bimodal benchmark
|
|
46
|
+
live (GPU-ready) and walks L1 → L2 → L3, every answer checked against analytic
|
|
47
|
+
ground truth. Didactic walkthrough of the model:
|
|
48
|
+
[`notebooks/intro_tram_dag.py`](notebooks/intro_tram_dag.py).
|
|
49
|
+
|
|
50
|
+
## Install
|
|
51
|
+
|
|
52
|
+
```bash
|
|
53
|
+
pip install tramdag # PyPI
|
|
54
|
+
uv sync # or: dev setup from a clone (tests, experiments)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
## 30 seconds of API
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
import tramdag as td
|
|
61
|
+
from tramdag import CausalFlowDAG, ContinuousNode, OrdinalNode
|
|
62
|
+
|
|
63
|
+
spec = { # the spec IS the labelled DAG
|
|
64
|
+
"Age": ContinuousNode(),
|
|
65
|
+
"mRS_pre": OrdinalNode(levels=6, parents={"Age": "ci"}),
|
|
66
|
+
"NIHSSa": ContinuousNode(parents={"Age": "ci", "mRS_pre": "ls"}),
|
|
67
|
+
"T": OrdinalNode(levels=2,
|
|
68
|
+
parents={"Age": "ci", "mRS_pre": "ls", "NIHSSa": "cs"}),
|
|
69
|
+
"mRS_3m": OrdinalNode(levels=7,
|
|
70
|
+
parents={"Age": "ci", "mRS_pre": "ls",
|
|
71
|
+
"NIHSSa": "cs", "T": "ls"}),
|
|
72
|
+
}
|
|
73
|
+
flow = CausalFlowDAG(spec) # validates acyclicity, builds the flow
|
|
74
|
+
|
|
75
|
+
# self-stopping training: per-node plateau lr decay + freezing of converged
|
|
76
|
+
# nodes (exact, since the per-node NLLs have independent gradients);
|
|
77
|
+
# see docs/training-speed.md for benchmarks and the classic two-phase recipe
|
|
78
|
+
flow.fit(train_df, val_df, epochs=4000, learning_rate=1e-2,
|
|
79
|
+
schedule="plateau", plateau_patience=30, freeze_patience=120)
|
|
80
|
+
|
|
81
|
+
flow.log_prob(df) # L1: joint log-likelihood per row
|
|
82
|
+
flow.sample(1000) # L1: observational sampling
|
|
83
|
+
flow.sample(1000, do={"T": 1}) # L2: interventional (graph mutilation)
|
|
84
|
+
flow.pmf(df, node="mRS_3m", do={"T": 1}) # L2: analytic interventional PMF
|
|
85
|
+
|
|
86
|
+
u = flow.abduct(df) # L3 step 1: latents from observations
|
|
87
|
+
cf = flow.sample(do={"T": 1}, u=u) # L3 steps 2+3: counterfactuals
|
|
88
|
+
|
|
89
|
+
flow.save("flow.pt"); flow = CausalFlowDAG.load("flow.pt")
|
|
90
|
+
|
|
91
|
+
td.simulations.REGISTRY # synthetic DGPs with known ground truth
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
## The model in one table
|
|
95
|
+
|
|
96
|
+
Per node, the transformation is additive on the latent (log-odds) scale —
|
|
97
|
+
`u = h(x; θ) + Σ β·x_pa + Σ g(x_pa)` — and each parent edge declares how it enters:
|
|
98
|
+
|
|
99
|
+
| edge term | meaning | interpretability |
|
|
100
|
+
|---|---|---|
|
|
101
|
+
| `ls` | linear shift `β·x_pa` | `exp(β)` is an odds ratio — one number per edge |
|
|
102
|
+
| `cs` | complex shift `g(x_pa)` (MLP), still additive | odds-ratio *function*, plot `g` |
|
|
103
|
+
| `ci` | complex intercept: the transform's parameters depend on the parents (several `ci` parents feed one joint network) | maximal flexibility, interactions |
|
|
104
|
+
|
|
105
|
+
Continuous nodes carry a monotone 1-D transform (`bernstein` — TRAM-faithful
|
|
106
|
+
default, `spline`, `affine`; `ContinuousNode(transform=..., transform_kwargs=...)`);
|
|
107
|
+
ordinal nodes an ordered-logit head `P(x ≤ k) = σ(θ_k − shift)`. Abduction is exact
|
|
108
|
+
for continuous nodes and truncated-logistic for ordinal ones, so
|
|
109
|
+
`flow.sample(u=flow.abduct(df))` reproduces `df` exactly / level-exactly.
|
|
110
|
+
|
|
111
|
+
## Validation (all pinned by tests)
|
|
112
|
+
|
|
113
|
+
- **Paper replication** — every experiment of the CLeaR paper is a registry
|
|
114
|
+
family (numpy-only SCM + frozen CSVs + replication script):
|
|
115
|
+
|
|
116
|
+
| family | paper | demonstrates |
|
|
117
|
+
|---|---|---|
|
|
118
|
+
| `triangle` (`linear`,`atan`,`sin`) | §6.1 | LS coefficient recovery (β = 2, −0.2, +0.3), CS curve ≡ −f(x₂), non-monotone f |
|
|
119
|
+
| `triangle-mixed` (`linear`,`exp`) | §6.2 | mixed data L1/L2 + the C.4 odds-ratio check (OR ≈ 7.4) |
|
|
120
|
+
| `vaca` | §5.1–5.2 | the bimodal L1 case a default CNF misses; L2 `p(x₃ \| do(x₂))` |
|
|
121
|
+
| `carefl` | §5.3 | L3 counterfactual curves vs **analytic** truth |
|
|
122
|
+
|
|
123
|
+
```bash
|
|
124
|
+
cd experiments && uv run python paper_triangle.py atan cs # etc., see paper_*.py
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
Sign note: ordinal shifts are *subtracted* here but *added* in the paper, so
|
|
128
|
+
fitted ordinal weights are the paper's with flipped sign (`truth.json` records
|
|
129
|
+
both conventions per family).
|
|
130
|
+
|
|
131
|
+
- **Exact classical equivalence** — an all-`ls` flow trained to convergence *is*
|
|
132
|
+
the proportional-odds MLE: coefficients match `statsmodels` **and** R
|
|
133
|
+
`MASS::polr` to ~4 decimals (`experiments/validate_ls.py`, R reference committed
|
|
134
|
+
under `data/magic-mrclean/*/ref_ls/`).
|
|
135
|
+
|
|
136
|
+
- **Training speed** — schedules, per-node freezing, LBFGS and device benchmarks:
|
|
137
|
+
[`docs/training-speed.md`](docs/training-speed.md).
|
|
138
|
+
|
|
139
|
+
## Case study: individualized treatment effects in stroke
|
|
140
|
+
|
|
141
|
+
The method's flagship application estimates individualized thrombectomy effects
|
|
142
|
+
from the observational MAGIC cohort with external validation against the
|
|
143
|
+
MR CLEAN trial:
|
|
144
|
+
|
|
145
|
+
> Dürr, Herzog, Bühler, Wegener & Sick, *Estimating Individualized Treatment
|
|
146
|
+
> Effects in Acute Ischemic Stroke with Causal Transformation Models (TRAM-DAG)*
|
|
147
|
+
> ([arXiv:2606.12623](https://arxiv.org/abs/2606.12623)).
|
|
148
|
+
|
|
149
|
+
The clinical data is private and **never** part of this repo. Its public
|
|
150
|
+
stand-in is `data/magic-mrclean/` — a fully synthetic cohort with the same
|
|
151
|
+
schema and **known ground truth** (true ATE, true individual counterfactuals),
|
|
152
|
+
including an `nl` variant where an all-`ls` model is provably misspecified:
|
|
153
|
+
|
|
154
|
+
| `nl` variant | ATE | vs true **+0.104** |
|
|
155
|
+
|---|---|---|
|
|
156
|
+
| naive observational contrast | +0.303 | confounded (overstates 2.9×) |
|
|
157
|
+
| all-`ls` flow | +0.076 | undershoots (misses the age-varying effect) |
|
|
158
|
+
| flexible (`ci`/`cs`) flow | +0.101 | **recovers the truth** |
|
|
159
|
+
|
|
160
|
+
Full storyline, clinical-data context, R cross-check and reading notes:
|
|
161
|
+
[`docs/stroke-case-study.md`](docs/stroke-case-study.md).
|
|
162
|
+
|
|
163
|
+
## Layout
|
|
164
|
+
|
|
165
|
+
```
|
|
166
|
+
src/tramdag/ spec.py transforms.py conditioners.py flow.py
|
|
167
|
+
simulations/ (magic_mrclean, triangle, vaca, carefl + CLIs)
|
|
168
|
+
data/ frozen synthetic CSVs + truth.json — a test contract
|
|
169
|
+
experiments/ stroke pipeline, paper replications, training benchmark
|
|
170
|
+
notebooks/ intro (didactic) + Colab demo (jupytext .py — see README there)
|
|
171
|
+
tests/ 66 tests: unit, known-truth recovery, R regression
|
|
172
|
+
docs/ training-speed.md, stroke-case-study.md
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
Implementation conventions (latent-scale signs, raw/one-hot parent encoding,
|
|
176
|
+
log-space ordinal likelihood, seeding) are documented in
|
|
177
|
+
[`CLAUDE.md`](CLAUDE.md) and pinned by tests.
|
|
178
|
+
|
|
179
|
+
## Citation
|
|
180
|
+
|
|
181
|
+
If you use `tramdag`, please cite the method paper:
|
|
182
|
+
|
|
183
|
+
```bibtex
|
|
184
|
+
@inproceedings{sick2025tramdag,
|
|
185
|
+
title = {Interpretable Neural Causal Models with TRAM-DAGs},
|
|
186
|
+
author = {Sick, Beate and D{\"u}rr, Oliver},
|
|
187
|
+
booktitle = {Proceedings of the 4th Conference on Causal Learning and Reasoning (CLeaR)},
|
|
188
|
+
series = {Proceedings of Machine Learning Research},
|
|
189
|
+
volume = {275},
|
|
190
|
+
year = {2025},
|
|
191
|
+
}
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
For the stroke application (and the `magic-mrclean` cohort design) additionally:
|
|
195
|
+
|
|
196
|
+
```bibtex
|
|
197
|
+
@article{duerr2026stroke,
|
|
198
|
+
title = {Estimating Individualized Treatment Effects in Acute Ischemic Stroke
|
|
199
|
+
with Causal Transformation Models (TRAM-DAG): A Multi-Centre
|
|
200
|
+
Observational Study with External RCT Validation},
|
|
201
|
+
author = {D{\"u}rr, Oliver and Herzog, Lisa and B{\"u}hler, Pascal and
|
|
202
|
+
Wegener, Susanne and Sick, Beate},
|
|
203
|
+
journal = {arXiv preprint arXiv:2606.12623},
|
|
204
|
+
year = {2026},
|
|
205
|
+
}
|
|
206
|
+
```
|