tramdag 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tramdag/__init__.py +22 -0
- tramdag/conditioners.py +74 -0
- tramdag/flow.py +427 -0
- tramdag/simulations/__init__.py +23 -0
- tramdag/simulations/carefl.py +125 -0
- tramdag/simulations/magic_mrclean.py +259 -0
- tramdag/simulations/triangle.py +252 -0
- tramdag/simulations/vaca.py +130 -0
- tramdag/spec.py +99 -0
- tramdag/transforms.py +271 -0
- tramdag-0.2.0.dist-info/METADATA +206 -0
- tramdag-0.2.0.dist-info/RECORD +14 -0
- tramdag-0.2.0.dist-info/WHEEL +4 -0
- tramdag-0.2.0.dist-info/licenses/LICENSE +21 -0
tramdag/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""tramdag — Interpretable Neural Causal Models (TRAM-DAGs) in PyTorch.
|
|
2
|
+
|
|
3
|
+
One triangular normalizing flow (built on `zuko <https://zuko.readthedocs.io>`_)
|
|
4
|
+
whose Jacobian sparsity is the causal DAG: fit once on observational data,
|
|
5
|
+
then answer observational (L1), interventional (L2) and counterfactual (L3)
|
|
6
|
+
queries. Reference: Sick & Dürr, *Interpretable Neural Causal Models with
|
|
7
|
+
TRAM-DAGs*, CLeaR 2025 (arXiv:2503.16206).
|
|
8
|
+
|
|
9
|
+
Conventional import alias::
|
|
10
|
+
|
|
11
|
+
import tramdag as td
|
|
12
|
+
|
|
13
|
+
flow = td.CausalFlowDAG(spec)
|
|
14
|
+
td.simulations.REGISTRY # synthetic DGPs with known ground truth
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from . import simulations
|
|
18
|
+
from .flow import CausalFlowDAG
|
|
19
|
+
from .spec import ContinuousNode, OrdinalNode
|
|
20
|
+
|
|
21
|
+
__all__ = ["CausalFlowDAG", "ContinuousNode", "OrdinalNode", "simulations"]
|
|
22
|
+
__version__ = "0.2.0"
|
tramdag/conditioners.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Conditioner networks for the per-edge term types.
|
|
2
|
+
|
|
3
|
+
Architectures replicate the original Keras/TF implementation's defaults
|
|
4
|
+
(``tram_models.py`` in https://github.com/tensorchiefs/tram-dag)
|
|
5
|
+
so that fitted models are directly comparable:
|
|
6
|
+
|
|
7
|
+
- ``LinearShift`` — Linear(n, 1, bias=False) (term "ls")
|
|
8
|
+
- ``ComplexShift`` — 64-128-64 ReLU MLP -> 1, no bias (term "cs",
|
|
9
|
+
original ``ComplexShiftDefaultTabular``)
|
|
10
|
+
- ``ComplexIntercept`` — 8-8 ReLU MLP -> n_params, no bias (term "ci",
|
|
11
|
+
original ``ComplexInterceptDefaultTabular``)
|
|
12
|
+
- ``SimpleIntercept`` — free parameter vector (no parent dependence)
|
|
13
|
+
|
|
14
|
+
Parent features follow the original implementation's encoding: continuous parents enter raw
|
|
15
|
+
(one column), ordinal parents are one-hot encoded (``levels`` columns).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import Tensor, nn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SimpleIntercept(nn.Module):
|
|
25
|
+
"""Free (data-independent) transform parameters, broadcast over the batch."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, n_params: int):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.theta = nn.Parameter(torch.zeros(n_params))
|
|
30
|
+
|
|
31
|
+
def forward(self, n: int) -> Tensor:
|
|
32
|
+
return self.theta.unsqueeze(0).expand(n, -1)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ComplexIntercept(nn.Module):
|
|
36
|
+
"""Transform parameters as a function of the (joint) ci-parent features."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, n_features: int, n_params: int):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.net = nn.Sequential(
|
|
41
|
+
nn.Linear(n_features, 8), nn.ReLU(),
|
|
42
|
+
nn.Linear(8, 8), nn.ReLU(),
|
|
43
|
+
nn.Linear(8, n_params, bias=False),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
47
|
+
return self.net(x)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LinearShift(nn.Module):
|
|
51
|
+
def __init__(self, n_features: int):
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.fc = nn.Linear(n_features, 1, bias=False)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def weight(self) -> Tensor:
|
|
57
|
+
return self.fc.weight.squeeze(0)
|
|
58
|
+
|
|
59
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
60
|
+
return self.fc(x).squeeze(-1)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ComplexShift(nn.Module):
|
|
64
|
+
def __init__(self, n_features: int):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.net = nn.Sequential(
|
|
67
|
+
nn.Linear(n_features, 64), nn.ReLU(),
|
|
68
|
+
nn.Linear(64, 128), nn.ReLU(),
|
|
69
|
+
nn.Linear(128, 64), nn.ReLU(),
|
|
70
|
+
nn.Linear(64, 1, bias=False),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
74
|
+
return self.net(x).squeeze(-1)
|
tramdag/flow.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
"""CausalFlowDAG — a single triangular normalizing flow on a user-defined DAG.
|
|
2
|
+
|
|
3
|
+
The flow maps iid standard-logistic latents ``U`` to the observed variables ``X``
|
|
4
|
+
in topological order; its Jacobian sparsity is exactly the DAG adjacency. The
|
|
5
|
+
joint log-likelihood decomposes per node, so one optimizer fits all nodes at once.
|
|
6
|
+
|
|
7
|
+
Causal queries:
|
|
8
|
+
flow.sample(n) observational sampling
|
|
9
|
+
flow.sample(n, do={"T": 1}) interventional sampling (graph mutilation)
|
|
10
|
+
u = flow.abduct(df) Pearl step 1 (latents from observations)
|
|
11
|
+
flow.sample(do={"T": 1}, u=u) Pearl steps 2+3 (counterfactuals)
|
|
12
|
+
flow.pmf(df, node, do=...) analytic per-row interventional PMF
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import copy
|
|
18
|
+
import json
|
|
19
|
+
import time
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
24
|
+
import torch
|
|
25
|
+
from torch import Tensor, nn
|
|
26
|
+
|
|
27
|
+
from .conditioners import ComplexIntercept, ComplexShift, LinearShift, SimpleIntercept
|
|
28
|
+
from .spec import (ContinuousNode, NodeSpec, OrdinalNode, spec_from_dict,
|
|
29
|
+
spec_to_dict, validate_and_sort)
|
|
30
|
+
from .transforms import (StandardLogistic, make_univariate_transform,
|
|
31
|
+
ordinal_abduct, ordinal_log_prob, ordinal_pmf,
|
|
32
|
+
ordinal_sample)
|
|
33
|
+
|
|
34
|
+
__all__ = ["CausalFlowDAG"]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _Node(nn.Module):
|
|
38
|
+
"""One dimension of the flow: intercept (transform params) + additive shifts."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, name: str, node: NodeSpec, spec: dict[str, NodeSpec]):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.name = name
|
|
43
|
+
self.kind = node.kind
|
|
44
|
+
self.parents = dict(node.parents)
|
|
45
|
+
self.ci_parents = [p for p, t in node.parents.items() if t == "ci"]
|
|
46
|
+
|
|
47
|
+
if isinstance(node, ContinuousNode):
|
|
48
|
+
self.ut = make_univariate_transform(node.transform, **node.transform_kwargs)
|
|
49
|
+
n_params = self.ut.n_params
|
|
50
|
+
self.levels = None
|
|
51
|
+
else:
|
|
52
|
+
self.ut = None
|
|
53
|
+
self.levels = node.levels
|
|
54
|
+
n_params = node.levels - 1
|
|
55
|
+
|
|
56
|
+
def width(parent: str) -> int:
|
|
57
|
+
pn = spec[parent]
|
|
58
|
+
return pn.levels if isinstance(pn, OrdinalNode) else 1
|
|
59
|
+
|
|
60
|
+
if self.ci_parents:
|
|
61
|
+
self.intercept = ComplexIntercept(sum(width(p) for p in self.ci_parents), n_params)
|
|
62
|
+
else:
|
|
63
|
+
self.intercept = SimpleIntercept(n_params)
|
|
64
|
+
|
|
65
|
+
self.shifts = nn.ModuleDict()
|
|
66
|
+
for parent, term in node.parents.items():
|
|
67
|
+
if term == "ls":
|
|
68
|
+
self.shifts[parent] = LinearShift(width(parent))
|
|
69
|
+
elif term == "cs":
|
|
70
|
+
self.shifts[parent] = ComplexShift(width(parent))
|
|
71
|
+
|
|
72
|
+
def theta_shift(self, feats: dict[str, Tensor], n: int) -> tuple[Tensor, Tensor]:
|
|
73
|
+
"""Transform parameters (n, P) and total shift (n,) from parent features."""
|
|
74
|
+
if self.ci_parents:
|
|
75
|
+
theta = self.intercept(torch.cat([feats[p] for p in self.ci_parents], dim=1))
|
|
76
|
+
else:
|
|
77
|
+
theta = self.intercept(n)
|
|
78
|
+
shift = torch.zeros(n, dtype=theta.dtype, device=theta.device)
|
|
79
|
+
for parent, module in self.shifts.items():
|
|
80
|
+
shift = shift + module(feats[parent])
|
|
81
|
+
return theta, shift
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class CausalFlowDAG(nn.Module):
|
|
85
|
+
"""A causal normalizing flow defined by ``spec = {name: NodeSpec}``."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, spec: dict[str, NodeSpec], device: str = "cpu"):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.spec = spec
|
|
90
|
+
self.order = validate_and_sort(spec)
|
|
91
|
+
self.nodes = nn.ModuleDict({name: _Node(name, spec[name], spec) for name in self.order})
|
|
92
|
+
self.device = torch.device(device)
|
|
93
|
+
self.history: dict = {"train": [], "val": [], "lr": [], "time": []}
|
|
94
|
+
self.to(self.device)
|
|
95
|
+
|
|
96
|
+
# ------------------------------------------------------------------ data
|
|
97
|
+
def _encode_parent(self, name: str, values: Tensor) -> Tensor:
|
|
98
|
+
"""Encode a node's values for use as a parent feature (original TRAM-DAG convention:
|
|
99
|
+
continuous raw (n, 1); ordinal one-hot (n, levels))."""
|
|
100
|
+
node = self.spec[name]
|
|
101
|
+
if isinstance(node, OrdinalNode):
|
|
102
|
+
return torch.nn.functional.one_hot(
|
|
103
|
+
values.long(), num_classes=node.levels).to(values.dtype)
|
|
104
|
+
return values.view(-1, 1)
|
|
105
|
+
|
|
106
|
+
def _tensorize(self, df: pd.DataFrame) -> dict[str, Tensor]:
|
|
107
|
+
out = {}
|
|
108
|
+
for name in self.order:
|
|
109
|
+
vals = torch.as_tensor(
|
|
110
|
+
df[name].to_numpy(dtype=np.float32), device=self.device)
|
|
111
|
+
out[name] = vals
|
|
112
|
+
return out
|
|
113
|
+
|
|
114
|
+
def _features(self, values: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
115
|
+
return {name: self._encode_parent(name, vals) for name, vals in values.items()}
|
|
116
|
+
|
|
117
|
+
# ------------------------------------------------------------- likelihood
|
|
118
|
+
def node_log_prob(self, values: dict[str, Tensor],
|
|
119
|
+
nodes: list[str] | None = None) -> dict[str, Tensor]:
|
|
120
|
+
"""Per-node log-likelihood contributions, each (n,).
|
|
121
|
+
|
|
122
|
+
``nodes`` restricts computation to a subset (used to skip frozen nodes
|
|
123
|
+
during training — valid because the per-node losses are independent)."""
|
|
124
|
+
feats = self._features(values)
|
|
125
|
+
n = next(iter(values.values())).shape[0]
|
|
126
|
+
out = {}
|
|
127
|
+
for name in (self.order if nodes is None else nodes):
|
|
128
|
+
node = self.nodes[name]
|
|
129
|
+
theta, shift = node.theta_shift(feats, n)
|
|
130
|
+
x = values[name]
|
|
131
|
+
if node.kind == "continuous":
|
|
132
|
+
z0, ladj = node.ut.forward(theta, x)
|
|
133
|
+
z = z0 + shift
|
|
134
|
+
out[name] = StandardLogistic.log_prob(z) + ladj
|
|
135
|
+
else:
|
|
136
|
+
out[name] = ordinal_log_prob(theta, shift, x)
|
|
137
|
+
return out
|
|
138
|
+
|
|
139
|
+
def log_prob(self, df: pd.DataFrame) -> Tensor:
|
|
140
|
+
"""Joint log-likelihood log p(x) per row, shape (n,)."""
|
|
141
|
+
per_node = self.node_log_prob(self._tensorize(df))
|
|
142
|
+
return torch.stack(list(per_node.values()), dim=0).sum(dim=0)
|
|
143
|
+
|
|
144
|
+
def nll(self, df: pd.DataFrame) -> dict[str, float]:
|
|
145
|
+
"""Mean negative log-likelihood per node (diagnostic)."""
|
|
146
|
+
with torch.no_grad():
|
|
147
|
+
per_node = self.node_log_prob(self._tensorize(df))
|
|
148
|
+
return {k: float(-v.mean()) for k, v in per_node.items()}
|
|
149
|
+
|
|
150
|
+
# ------------------------------------------------------------------- fit
|
|
151
|
+
def _set_ranges(self, train_df: pd.DataFrame) -> None:
|
|
152
|
+
"""Train 5%/95% quantiles -> transform domain (the original implementation's min_max scaling)."""
|
|
153
|
+
for name in self.order:
|
|
154
|
+
node = self.nodes[name]
|
|
155
|
+
if node.kind == "continuous" and not node.ut._fitted:
|
|
156
|
+
q = train_df[name].quantile([0.05, 0.95])
|
|
157
|
+
node.ut.set_range(q.iloc[0], q.iloc[1])
|
|
158
|
+
|
|
159
|
+
def fit(self, train_df: pd.DataFrame, val_df: pd.DataFrame | None = None,
|
|
160
|
+
epochs: int = 500, learning_rate: float = 1e-2, batch_size: int = 512,
|
|
161
|
+
verbose: int = 50, seed: int | None = None,
|
|
162
|
+
restore_best: bool = False, schedule: str | None = None,
|
|
163
|
+
plateau_patience: int = 15, freeze_patience: int | None = None,
|
|
164
|
+
min_delta: float = 1e-4) -> "CausalFlowDAG":
|
|
165
|
+
"""Jointly fit all nodes by maximum likelihood.
|
|
166
|
+
|
|
167
|
+
By default training keeps the **final** (converged) weights, so an
|
|
168
|
+
all-``ls`` model trained to convergence reproduces the classical maximum
|
|
169
|
+
likelihood estimate exactly (e.g. matches ``statsmodels``/``polr``).
|
|
170
|
+
|
|
171
|
+
The optimizer holds one parameter group per node. Because the joint NLL
|
|
172
|
+
decomposes per node with independent gradients, per-node learning rates
|
|
173
|
+
and freezing are exactly equivalent to independent per-node training.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
val_df: optional held-out set, used only for monitoring (and for
|
|
177
|
+
``restore_best``, ``schedule="plateau"`` and ``freeze_patience``).
|
|
178
|
+
If omitted, the training set is used for the validation metric.
|
|
179
|
+
restore_best: if True, snapshot each node's best-validation weights
|
|
180
|
+
during training and restore them at the end (mild early-stopping
|
|
181
|
+
regularization, the original implementation's convention). This makes the fit
|
|
182
|
+
*not* the training-data MLE, so leave it False for an exact
|
|
183
|
+
classical comparison. Default False.
|
|
184
|
+
schedule: learning-rate schedule. ``None`` = constant (the classic
|
|
185
|
+
behavior); ``"onecycle"`` = ``OneCycleLR`` (warmup to
|
|
186
|
+
``learning_rate``, then anneal; stepped per batch);
|
|
187
|
+
``"cosine"`` = ``CosineAnnealingLR`` over ``epochs``;
|
|
188
|
+
``"plateau"`` = **per-node** decay: a node's lr is multiplied by
|
|
189
|
+
0.3 whenever its own validation NLL hasn't improved by
|
|
190
|
+
``min_delta`` for ``plateau_patience`` epochs (floor 1e-3 ×
|
|
191
|
+
``learning_rate``).
|
|
192
|
+
freeze_patience: if set, a node whose validation NLL hasn't improved
|
|
193
|
+
by ``min_delta`` for this many epochs is **frozen** — excluded
|
|
194
|
+
from the loss and backward pass (a real compute saving, since
|
|
195
|
+
per-node losses are independent). When every node is frozen the
|
|
196
|
+
fit returns early. Freeze epochs are recorded in
|
|
197
|
+
``history["frozen"]``.
|
|
198
|
+
|
|
199
|
+
Calling ``fit`` again continues training (e.g. a second phase with a
|
|
200
|
+
lower learning rate); freezing state does not carry across calls.
|
|
201
|
+
"""
|
|
202
|
+
if schedule not in (None, "onecycle", "cosine", "plateau"):
|
|
203
|
+
raise ValueError(f"unknown schedule {schedule!r}")
|
|
204
|
+
if seed is not None:
|
|
205
|
+
torch.manual_seed(seed)
|
|
206
|
+
self._set_ranges(train_df)
|
|
207
|
+
|
|
208
|
+
train_vals = self._tensorize(train_df)
|
|
209
|
+
val_vals = self._tensorize(val_df) if val_df is not None else train_vals
|
|
210
|
+
n = len(train_df)
|
|
211
|
+
steps_per_epoch = (n + batch_size - 1) // batch_size
|
|
212
|
+
|
|
213
|
+
opt = torch.optim.Adam(
|
|
214
|
+
[{"params": list(self.nodes[name].parameters()), "lr": learning_rate,
|
|
215
|
+
"node": name} for name in self.order])
|
|
216
|
+
sched = None
|
|
217
|
+
if schedule == "onecycle":
|
|
218
|
+
sched = torch.optim.lr_scheduler.OneCycleLR(
|
|
219
|
+
opt, max_lr=learning_rate, total_steps=epochs * steps_per_epoch)
|
|
220
|
+
elif schedule == "cosine":
|
|
221
|
+
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
222
|
+
opt, T_max=epochs, eta_min=learning_rate * 1e-3)
|
|
223
|
+
|
|
224
|
+
if restore_best and not hasattr(self, "_best"):
|
|
225
|
+
self._best = {name: (float("inf"), None) for name in self.order}
|
|
226
|
+
best = self._best if restore_best else None
|
|
227
|
+
# per-node plateau/freeze bookkeeping (local to this fit call)
|
|
228
|
+
node_best = {name: float("inf") for name in self.order}
|
|
229
|
+
node_bad = {name: 0 for name in self.order}
|
|
230
|
+
frozen: set[str] = set()
|
|
231
|
+
t0 = time.perf_counter()
|
|
232
|
+
t_offset = self.history["time"][-1] if self.history.get("time") else 0.0
|
|
233
|
+
prev_train: dict[str, float] = {}
|
|
234
|
+
|
|
235
|
+
for epoch in range(epochs):
|
|
236
|
+
self.train()
|
|
237
|
+
active = [name for name in self.order if name not in frozen]
|
|
238
|
+
perm = torch.randperm(n, device=self.device)
|
|
239
|
+
train_acc = {name: prev_train.get(name, float("nan"))
|
|
240
|
+
for name in frozen}
|
|
241
|
+
train_acc.update({name: 0.0 for name in active})
|
|
242
|
+
for start in range(0, n, batch_size):
|
|
243
|
+
idx = perm[start:start + batch_size]
|
|
244
|
+
batch = {k: v[idx] for k, v in train_vals.items()}
|
|
245
|
+
per_node = self.node_log_prob(batch, nodes=active)
|
|
246
|
+
node_nlls = {k: -v.mean() for k, v in per_node.items()}
|
|
247
|
+
loss = torch.stack(list(node_nlls.values())).sum()
|
|
248
|
+
opt.zero_grad()
|
|
249
|
+
loss.backward()
|
|
250
|
+
opt.step()
|
|
251
|
+
if schedule == "onecycle":
|
|
252
|
+
sched.step()
|
|
253
|
+
w = len(idx) / n
|
|
254
|
+
for k, v in node_nlls.items():
|
|
255
|
+
train_acc[k] += float(v.detach()) * w
|
|
256
|
+
if schedule == "cosine":
|
|
257
|
+
sched.step()
|
|
258
|
+
prev_train = train_acc
|
|
259
|
+
|
|
260
|
+
self.eval()
|
|
261
|
+
with torch.no_grad():
|
|
262
|
+
val_per_node = {k: float(-v.mean())
|
|
263
|
+
for k, v in self.node_log_prob(val_vals).items()}
|
|
264
|
+
self.history["train"].append(train_acc)
|
|
265
|
+
self.history["val"].append(val_per_node)
|
|
266
|
+
self.history.setdefault("lr", []).append(
|
|
267
|
+
max(g["lr"] for g in opt.param_groups))
|
|
268
|
+
self.history.setdefault("time", []).append(
|
|
269
|
+
t_offset + time.perf_counter() - t0)
|
|
270
|
+
|
|
271
|
+
# per-node improvement tracking (plateau decay + freezing)
|
|
272
|
+
for g in opt.param_groups:
|
|
273
|
+
name = g["node"]
|
|
274
|
+
if name in frozen:
|
|
275
|
+
continue
|
|
276
|
+
if val_per_node[name] < node_best[name] - min_delta:
|
|
277
|
+
node_best[name] = val_per_node[name]
|
|
278
|
+
node_bad[name] = 0
|
|
279
|
+
else:
|
|
280
|
+
node_bad[name] += 1
|
|
281
|
+
if (schedule == "plateau" and node_bad[name] > 0
|
|
282
|
+
and node_bad[name] % plateau_patience == 0):
|
|
283
|
+
g["lr"] = max(g["lr"] * 0.3, learning_rate * 1e-3)
|
|
284
|
+
# under "plateau", only freeze nodes whose lr has already been
|
|
285
|
+
# decayed substantially — otherwise a node can freeze while a
|
|
286
|
+
# smaller lr would still make progress toward the optimum
|
|
287
|
+
lr_decayed = (schedule != "plateau"
|
|
288
|
+
or g["lr"] <= learning_rate * 1e-2 * (1 + 1e-9))
|
|
289
|
+
if (freeze_patience is not None and lr_decayed
|
|
290
|
+
and node_bad[name] >= freeze_patience):
|
|
291
|
+
frozen.add(name)
|
|
292
|
+
self.history.setdefault("frozen", {}).setdefault(
|
|
293
|
+
name, len(self.history["val"])) # 1-based global epoch
|
|
294
|
+
|
|
295
|
+
if restore_best:
|
|
296
|
+
for name in self.order:
|
|
297
|
+
if val_per_node[name] < best[name][0]:
|
|
298
|
+
best[name] = (val_per_node[name],
|
|
299
|
+
copy.deepcopy(self.nodes[name].state_dict()))
|
|
300
|
+
|
|
301
|
+
if verbose and (epoch % verbose == 0 or epoch == epochs - 1):
|
|
302
|
+
tot_t = sum(train_acc.values())
|
|
303
|
+
tot_v = sum(val_per_node.values())
|
|
304
|
+
print(f"[epoch {epoch + 1:5d}/{epochs}] train NLL {tot_t:.4f} "
|
|
305
|
+
f"val NLL {tot_v:.4f}"
|
|
306
|
+
+ (f" frozen {sorted(frozen)}" if frozen else ""))
|
|
307
|
+
|
|
308
|
+
if len(frozen) == len(self.order): # everything converged
|
|
309
|
+
if verbose:
|
|
310
|
+
print(f"[epoch {epoch + 1:5d}] all nodes frozen — stopping.")
|
|
311
|
+
break
|
|
312
|
+
|
|
313
|
+
if restore_best: # restore per-node best-validation weights
|
|
314
|
+
for name, (_, state) in best.items():
|
|
315
|
+
if state is not None:
|
|
316
|
+
self.nodes[name].load_state_dict(state)
|
|
317
|
+
self.eval()
|
|
318
|
+
return self
|
|
319
|
+
|
|
320
|
+
# ------------------------------------------------------- causal queries
|
|
321
|
+
@torch.no_grad()
|
|
322
|
+
def sample(self, n: int | None = None, *, do: dict[str, float] | None = None,
|
|
323
|
+
u: pd.DataFrame | None = None, seed: int | None = None) -> pd.DataFrame:
|
|
324
|
+
"""Sample from the (optionally mutilated) flow.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
n: number of samples (ignored if ``u`` is given).
|
|
328
|
+
do: interventions {node: value}; intervened nodes are clamped and
|
|
329
|
+
their parent dependence removed (graph mutilation).
|
|
330
|
+
u: latent variables (as returned by :meth:`abduct`). If given, they
|
|
331
|
+
are pushed through the flow — together with ``do`` this yields
|
|
332
|
+
counterfactuals (Pearl's abduction -> action -> prediction).
|
|
333
|
+
"""
|
|
334
|
+
do = do or {}
|
|
335
|
+
gen = None
|
|
336
|
+
if seed is not None:
|
|
337
|
+
gen = torch.Generator(device=self.device).manual_seed(seed)
|
|
338
|
+
|
|
339
|
+
if u is not None:
|
|
340
|
+
n = len(u)
|
|
341
|
+
u_vals = {name: torch.as_tensor(u[name].to_numpy(dtype=np.float32, copy=True),
|
|
342
|
+
device=self.device) for name in self.order}
|
|
343
|
+
elif n is not None:
|
|
344
|
+
u_vals = {name: StandardLogistic.sample((n,), device=self.device)
|
|
345
|
+
if gen is None else
|
|
346
|
+
StandardLogistic.icdf(torch.rand((n,), device=self.device, generator=gen))
|
|
347
|
+
for name in self.order}
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError("Provide either n or u.")
|
|
350
|
+
|
|
351
|
+
values: dict[str, Tensor] = {}
|
|
352
|
+
for name in self.order:
|
|
353
|
+
if name in do:
|
|
354
|
+
values[name] = torch.full((n,), float(do[name]), device=self.device)
|
|
355
|
+
continue
|
|
356
|
+
node = self.nodes[name]
|
|
357
|
+
feats = self._features({p: values[p] for p in node.parents})
|
|
358
|
+
theta, shift = node.theta_shift(feats, n)
|
|
359
|
+
z = u_vals[name]
|
|
360
|
+
if node.kind == "continuous":
|
|
361
|
+
values[name] = node.ut.inverse(theta, z - shift)
|
|
362
|
+
else:
|
|
363
|
+
values[name] = ordinal_sample(theta, shift, z)
|
|
364
|
+
return pd.DataFrame({k: v.cpu().numpy() for k, v in values.items()})
|
|
365
|
+
|
|
366
|
+
@torch.no_grad()
|
|
367
|
+
def abduct(self, df: pd.DataFrame, seed: int | None = None) -> pd.DataFrame:
|
|
368
|
+
"""Pearl abduction: recover the latent variables ``u`` from observations.
|
|
369
|
+
|
|
370
|
+
Continuous nodes are inverted exactly (``u = h(x) + shift``); for ordinal
|
|
371
|
+
nodes the latent is only interval-identified, so it is sampled from the
|
|
372
|
+
standard logistic truncated to the observed level's interval.
|
|
373
|
+
"""
|
|
374
|
+
gen = None
|
|
375
|
+
if seed is not None:
|
|
376
|
+
gen = torch.Generator(device=self.device).manual_seed(seed)
|
|
377
|
+
values = self._tensorize(df)
|
|
378
|
+
feats = self._features(values)
|
|
379
|
+
n = len(df)
|
|
380
|
+
u = {}
|
|
381
|
+
for name in self.order:
|
|
382
|
+
node = self.nodes[name]
|
|
383
|
+
theta, shift = node.theta_shift(feats, n)
|
|
384
|
+
x = values[name]
|
|
385
|
+
if node.kind == "continuous":
|
|
386
|
+
z0, _ = node.ut.forward(theta, x)
|
|
387
|
+
u[name] = z0 + shift
|
|
388
|
+
else:
|
|
389
|
+
u[name] = ordinal_abduct(theta, shift, x, generator=gen)
|
|
390
|
+
return pd.DataFrame({k: v.cpu().numpy() for k, v in u.items()})
|
|
391
|
+
|
|
392
|
+
@torch.no_grad()
|
|
393
|
+
def pmf(self, df: pd.DataFrame, node: str, do: dict[str, float] | None = None) -> np.ndarray:
|
|
394
|
+
"""Analytic class probabilities (n, levels) for an ordinal node, with the
|
|
395
|
+
node's parents taken from ``df`` after applying ``do`` overrides."""
|
|
396
|
+
if not isinstance(self.spec[node], OrdinalNode):
|
|
397
|
+
raise ValueError(f"pmf() requires an ordinal node, '{node}' is continuous.")
|
|
398
|
+
df_local = df.copy()
|
|
399
|
+
for col, val in (do or {}).items():
|
|
400
|
+
df_local[col] = val
|
|
401
|
+
nd = self.nodes[node]
|
|
402
|
+
values = {p: torch.as_tensor(df_local[p].to_numpy(dtype=np.float32),
|
|
403
|
+
device=self.device) for p in nd.parents}
|
|
404
|
+
feats = self._features(values)
|
|
405
|
+
theta, shift = nd.theta_shift(feats, len(df_local))
|
|
406
|
+
return ordinal_pmf(theta, shift).cpu().numpy()
|
|
407
|
+
|
|
408
|
+
# ------------------------------------------------------------------- io
|
|
409
|
+
def save(self, path: str | Path) -> None:
|
|
410
|
+
path = Path(path)
|
|
411
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
412
|
+
torch.save({"spec": spec_to_dict(self.spec),
|
|
413
|
+
"state_dict": self.state_dict(),
|
|
414
|
+
"history": self.history}, path)
|
|
415
|
+
|
|
416
|
+
@classmethod
|
|
417
|
+
def load(cls, path: str | Path, device: str = "cpu") -> "CausalFlowDAG":
|
|
418
|
+
ckpt = torch.load(path, map_location=device, weights_only=False)
|
|
419
|
+
flow = cls(spec_from_dict(ckpt["spec"]), device=device)
|
|
420
|
+
for name in flow.order: # mark transforms as fitted before loading buffers
|
|
421
|
+
node = flow.nodes[name]
|
|
422
|
+
if node.kind == "continuous":
|
|
423
|
+
node.ut._fitted = True
|
|
424
|
+
flow.load_state_dict(ckpt["state_dict"])
|
|
425
|
+
flow.history = ckpt.get("history", {"train": [], "val": []})
|
|
426
|
+
flow.eval()
|
|
427
|
+
return flow
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Synthetic-cohort generators for tramdag.
|
|
2
|
+
|
|
3
|
+
Each scenario is one module exposing a numpy-only SCM generator class with known
|
|
4
|
+
causal ground truth. New scenarios register here so experiments/tests can look
|
|
5
|
+
them up by name. Frozen CSVs live under ``data/<name>/`` and are a contract —
|
|
6
|
+
regenerate only deliberately via each module's CLI.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .carefl import Carefl4
|
|
10
|
+
from .magic_mrclean import MagicMrClean
|
|
11
|
+
from .triangle import TriangleContinuous, TriangleMixed
|
|
12
|
+
from .vaca import VacaTriangle
|
|
13
|
+
|
|
14
|
+
REGISTRY = {
|
|
15
|
+
"magic-mrclean": MagicMrClean,
|
|
16
|
+
"triangle": TriangleContinuous,
|
|
17
|
+
"triangle-mixed": TriangleMixed,
|
|
18
|
+
"vaca": VacaTriangle,
|
|
19
|
+
"carefl": Carefl4,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
__all__ = ["MagicMrClean", "TriangleContinuous", "TriangleMixed",
|
|
23
|
+
"VacaTriangle", "Carefl4", "REGISTRY"]
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""The CAREFL counterfactual benchmark DGP (TRAM-DAG paper App. C.2).
|
|
2
|
+
|
|
3
|
+
Originally from Khemakhem et al. (2021, CAREFL Fig. 5), used in the paper to
|
|
4
|
+
benchmark TRAM-DAG's L3 (counterfactual) queries (Fig. 6)::
|
|
5
|
+
|
|
6
|
+
x1, x2 ~ Laplace(0, 1/sqrt(2))
|
|
7
|
+
x3 = x1 + 0.5 x2^3 + Laplace(0, 1/sqrt(2))
|
|
8
|
+
x4 = -x2 + 0.5 x1^2 + Laplace(0, 1/sqrt(2))
|
|
9
|
+
|
|
10
|
+
All-continuous additive-noise SCM, so individual counterfactuals are **analytic**
|
|
11
|
+
via noise abduction (no Monte Carlo): eps3 = x3 - x1 - 0.5 x2^3 and
|
|
12
|
+
eps4 = x4 + x2 - 0.5 x1^2 are recovered exactly, then the mutilated SCM is
|
|
13
|
+
re-evaluated. The paper picks the observation ``X_OBS = (2.00, 1.50, 0.81, -0.28)``
|
|
14
|
+
and sweeps two queries for alpha in [-3, 3]:
|
|
15
|
+
|
|
16
|
+
(i) x3^cf given do(x2 = alpha): x1 + 0.5 alpha^3 + eps3
|
|
17
|
+
(ii) x4^cf given do(x1 = alpha): -x2 + 0.5 alpha^2 + eps4
|
|
18
|
+
|
|
19
|
+
CLI::
|
|
20
|
+
|
|
21
|
+
uv run python -m tramdag.simulations.carefl --out data/carefl --seed 42
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import argparse
|
|
27
|
+
import json
|
|
28
|
+
from dataclasses import dataclass
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
import pandas as pd
|
|
33
|
+
|
|
34
|
+
X_OBS = {"x1": 2.00, "x2": 1.50, "x3": 0.81, "x4": -0.28} # the paper's observation
|
|
35
|
+
ALPHA_GRID = np.round(np.linspace(-3.0, 3.0, 61), 4)
|
|
36
|
+
_SCALE = 1.0 / np.sqrt(2.0)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class Carefl4:
|
|
41
|
+
"""SCM generator for the 4-variable CAREFL benchmark."""
|
|
42
|
+
|
|
43
|
+
seed: int = 42
|
|
44
|
+
|
|
45
|
+
def draw_latents(self, n: int, rng: np.random.Generator) -> dict[str, np.ndarray]:
|
|
46
|
+
return {k: rng.laplace(loc=0.0, scale=_SCALE, size=n)
|
|
47
|
+
for k in ["x1", "x2", "x3", "x4"]}
|
|
48
|
+
|
|
49
|
+
def simulate(self, n: int | None = None, *, rng: np.random.Generator | None = None,
|
|
50
|
+
do: dict[str, float] | None = None,
|
|
51
|
+
latents: dict[str, np.ndarray] | None = None) -> pd.DataFrame:
|
|
52
|
+
do = do or {}
|
|
53
|
+
if latents is None:
|
|
54
|
+
if n is None:
|
|
55
|
+
raise ValueError("provide either n or latents")
|
|
56
|
+
rng = rng or np.random.default_rng(self.seed)
|
|
57
|
+
latents = self.draw_latents(n, rng)
|
|
58
|
+
n = len(latents["x1"])
|
|
59
|
+
|
|
60
|
+
def clamp_or(name, value):
|
|
61
|
+
return np.full(n, float(do[name])) if name in do else value
|
|
62
|
+
|
|
63
|
+
x1 = clamp_or("x1", latents["x1"])
|
|
64
|
+
x2 = clamp_or("x2", latents["x2"])
|
|
65
|
+
x3 = clamp_or("x3", x1 + 0.5 * x2**3 + latents["x3"])
|
|
66
|
+
x4 = clamp_or("x4", -x2 + 0.5 * x1**2 + latents["x4"])
|
|
67
|
+
return pd.DataFrame({"x1": x1, "x2": x2, "x3": x3, "x4": x4})
|
|
68
|
+
|
|
69
|
+
# ----------------------------------------------------------------- datasets
|
|
70
|
+
def observational(self, n: int, seed_offset: int = 0) -> pd.DataFrame:
|
|
71
|
+
rng = np.random.default_rng(self.seed + 1 + seed_offset)
|
|
72
|
+
return self.simulate(n, rng=rng)
|
|
73
|
+
|
|
74
|
+
# -------------------------------------------------------------- ground truth
|
|
75
|
+
@staticmethod
|
|
76
|
+
def abduct_noise(obs: dict[str, float] | pd.DataFrame) -> dict[str, np.ndarray]:
|
|
77
|
+
"""Exact noise values consistent with an observation (vectorized)."""
|
|
78
|
+
x1, x2 = np.asarray(obs["x1"], float), np.asarray(obs["x2"], float)
|
|
79
|
+
x3, x4 = np.asarray(obs["x3"], float), np.asarray(obs["x4"], float)
|
|
80
|
+
return {"x1": x1, "x2": x2,
|
|
81
|
+
"x3": x3 - x1 - 0.5 * x2**3,
|
|
82
|
+
"x4": x4 + x2 - 0.5 * x1**2}
|
|
83
|
+
|
|
84
|
+
def true_counterfactual(self, obs: dict[str, float],
|
|
85
|
+
do: dict[str, float]) -> dict[str, float]:
|
|
86
|
+
"""Analytic counterfactual of a single observation under ``do``."""
|
|
87
|
+
eps = self.abduct_noise({k: np.atleast_1d(v) for k, v in obs.items()})
|
|
88
|
+
cf = self.simulate(do=do, latents=eps)
|
|
89
|
+
return {k: float(cf[k].iloc[0]) for k in cf}
|
|
90
|
+
|
|
91
|
+
def true_cf_curves(self, obs: dict[str, float] = X_OBS,
|
|
92
|
+
alphas: np.ndarray = ALPHA_GRID) -> dict:
|
|
93
|
+
"""The paper's two Fig.-6 curves, analytic."""
|
|
94
|
+
x3_cf = [self.true_counterfactual(obs, {"x2": a})["x3"] for a in alphas]
|
|
95
|
+
x4_cf = [self.true_counterfactual(obs, {"x1": a})["x4"] for a in alphas]
|
|
96
|
+
return {"x_obs": dict(obs), "alphas": [float(a) for a in alphas],
|
|
97
|
+
"x3_cf_do_x2": x3_cf, "x4_cf_do_x1": x4_cf}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# --------------------------------------------------------------------------- CLI
|
|
101
|
+
def main(argv: list[str] | None = None) -> None:
|
|
102
|
+
p = argparse.ArgumentParser(description="Generate the CAREFL benchmark data.")
|
|
103
|
+
p.add_argument("--out", type=Path, default=Path("data/carefl"))
|
|
104
|
+
p.add_argument("--seed", type=int, default=42)
|
|
105
|
+
p.add_argument("--n-obs", type=int, default=5000)
|
|
106
|
+
args = p.parse_args(argv)
|
|
107
|
+
|
|
108
|
+
gen = Carefl4(seed=args.seed)
|
|
109
|
+
args.out.mkdir(parents=True, exist_ok=True)
|
|
110
|
+
obs = gen.observational(args.n_obs)
|
|
111
|
+
obs.to_csv(args.out / "obs.csv", index=False)
|
|
112
|
+
|
|
113
|
+
truth = {"source": "arXiv:2503.16206 App. C.2 (orig. Khemakhem 2021 Fig. 5)",
|
|
114
|
+
"seed": args.seed, "n_obs": args.n_obs,
|
|
115
|
+
"scm": {"x1": "Laplace(0, 1/sqrt(2))", "x2": "Laplace(0, 1/sqrt(2))",
|
|
116
|
+
"x3": "x1 + 0.5*x2^3 + Laplace", "x4": "-x2 + 0.5*x1^2 + Laplace"},
|
|
117
|
+
**gen.true_cf_curves()}
|
|
118
|
+
(args.out / "truth.json").write_text(json.dumps(truth, indent=2) + "\n")
|
|
119
|
+
print(f"[carefl] n={len(obs)} cf sanity: "
|
|
120
|
+
f"x3_cf(do x2=1.5)={gen.true_counterfactual(X_OBS, {'x2': 1.5})['x3']:.3f} "
|
|
121
|
+
f"(factual {X_OBS['x3']})")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
if __name__ == "__main__":
|
|
125
|
+
main()
|