openstat-cli 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstat/__init__.py +3 -0
- openstat/__main__.py +4 -0
- openstat/backends/__init__.py +16 -0
- openstat/backends/duckdb_backend.py +70 -0
- openstat/backends/polars_backend.py +52 -0
- openstat/cli.py +92 -0
- openstat/commands/__init__.py +82 -0
- openstat/commands/adv_stat_cmds.py +1255 -0
- openstat/commands/advanced_ml_cmds.py +576 -0
- openstat/commands/advreg_cmds.py +207 -0
- openstat/commands/alias_cmds.py +135 -0
- openstat/commands/arch_cmds.py +82 -0
- openstat/commands/arules_cmds.py +111 -0
- openstat/commands/automodel_cmds.py +212 -0
- openstat/commands/backend_cmds.py +82 -0
- openstat/commands/base.py +170 -0
- openstat/commands/bayes_cmds.py +71 -0
- openstat/commands/causal_cmds.py +269 -0
- openstat/commands/cluster_cmds.py +152 -0
- openstat/commands/data_cmds.py +996 -0
- openstat/commands/datamanip_cmds.py +672 -0
- openstat/commands/dataquality_cmds.py +174 -0
- openstat/commands/datetime_cmds.py +176 -0
- openstat/commands/dimreduce_cmds.py +184 -0
- openstat/commands/discrete_cmds.py +149 -0
- openstat/commands/dsl_cmds.py +143 -0
- openstat/commands/epi_cmds.py +93 -0
- openstat/commands/equiv_tobit_cmds.py +94 -0
- openstat/commands/esttab_cmds.py +196 -0
- openstat/commands/export_beamer_cmds.py +142 -0
- openstat/commands/export_cmds.py +201 -0
- openstat/commands/export_extra_cmds.py +240 -0
- openstat/commands/factor_cmds.py +180 -0
- openstat/commands/groupby_cmds.py +155 -0
- openstat/commands/help_cmds.py +237 -0
- openstat/commands/i18n_cmds.py +43 -0
- openstat/commands/import_extra_cmds.py +561 -0
- openstat/commands/influence_cmds.py +134 -0
- openstat/commands/iv_cmds.py +106 -0
- openstat/commands/manova_cmds.py +105 -0
- openstat/commands/mediate_cmds.py +233 -0
- openstat/commands/meta_cmds.py +284 -0
- openstat/commands/mi_cmds.py +228 -0
- openstat/commands/mixed_cmds.py +79 -0
- openstat/commands/mixture_changepoint_cmds.py +166 -0
- openstat/commands/ml_adv_cmds.py +147 -0
- openstat/commands/ml_cmds.py +178 -0
- openstat/commands/model_eval_cmds.py +142 -0
- openstat/commands/network_cmds.py +288 -0
- openstat/commands/nlquery_cmds.py +161 -0
- openstat/commands/nonparam_cmds.py +149 -0
- openstat/commands/outreg_cmds.py +247 -0
- openstat/commands/panel_cmds.py +141 -0
- openstat/commands/pdf_cmds.py +226 -0
- openstat/commands/pipeline_cmds.py +319 -0
- openstat/commands/plot_cmds.py +189 -0
- openstat/commands/plugin_cmds.py +79 -0
- openstat/commands/posthoc_cmds.py +153 -0
- openstat/commands/power_cmds.py +172 -0
- openstat/commands/profile_cmds.py +246 -0
- openstat/commands/rbridge_cmds.py +81 -0
- openstat/commands/regex_cmds.py +104 -0
- openstat/commands/report_cmds.py +48 -0
- openstat/commands/repro_cmds.py +129 -0
- openstat/commands/resampling_cmds.py +109 -0
- openstat/commands/reshape_cmds.py +223 -0
- openstat/commands/sem_cmds.py +177 -0
- openstat/commands/stat_cmds.py +1040 -0
- openstat/commands/stata_import_cmds.py +215 -0
- openstat/commands/string_cmds.py +124 -0
- openstat/commands/surv_cmds.py +145 -0
- openstat/commands/survey_cmds.py +153 -0
- openstat/commands/textanalysis_cmds.py +192 -0
- openstat/commands/ts_adv_cmds.py +136 -0
- openstat/commands/ts_cmds.py +195 -0
- openstat/commands/tui_cmds.py +111 -0
- openstat/commands/ux_cmds.py +191 -0
- openstat/commands/validate_cmds.py +270 -0
- openstat/commands/viz_adv_cmds.py +312 -0
- openstat/commands/viz_extra_cmds.py +251 -0
- openstat/commands/watch_cmds.py +69 -0
- openstat/config.py +106 -0
- openstat/dsl/__init__.py +0 -0
- openstat/dsl/parser.py +332 -0
- openstat/dsl/tokenizer.py +105 -0
- openstat/i18n.py +120 -0
- openstat/io/__init__.py +0 -0
- openstat/io/loader.py +187 -0
- openstat/jupyter/__init__.py +18 -0
- openstat/jupyter/display.py +18 -0
- openstat/jupyter/magic.py +60 -0
- openstat/logging_config.py +59 -0
- openstat/plots/__init__.py +0 -0
- openstat/plots/plotter.py +437 -0
- openstat/plots/surv_plots.py +32 -0
- openstat/plots/ts_plots.py +59 -0
- openstat/plugins/__init__.py +5 -0
- openstat/plugins/manager.py +69 -0
- openstat/repl.py +457 -0
- openstat/reporting/__init__.py +0 -0
- openstat/reporting/eda.py +208 -0
- openstat/reporting/report.py +67 -0
- openstat/script_runner.py +319 -0
- openstat/session.py +133 -0
- openstat/stats/__init__.py +0 -0
- openstat/stats/advanced_regression.py +269 -0
- openstat/stats/arch_garch.py +84 -0
- openstat/stats/bayesian.py +103 -0
- openstat/stats/causal.py +258 -0
- openstat/stats/clustering.py +206 -0
- openstat/stats/discrete.py +311 -0
- openstat/stats/epidemiology.py +119 -0
- openstat/stats/equiv_tobit.py +163 -0
- openstat/stats/factor.py +174 -0
- openstat/stats/imputation.py +282 -0
- openstat/stats/influence.py +78 -0
- openstat/stats/iv.py +131 -0
- openstat/stats/manova.py +124 -0
- openstat/stats/mixed.py +128 -0
- openstat/stats/ml.py +275 -0
- openstat/stats/ml_advanced.py +117 -0
- openstat/stats/model_eval.py +183 -0
- openstat/stats/models.py +1342 -0
- openstat/stats/nonparametric.py +130 -0
- openstat/stats/panel.py +179 -0
- openstat/stats/power.py +295 -0
- openstat/stats/resampling.py +203 -0
- openstat/stats/survey.py +213 -0
- openstat/stats/survival.py +196 -0
- openstat/stats/timeseries.py +142 -0
- openstat/stats/ts_advanced.py +114 -0
- openstat/types.py +11 -0
- openstat/web/__init__.py +1 -0
- openstat/web/app.py +117 -0
- openstat/web/session_manager.py +73 -0
- openstat/web/static/app.js +117 -0
- openstat/web/static/index.html +38 -0
- openstat/web/static/style.css +103 -0
- openstat_cli-1.0.0.dist-info/METADATA +748 -0
- openstat_cli-1.0.0.dist-info/RECORD +143 -0
- openstat_cli-1.0.0.dist-info/WHEEL +4 -0
- openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
- openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1255 @@
|
|
|
1
|
+
"""Advanced statistical commands: irt, competing, cate, joinpoint, spline."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from openstat.commands.base import command, CommandArgs, friendly_error
|
|
6
|
+
from openstat.session import Session
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# ---------------------------------------------------------------------------
|
|
10
|
+
# Shared helpers
|
|
11
|
+
# ---------------------------------------------------------------------------
|
|
12
|
+
|
|
13
|
+
def _sep(width: int = 60) -> str:
|
|
14
|
+
return "=" * width
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _sig_stars(p: float) -> str:
|
|
18
|
+
if p < 0.001:
|
|
19
|
+
return "***"
|
|
20
|
+
if p < 0.01:
|
|
21
|
+
return "**"
|
|
22
|
+
if p < 0.05:
|
|
23
|
+
return "*"
|
|
24
|
+
if p < 0.10:
|
|
25
|
+
return "."
|
|
26
|
+
return ""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _coef_table(rows: list[tuple], headers: list[str]) -> str:
|
|
30
|
+
"""Render a simple fixed-width table.
|
|
31
|
+
|
|
32
|
+
rows : list of tuples whose elements are already formatted strings.
|
|
33
|
+
headers: list of column header strings.
|
|
34
|
+
"""
|
|
35
|
+
col_widths = [max(len(h), max((len(str(r[i])) for r in rows), default=0))
|
|
36
|
+
for i, h in enumerate(headers)]
|
|
37
|
+
fmt = " " + " ".join(f"{{:<{w}}}" for w in col_widths)
|
|
38
|
+
lines = [fmt.format(*headers), " " + "-" * (sum(col_widths) + 2 * len(col_widths))]
|
|
39
|
+
for row in rows:
|
|
40
|
+
lines.append(fmt.format(*[str(x) for x in row]))
|
|
41
|
+
return "\n".join(lines)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
# 1. IRT — Item Response Theory
|
|
46
|
+
# ---------------------------------------------------------------------------
|
|
47
|
+
|
|
48
|
+
@command("irt", usage="irt <item1> [item2 ...] [--model=1pl|2pl|3pl]")
|
|
49
|
+
def cmd_irt(session: Session, args: str) -> str:
|
|
50
|
+
"""Item Response Theory: estimate discrimination and difficulty parameters.
|
|
51
|
+
|
|
52
|
+
Implements 2PL (default) manually via scipy.optimize or falls back to
|
|
53
|
+
per-item logistic regression approximation when scipy is unavailable.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
irt q1 q2 q3 q4 q5
|
|
57
|
+
irt q1 q2 q3 --model=1pl
|
|
58
|
+
irt q1 q2 q3 q4 q5 --model=2pl
|
|
59
|
+
"""
|
|
60
|
+
import numpy as np
|
|
61
|
+
|
|
62
|
+
ca = CommandArgs(args)
|
|
63
|
+
items = [p for p in ca.positional if not p.startswith("--")]
|
|
64
|
+
model = ca.options.get("model", "2pl").lower()
|
|
65
|
+
|
|
66
|
+
if not items:
|
|
67
|
+
return "Usage: irt <item1> [item2 ...] [--model=1pl|2pl|3pl]"
|
|
68
|
+
|
|
69
|
+
df = session.require_data()
|
|
70
|
+
|
|
71
|
+
missing = [c for c in items if c not in df.columns]
|
|
72
|
+
if missing:
|
|
73
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
sub = df.select(items).drop_nulls()
|
|
77
|
+
data = sub.to_numpy().astype(float)
|
|
78
|
+
n_persons, n_items = data.shape
|
|
79
|
+
|
|
80
|
+
if n_persons < 10:
|
|
81
|
+
return "IRT requires at least 10 complete observations."
|
|
82
|
+
|
|
83
|
+
# Check all items are binary (0/1)
|
|
84
|
+
for j, item in enumerate(items):
|
|
85
|
+
vals = np.unique(data[:, j])
|
|
86
|
+
non_binary = [v for v in vals if v not in (0.0, 1.0)]
|
|
87
|
+
if non_binary:
|
|
88
|
+
return (
|
|
89
|
+
f"Column '{item}' contains non-binary values {non_binary[:3]}. "
|
|
90
|
+
"IRT expects binary (0/1) item responses."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
from scipy.optimize import minimize
|
|
95
|
+
_has_scipy = True
|
|
96
|
+
except ImportError:
|
|
97
|
+
_has_scipy = False
|
|
98
|
+
|
|
99
|
+
# ------------------------------------------------------------------
|
|
100
|
+
# 2PL / 1PL via EM-like marginal maximum likelihood approximation
|
|
101
|
+
# We use a simplified approach: for each item, treat the sum score
|
|
102
|
+
# as a proxy for ability theta, then fit a logistic curve.
|
|
103
|
+
# ------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
# Ability proxy: standardised sum score
|
|
106
|
+
raw_scores = data.sum(axis=1)
|
|
107
|
+
theta = (raw_scores - raw_scores.mean()) / (raw_scores.std() + 1e-12)
|
|
108
|
+
|
|
109
|
+
def _2pl_loglik(params, y, theta_vals):
|
|
110
|
+
a, b = params
|
|
111
|
+
# constrain a > 0 via soft barrier
|
|
112
|
+
if a <= 0:
|
|
113
|
+
return 1e9
|
|
114
|
+
p = 1.0 / (1.0 + np.exp(-a * (theta_vals - b)))
|
|
115
|
+
p = np.clip(p, 1e-9, 1 - 1e-9)
|
|
116
|
+
return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))
|
|
117
|
+
|
|
118
|
+
def _1pl_loglik(params, y, theta_vals):
|
|
119
|
+
b = params[0]
|
|
120
|
+
p = 1.0 / (1.0 + np.exp(-(theta_vals - b)))
|
|
121
|
+
p = np.clip(p, 1e-9, 1 - 1e-9)
|
|
122
|
+
return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))
|
|
123
|
+
|
|
124
|
+
item_params = []
|
|
125
|
+
item_info_at_b = []
|
|
126
|
+
|
|
127
|
+
for j, item in enumerate(items):
|
|
128
|
+
y_j = data[:, j]
|
|
129
|
+
p_bar = y_j.mean()
|
|
130
|
+
|
|
131
|
+
if _has_scipy:
|
|
132
|
+
if model == "1pl":
|
|
133
|
+
# Difficulty only (Rasch-like, a fixed at 1.0)
|
|
134
|
+
b0 = np.log(p_bar / (1 - p_bar + 1e-12))
|
|
135
|
+
res = minimize(_1pl_loglik, x0=[b0], args=(y_j, theta),
|
|
136
|
+
method="Nelder-Mead",
|
|
137
|
+
options={"maxiter": 2000, "xatol": 1e-5})
|
|
138
|
+
a_hat = 1.0
|
|
139
|
+
b_hat = float(res.x[0])
|
|
140
|
+
else:
|
|
141
|
+
# 2PL
|
|
142
|
+
b0 = -np.log(p_bar / (1 - p_bar + 1e-12))
|
|
143
|
+
res = minimize(_2pl_loglik, x0=[1.0, b0], args=(y_j, theta),
|
|
144
|
+
method="Nelder-Mead",
|
|
145
|
+
options={"maxiter": 3000, "xatol": 1e-5, "fatol": 1e-5})
|
|
146
|
+
a_hat = max(float(res.x[0]), 0.01)
|
|
147
|
+
b_hat = float(res.x[1])
|
|
148
|
+
else:
|
|
149
|
+
# Fallback: logistic regression approximation
|
|
150
|
+
try:
|
|
151
|
+
from sklearn.linear_model import LogisticRegression
|
|
152
|
+
lr = LogisticRegression(max_iter=500, C=1e6)
|
|
153
|
+
lr.fit(theta.reshape(-1, 1), y_j.astype(int))
|
|
154
|
+
a_hat = float(lr.coef_[0][0])
|
|
155
|
+
b_hat = -float(lr.intercept_[0]) / (a_hat + 1e-12)
|
|
156
|
+
if model == "1pl":
|
|
157
|
+
a_hat = 1.0
|
|
158
|
+
except ImportError:
|
|
159
|
+
# Last resort: method of moments
|
|
160
|
+
a_hat = 1.0
|
|
161
|
+
b_hat = -np.log(p_bar / (1 - p_bar + 1e-12))
|
|
162
|
+
|
|
163
|
+
# Item information at difficulty b
|
|
164
|
+
# I(theta) = a^2 * P(theta) * (1 - P(theta))
|
|
165
|
+
p_at_b = 0.25 # P(b) = 0.5, so info = a^2 * 0.25
|
|
166
|
+
info = a_hat ** 2 * p_at_b
|
|
167
|
+
|
|
168
|
+
# Empirical fit: proportion correct
|
|
169
|
+
p_obs = float(y_j.mean())
|
|
170
|
+
|
|
171
|
+
item_params.append((item, a_hat, b_hat, p_obs, info))
|
|
172
|
+
item_info_at_b.append(info)
|
|
173
|
+
|
|
174
|
+
# ------------------------------------------------------------------
|
|
175
|
+
# Build output
|
|
176
|
+
# ------------------------------------------------------------------
|
|
177
|
+
model_label = model.upper()
|
|
178
|
+
rows = []
|
|
179
|
+
for item, a, b, p_obs, info in item_params:
|
|
180
|
+
rows.append((
|
|
181
|
+
item,
|
|
182
|
+
f"{a:.4f}",
|
|
183
|
+
f"{b:.4f}",
|
|
184
|
+
f"{p_obs:.4f}",
|
|
185
|
+
f"{info:.4f}",
|
|
186
|
+
))
|
|
187
|
+
|
|
188
|
+
header_line = (
|
|
189
|
+
f"\nIRT {model_label} — {n_items} items, {n_persons} persons\n"
|
|
190
|
+
+ _sep()
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
tbl = _coef_table(
|
|
194
|
+
rows,
|
|
195
|
+
["Item", "Discrim (a)", "Difficulty (b)", "P(correct)", "Info@b"],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Test information function: TIF = sum of item info across ability range
|
|
199
|
+
theta_grid = np.linspace(-3, 3, 61)
|
|
200
|
+
tif = np.zeros_like(theta_grid)
|
|
201
|
+
for _, a, b, _, _ in item_params:
|
|
202
|
+
p = 1.0 / (1.0 + np.exp(-a * (theta_grid - b)))
|
|
203
|
+
tif += a ** 2 * p * (1 - p)
|
|
204
|
+
|
|
205
|
+
tif_max_idx = int(np.argmax(tif))
|
|
206
|
+
tif_max_theta = float(theta_grid[tif_max_idx])
|
|
207
|
+
tif_max_val = float(tif[tif_max_idx])
|
|
208
|
+
|
|
209
|
+
reliability_approx = float(np.mean(tif) / (np.mean(tif) + 1.0))
|
|
210
|
+
|
|
211
|
+
summary_lines = [
|
|
212
|
+
"",
|
|
213
|
+
"Test Information Summary:",
|
|
214
|
+
f" Peak information : {tif_max_val:.4f} at theta = {tif_max_theta:.2f}",
|
|
215
|
+
f" Mean information : {np.mean(tif):.4f}",
|
|
216
|
+
f" Marginal reliability (approx) : {reliability_approx:.4f}",
|
|
217
|
+
"",
|
|
218
|
+
"Note: Ability (theta) estimated from standardised sum score.",
|
|
219
|
+
]
|
|
220
|
+
if not _has_scipy:
|
|
221
|
+
summary_lines.append(
|
|
222
|
+
"Note: scipy not found; used logistic regression / moments approximation."
|
|
223
|
+
)
|
|
224
|
+
if model == "3pl":
|
|
225
|
+
summary_lines.append(
|
|
226
|
+
"Note: 3PL guessing parameter not estimated; showing 2PL results."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return header_line + "\n" + tbl + "\n" + _sep() + "\n" + "\n".join(summary_lines)
|
|
230
|
+
|
|
231
|
+
except Exception as e:
|
|
232
|
+
return friendly_error(e, "irt")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# ---------------------------------------------------------------------------
|
|
236
|
+
# 2. Competing Risks Regression
|
|
237
|
+
# ---------------------------------------------------------------------------
|
|
238
|
+
|
|
239
|
+
@command("competing", usage="competing <time> <event> <cause> [covars...]")
|
|
240
|
+
def cmd_competing(session: Session, args: str) -> str:
|
|
241
|
+
"""Fine-Gray competing risks regression and cumulative incidence curves.
|
|
242
|
+
|
|
243
|
+
Fits a cumulative incidence function (CIF) for each cause using lifelines,
|
|
244
|
+
or falls back to a manual Nelson-Aalen-based estimate if lifelines is
|
|
245
|
+
unavailable.
|
|
246
|
+
|
|
247
|
+
Examples:
|
|
248
|
+
competing time status cause
|
|
249
|
+
competing time status cause age gender
|
|
250
|
+
"""
|
|
251
|
+
import numpy as np
|
|
252
|
+
|
|
253
|
+
ca = CommandArgs(args)
|
|
254
|
+
pos = [p for p in ca.positional if not p.startswith("--")]
|
|
255
|
+
|
|
256
|
+
if len(pos) < 3:
|
|
257
|
+
return "Usage: competing <time> <event> <cause> [covars...]"
|
|
258
|
+
|
|
259
|
+
time_col, event_col, cause_col = pos[0], pos[1], pos[2]
|
|
260
|
+
covars = pos[3:]
|
|
261
|
+
|
|
262
|
+
df = session.require_data()
|
|
263
|
+
|
|
264
|
+
needed = [time_col, event_col, cause_col] + covars
|
|
265
|
+
missing = [c for c in needed if c not in df.columns]
|
|
266
|
+
if missing:
|
|
267
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
sub = df.select(needed).drop_nulls()
|
|
271
|
+
T = sub[time_col].to_numpy().astype(float)
|
|
272
|
+
E = sub[event_col].to_numpy().astype(float)
|
|
273
|
+
C = sub[cause_col].to_numpy()
|
|
274
|
+
|
|
275
|
+
causes = sorted(set(C.tolist()))
|
|
276
|
+
n_total = len(T)
|
|
277
|
+
|
|
278
|
+
lines = [
|
|
279
|
+
f"\nCompeting Risks Analysis",
|
|
280
|
+
_sep(),
|
|
281
|
+
f" Time var : {time_col}",
|
|
282
|
+
f" Event var : {event_col}",
|
|
283
|
+
f" Cause var : {cause_col}",
|
|
284
|
+
f" N : {n_total}",
|
|
285
|
+
f" Causes : {causes}",
|
|
286
|
+
"",
|
|
287
|
+
]
|
|
288
|
+
|
|
289
|
+
# ------------------------------------------------------------------
|
|
290
|
+
# Try lifelines for CIF and Fine-Gray
|
|
291
|
+
# ------------------------------------------------------------------
|
|
292
|
+
try:
|
|
293
|
+
from lifelines import AalenJohansenFitter
|
|
294
|
+
_has_lifelines = True
|
|
295
|
+
except ImportError:
|
|
296
|
+
_has_lifelines = False
|
|
297
|
+
|
|
298
|
+
cif_results = {}
|
|
299
|
+
|
|
300
|
+
if _has_lifelines:
|
|
301
|
+
for cause in causes:
|
|
302
|
+
event_of_interest = (C == cause).astype(int)
|
|
303
|
+
ajf = AalenJohansenFitter(calculate_variance=True)
|
|
304
|
+
try:
|
|
305
|
+
ajf.fit(T, E, event_col=event_of_interest)
|
|
306
|
+
cif_t = ajf.cumulative_density_
|
|
307
|
+
cif_results[cause] = ajf
|
|
308
|
+
t_max = float(T.max())
|
|
309
|
+
cif_at_max = float(cif_t.values[-1])
|
|
310
|
+
n_events = int((C == cause).sum())
|
|
311
|
+
lines.append(
|
|
312
|
+
f" Cause {cause}: {n_events} events, "
|
|
313
|
+
f"CIF at t={t_max:.1f} = {cif_at_max:.4f}"
|
|
314
|
+
)
|
|
315
|
+
except Exception as fit_err:
|
|
316
|
+
lines.append(f" Cause {cause}: CIF fit failed — {fit_err}")
|
|
317
|
+
else:
|
|
318
|
+
# Manual Aalen-Johansen CIF estimate
|
|
319
|
+
lines.append(" lifelines not found; computing manual CIF estimates.")
|
|
320
|
+
sort_idx = np.argsort(T)
|
|
321
|
+
T_s = T[sort_idx]
|
|
322
|
+
E_s = E[sort_idx]
|
|
323
|
+
C_s = C[sort_idx]
|
|
324
|
+
|
|
325
|
+
n_at_risk = n_total
|
|
326
|
+
S = 1.0 # overall survival
|
|
327
|
+
|
|
328
|
+
for cause in causes:
|
|
329
|
+
cif_vals = []
|
|
330
|
+
cif_times = [0.0]
|
|
331
|
+
cif_running = [0.0]
|
|
332
|
+
S_running = 1.0
|
|
333
|
+
n_r = n_total
|
|
334
|
+
|
|
335
|
+
for i, (t_i, e_i, c_i) in enumerate(zip(T_s, E_s, C_s)):
|
|
336
|
+
if e_i == 1:
|
|
337
|
+
d_j = int(c_i == cause)
|
|
338
|
+
d_total = 1
|
|
339
|
+
hazard_cause = d_j / n_r
|
|
340
|
+
hazard_all = d_total / n_r
|
|
341
|
+
cif_running_new = cif_running[-1] + S_running * hazard_cause
|
|
342
|
+
S_running = S_running * (1 - hazard_all)
|
|
343
|
+
cif_times.append(float(t_i))
|
|
344
|
+
cif_running.append(float(cif_running_new))
|
|
345
|
+
n_r = max(n_r - 1, 1)
|
|
346
|
+
|
|
347
|
+
n_events = int((C == cause).sum())
|
|
348
|
+
cif_final = cif_running[-1] if cif_running else 0.0
|
|
349
|
+
lines.append(
|
|
350
|
+
f" Cause {cause}: {n_events} events, "
|
|
351
|
+
f"CIF at t_max = {cif_final:.4f}"
|
|
352
|
+
)
|
|
353
|
+
cif_results[cause] = (cif_times, cif_running)
|
|
354
|
+
|
|
355
|
+
# ------------------------------------------------------------------
|
|
356
|
+
# Covariate association (subdistribution hazard approximation)
|
|
357
|
+
# ------------------------------------------------------------------
|
|
358
|
+
if covars:
|
|
359
|
+
lines.append("")
|
|
360
|
+
lines.append("Covariate Association (cause-specific logrank proxy):")
|
|
361
|
+
try:
|
|
362
|
+
from scipy.stats import pearsonr
|
|
363
|
+
X_cov = sub.select(covars).to_numpy().astype(float)
|
|
364
|
+
for cause in causes:
|
|
365
|
+
indicator = (C == cause).astype(float)
|
|
366
|
+
lines.append(f" Cause {cause}:")
|
|
367
|
+
for j, cv in enumerate(covars):
|
|
368
|
+
if X_cov.shape[0] > 2:
|
|
369
|
+
r, p = pearsonr(X_cov[:, j], indicator)
|
|
370
|
+
lines.append(
|
|
371
|
+
f" {cv:<20} r={r:.4f} p={p:.4f}{_sig_stars(p)}"
|
|
372
|
+
)
|
|
373
|
+
except ImportError:
|
|
374
|
+
lines.append(" scipy not found; skipping covariate associations.")
|
|
375
|
+
except Exception as cov_err:
|
|
376
|
+
lines.append(f" Covariate analysis error: {cov_err}")
|
|
377
|
+
|
|
378
|
+
# ------------------------------------------------------------------
|
|
379
|
+
# Plot CIF curves
|
|
380
|
+
# ------------------------------------------------------------------
|
|
381
|
+
try:
|
|
382
|
+
import matplotlib
|
|
383
|
+
matplotlib.use("Agg")
|
|
384
|
+
import matplotlib.pyplot as plt
|
|
385
|
+
|
|
386
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
387
|
+
colors = plt.cm.tab10.colors
|
|
388
|
+
|
|
389
|
+
if _has_lifelines:
|
|
390
|
+
for i, (cause, ajf) in enumerate(cif_results.items()):
|
|
391
|
+
cif_df = ajf.cumulative_density_
|
|
392
|
+
ax.step(
|
|
393
|
+
cif_df.index,
|
|
394
|
+
cif_df.values[:, 0],
|
|
395
|
+
where="post",
|
|
396
|
+
label=f"Cause {cause}",
|
|
397
|
+
color=colors[i % len(colors)],
|
|
398
|
+
linewidth=2,
|
|
399
|
+
)
|
|
400
|
+
else:
|
|
401
|
+
for i, (cause, (ct, cv)) in enumerate(cif_results.items()):
|
|
402
|
+
ax.step(
|
|
403
|
+
ct, cv,
|
|
404
|
+
where="post",
|
|
405
|
+
label=f"Cause {cause}",
|
|
406
|
+
color=colors[i % len(colors)],
|
|
407
|
+
linewidth=2,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
ax.set_xlabel(f"Time ({time_col})")
|
|
411
|
+
ax.set_ylabel("Cumulative Incidence")
|
|
412
|
+
ax.set_title("Cumulative Incidence Functions")
|
|
413
|
+
ax.legend()
|
|
414
|
+
ax.set_ylim(0, 1)
|
|
415
|
+
ax.grid(alpha=0.3)
|
|
416
|
+
fig.tight_layout()
|
|
417
|
+
|
|
418
|
+
session.output_dir.mkdir(parents=True, exist_ok=True)
|
|
419
|
+
plot_path = session.output_dir / "competing_risks_cif.png"
|
|
420
|
+
fig.savefig(plot_path, dpi=150)
|
|
421
|
+
plt.close(fig)
|
|
422
|
+
session.plot_paths.append(str(plot_path))
|
|
423
|
+
lines.append(f"\nPlot saved: {plot_path}")
|
|
424
|
+
|
|
425
|
+
except Exception as plot_err:
|
|
426
|
+
lines.append(f"\nPlot error: {plot_err}")
|
|
427
|
+
|
|
428
|
+
lines.append(_sep())
|
|
429
|
+
return "\n".join(lines)
|
|
430
|
+
|
|
431
|
+
except Exception as e:
|
|
432
|
+
return friendly_error(e, "competing")
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
# ---------------------------------------------------------------------------
|
|
436
|
+
# 3. CATE — Conditional Average Treatment Effects
|
|
437
|
+
# ---------------------------------------------------------------------------
|
|
438
|
+
|
|
439
|
+
@command("cate", usage="cate <y> <treat> <x1> [x2 ...] [--method=xlearner|drlearner|tlearner]")
|
|
440
|
+
def cmd_cate(session: Session, args: str) -> str:
|
|
441
|
+
"""Conditional Average Treatment Effects via meta-learners.
|
|
442
|
+
|
|
443
|
+
Implements T-Learner, X-Learner, and DR-Learner using sklearn.
|
|
444
|
+
|
|
445
|
+
Methods:
|
|
446
|
+
tlearner : Fit E[Y|T=1,X] and E[Y|T=0,X] separately; CATE = mu1 - mu0.
|
|
447
|
+
xlearner : Cross-fitting with imputed potential outcomes.
|
|
448
|
+
drlearner : Doubly robust estimation combining outcome and propensity models.
|
|
449
|
+
|
|
450
|
+
Examples:
|
|
451
|
+
cate outcome treatment age educ income --method=tlearner
|
|
452
|
+
cate outcome treatment age educ --method=xlearner
|
|
453
|
+
cate outcome treatment age educ income --method=drlearner
|
|
454
|
+
"""
|
|
455
|
+
import numpy as np
|
|
456
|
+
|
|
457
|
+
ca = CommandArgs(args)
|
|
458
|
+
pos = [p for p in ca.positional if not p.startswith("--")]
|
|
459
|
+
method = ca.options.get("method", "tlearner").lower()
|
|
460
|
+
|
|
461
|
+
if len(pos) < 3:
|
|
462
|
+
return "Usage: cate <y> <treat> <x1> [x2 ...] [--method=tlearner|xlearner|drlearner]"
|
|
463
|
+
|
|
464
|
+
y_col = pos[0]
|
|
465
|
+
treat_col = pos[1]
|
|
466
|
+
x_cols = pos[2:]
|
|
467
|
+
|
|
468
|
+
df = session.require_data()
|
|
469
|
+
|
|
470
|
+
needed = [y_col, treat_col] + x_cols
|
|
471
|
+
missing = [c for c in needed if c not in df.columns]
|
|
472
|
+
if missing:
|
|
473
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
try:
|
|
477
|
+
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
|
478
|
+
from sklearn.linear_model import LogisticRegression, Ridge
|
|
479
|
+
_has_sklearn = True
|
|
480
|
+
except ImportError:
|
|
481
|
+
_has_sklearn = False
|
|
482
|
+
return "sklearn not installed. Run: pip install scikit-learn"
|
|
483
|
+
|
|
484
|
+
sub = df.select(needed).drop_nulls()
|
|
485
|
+
Y = sub[y_col].to_numpy().astype(float)
|
|
486
|
+
T = sub[treat_col].to_numpy().astype(float)
|
|
487
|
+
X = sub.select(x_cols).to_numpy().astype(float)
|
|
488
|
+
|
|
489
|
+
n = len(Y)
|
|
490
|
+
n_treated = int(T.sum())
|
|
491
|
+
n_control = int((1 - T).sum())
|
|
492
|
+
|
|
493
|
+
if n_treated < 5 or n_control < 5:
|
|
494
|
+
return "Need at least 5 treated and 5 control observations."
|
|
495
|
+
|
|
496
|
+
idx_t = T == 1
|
|
497
|
+
idx_c = T == 0
|
|
498
|
+
|
|
499
|
+
# Base learner: Ridge regression (fast, works on small data)
|
|
500
|
+
def _base_learner():
|
|
501
|
+
return Ridge(alpha=1.0)
|
|
502
|
+
|
|
503
|
+
def _prop_learner():
|
|
504
|
+
return LogisticRegression(max_iter=500, C=1.0)
|
|
505
|
+
|
|
506
|
+
cate_estimates = None
|
|
507
|
+
|
|
508
|
+
if method == "tlearner":
|
|
509
|
+
# T-Learner: two separate outcome models
|
|
510
|
+
mu1_model = _base_learner()
|
|
511
|
+
mu0_model = _base_learner()
|
|
512
|
+
mu1_model.fit(X[idx_t], Y[idx_t])
|
|
513
|
+
mu0_model.fit(X[idx_c], Y[idx_c])
|
|
514
|
+
mu1_hat = mu1_model.predict(X)
|
|
515
|
+
mu0_hat = mu0_model.predict(X)
|
|
516
|
+
cate_estimates = mu1_hat - mu0_hat
|
|
517
|
+
method_desc = "T-Learner (separate outcome models)"
|
|
518
|
+
|
|
519
|
+
elif method == "xlearner":
|
|
520
|
+
# X-Learner: impute counterfactual outcomes, cross-fit
|
|
521
|
+
mu1_model = _base_learner()
|
|
522
|
+
mu0_model = _base_learner()
|
|
523
|
+
mu1_model.fit(X[idx_t], Y[idx_t])
|
|
524
|
+
mu0_model.fit(X[idx_c], Y[idx_c])
|
|
525
|
+
|
|
526
|
+
# Imputed individual effects
|
|
527
|
+
D1 = Y[idx_t] - mu0_model.predict(X[idx_t]) # treated: Y1 - mu0(X)
|
|
528
|
+
D0 = mu1_model.predict(X[idx_c]) - Y[idx_c] # control: mu1(X) - Y0
|
|
529
|
+
|
|
530
|
+
tau1_model = _base_learner()
|
|
531
|
+
tau0_model = _base_learner()
|
|
532
|
+
tau1_model.fit(X[idx_t], D1)
|
|
533
|
+
tau0_model.fit(X[idx_c], D0)
|
|
534
|
+
|
|
535
|
+
# Propensity score for weighting
|
|
536
|
+
ps_model = _prop_learner()
|
|
537
|
+
ps_model.fit(X, T.astype(int))
|
|
538
|
+
e_hat = ps_model.predict_proba(X)[:, 1]
|
|
539
|
+
e_hat = np.clip(e_hat, 0.01, 0.99)
|
|
540
|
+
|
|
541
|
+
tau1_hat = tau1_model.predict(X)
|
|
542
|
+
tau0_hat = tau0_model.predict(X)
|
|
543
|
+
# Propensity-weighted combination
|
|
544
|
+
cate_estimates = e_hat * tau0_hat + (1 - e_hat) * tau1_hat
|
|
545
|
+
method_desc = "X-Learner (cross-fitted imputed outcomes)"
|
|
546
|
+
|
|
547
|
+
elif method == "drlearner":
|
|
548
|
+
# DR-Learner: doubly robust pseudo-outcomes
|
|
549
|
+
mu1_model = _base_learner()
|
|
550
|
+
mu0_model = _base_learner()
|
|
551
|
+
mu1_model.fit(X[idx_t], Y[idx_t])
|
|
552
|
+
mu0_model.fit(X[idx_c], Y[idx_c])
|
|
553
|
+
mu1_hat = mu1_model.predict(X)
|
|
554
|
+
mu0_hat = mu0_model.predict(X)
|
|
555
|
+
|
|
556
|
+
ps_model = _prop_learner()
|
|
557
|
+
ps_model.fit(X, T.astype(int))
|
|
558
|
+
e_hat = ps_model.predict_proba(X)[:, 1]
|
|
559
|
+
e_hat = np.clip(e_hat, 0.01, 0.99)
|
|
560
|
+
|
|
561
|
+
# DR pseudo-outcome
|
|
562
|
+
psi = (
|
|
563
|
+
(T * (Y - mu1_hat)) / e_hat
|
|
564
|
+
- ((1 - T) * (Y - mu0_hat)) / (1 - e_hat)
|
|
565
|
+
+ mu1_hat - mu0_hat
|
|
566
|
+
)
|
|
567
|
+
# Second-stage regression on pseudo-outcomes
|
|
568
|
+
tau_model = _base_learner()
|
|
569
|
+
tau_model.fit(X, psi)
|
|
570
|
+
cate_estimates = tau_model.predict(X)
|
|
571
|
+
method_desc = "DR-Learner (doubly robust pseudo-outcomes)"
|
|
572
|
+
|
|
573
|
+
else:
|
|
574
|
+
return (
|
|
575
|
+
f"Unknown method '{method}'. "
|
|
576
|
+
"Choose from: tlearner, xlearner, drlearner"
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
# ------------------------------------------------------------------
|
|
580
|
+
# Summary statistics
|
|
581
|
+
# ------------------------------------------------------------------
|
|
582
|
+
ate = float(np.mean(cate_estimates))
|
|
583
|
+
att = float(np.mean(cate_estimates[idx_t]))
|
|
584
|
+
atc = float(np.mean(cate_estimates[idx_c]))
|
|
585
|
+
cate_std = float(np.std(cate_estimates))
|
|
586
|
+
cate_min = float(np.min(cate_estimates))
|
|
587
|
+
cate_max = float(np.max(cate_estimates))
|
|
588
|
+
q25, q50, q75 = np.percentile(cate_estimates, [25, 50, 75])
|
|
589
|
+
|
|
590
|
+
# Bootstrap SE for ATE (200 replicates, fast)
|
|
591
|
+
rng = np.random.default_rng(42)
|
|
592
|
+
boot_ate = []
|
|
593
|
+
for _ in range(200):
|
|
594
|
+
idx_b = rng.integers(0, n, size=n)
|
|
595
|
+
boot_ate.append(float(np.mean(cate_estimates[idx_b])))
|
|
596
|
+
ate_se = float(np.std(boot_ate))
|
|
597
|
+
ate_ci_lo = ate - 1.96 * ate_se
|
|
598
|
+
ate_ci_hi = ate + 1.96 * ate_se
|
|
599
|
+
|
|
600
|
+
lines = [
|
|
601
|
+
f"\nConditional Average Treatment Effects (CATE)",
|
|
602
|
+
_sep(),
|
|
603
|
+
f" Method : {method_desc}",
|
|
604
|
+
f" Outcome : {y_col}",
|
|
605
|
+
f" Treatment : {treat_col}",
|
|
606
|
+
f" Covariates : {', '.join(x_cols)}",
|
|
607
|
+
f" N total : {n} (treated={n_treated}, control={n_control})",
|
|
608
|
+
"",
|
|
609
|
+
"CATE Distribution:",
|
|
610
|
+
f" {'ATE (avg treatment effect)':<35} {ate:>10.4f}",
|
|
611
|
+
f" {'ATT (avg on treated)':<35} {att:>10.4f}",
|
|
612
|
+
f" {'ATC (avg on controls)':<35} {atc:>10.4f}",
|
|
613
|
+
f" {'SD of individual CATE':<35} {cate_std:>10.4f}",
|
|
614
|
+
f" {'Min':<35} {cate_min:>10.4f}",
|
|
615
|
+
f" {'Q25':<35} {q25:>10.4f}",
|
|
616
|
+
f" {'Median':<35} {q50:>10.4f}",
|
|
617
|
+
f" {'Q75':<35} {q75:>10.4f}",
|
|
618
|
+
f" {'Max':<35} {cate_max:>10.4f}",
|
|
619
|
+
"",
|
|
620
|
+
"ATE Inference (bootstrap, 200 reps):",
|
|
621
|
+
f" {'ATE':<15} {ate:>10.4f}",
|
|
622
|
+
f" {'Bootstrap SE':<15} {ate_se:>10.4f}",
|
|
623
|
+
f" {'95% CI':<15} [{ate_ci_lo:.4f}, {ate_ci_hi:.4f}]",
|
|
624
|
+
]
|
|
625
|
+
|
|
626
|
+
# Heterogeneity test: variance of CATE vs bootstrap null variance
|
|
627
|
+
if cate_std > ate_se:
|
|
628
|
+
lines.append(
|
|
629
|
+
"\n Heterogeneity: CATE SD > bootstrap SE, "
|
|
630
|
+
"suggesting effect heterogeneity."
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# ------------------------------------------------------------------
|
|
634
|
+
# Plot CATE distribution
|
|
635
|
+
# ------------------------------------------------------------------
|
|
636
|
+
try:
|
|
637
|
+
import matplotlib
|
|
638
|
+
matplotlib.use("Agg")
|
|
639
|
+
import matplotlib.pyplot as plt
|
|
640
|
+
|
|
641
|
+
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
|
|
642
|
+
|
|
643
|
+
# Histogram of CATE
|
|
644
|
+
axes[0].hist(cate_estimates, bins=30, color="#4C72B0", alpha=0.75,
|
|
645
|
+
edgecolor="white")
|
|
646
|
+
axes[0].axvline(ate, color="red", linestyle="--", linewidth=1.5,
|
|
647
|
+
label=f"ATE={ate:.3f}")
|
|
648
|
+
axes[0].axvline(0, color="black", linestyle=":", linewidth=1.0,
|
|
649
|
+
alpha=0.6)
|
|
650
|
+
axes[0].set_xlabel("CATE")
|
|
651
|
+
axes[0].set_ylabel("Frequency")
|
|
652
|
+
axes[0].set_title(f"CATE Distribution ({method.upper()})")
|
|
653
|
+
axes[0].legend()
|
|
654
|
+
|
|
655
|
+
# CATE vs first covariate (sorted)
|
|
656
|
+
sort_idx = np.argsort(X[:, 0])
|
|
657
|
+
axes[1].scatter(X[sort_idx, 0], cate_estimates[sort_idx],
|
|
658
|
+
alpha=0.4, s=20, color="#4C72B0")
|
|
659
|
+
axes[1].axhline(ate, color="red", linestyle="--", linewidth=1.5,
|
|
660
|
+
label=f"ATE={ate:.3f}")
|
|
661
|
+
axes[1].axhline(0, color="black", linestyle=":", linewidth=1.0, alpha=0.6)
|
|
662
|
+
axes[1].set_xlabel(x_cols[0])
|
|
663
|
+
axes[1].set_ylabel("CATE")
|
|
664
|
+
axes[1].set_title(f"CATE vs {x_cols[0]}")
|
|
665
|
+
axes[1].legend()
|
|
666
|
+
|
|
667
|
+
fig.tight_layout()
|
|
668
|
+
session.output_dir.mkdir(parents=True, exist_ok=True)
|
|
669
|
+
plot_path = session.output_dir / "cate.png"
|
|
670
|
+
fig.savefig(plot_path, dpi=150)
|
|
671
|
+
plt.close(fig)
|
|
672
|
+
session.plot_paths.append(str(plot_path))
|
|
673
|
+
lines.append(f"\nPlot saved: {plot_path}")
|
|
674
|
+
|
|
675
|
+
except Exception as plot_err:
|
|
676
|
+
lines.append(f"\nPlot error: {plot_err}")
|
|
677
|
+
|
|
678
|
+
lines.append(_sep())
|
|
679
|
+
return "\n".join(lines)
|
|
680
|
+
|
|
681
|
+
except Exception as e:
|
|
682
|
+
return friendly_error(e, "cate")
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
# ---------------------------------------------------------------------------
|
|
686
|
+
# 4. Joinpoint Trend Analysis
|
|
687
|
+
# ---------------------------------------------------------------------------
|
|
688
|
+
|
|
689
|
+
@command("joinpoint", usage="joinpoint <y> <x> [--max_points=3] [--permutations=100]")
|
|
690
|
+
def cmd_joinpoint(session: Session, args: str) -> str:
|
|
691
|
+
"""Joinpoint (piecewise linear) trend analysis with BIC-based model selection.
|
|
692
|
+
|
|
693
|
+
Finds optimal changepoints (joinpoints) in a trend using BIC minimisation.
|
|
694
|
+
Reports segment slopes, percent change per unit, and p-values.
|
|
695
|
+
|
|
696
|
+
Examples:
|
|
697
|
+
joinpoint cancer_rate year
|
|
698
|
+
joinpoint incidence year --max_points=3
|
|
699
|
+
joinpoint rate year --max_points=2 --permutations=200
|
|
700
|
+
"""
|
|
701
|
+
import numpy as np
|
|
702
|
+
|
|
703
|
+
ca = CommandArgs(args)
|
|
704
|
+
pos = [p for p in ca.positional if not p.startswith("--")]
|
|
705
|
+
|
|
706
|
+
if len(pos) < 2:
|
|
707
|
+
return "Usage: joinpoint <y> <x> [--max_points=3] [--permutations=100]"
|
|
708
|
+
|
|
709
|
+
y_col, x_col = pos[0], pos[1]
|
|
710
|
+
max_jp = int(ca.options.get("max_points", 3))
|
|
711
|
+
n_perm = int(ca.options.get("permutations", 100))
|
|
712
|
+
|
|
713
|
+
df = session.require_data()
|
|
714
|
+
|
|
715
|
+
if y_col not in df.columns:
|
|
716
|
+
return f"Column not found: {y_col}"
|
|
717
|
+
if x_col not in df.columns:
|
|
718
|
+
return f"Column not found: {x_col}"
|
|
719
|
+
|
|
720
|
+
try:
|
|
721
|
+
sub = df.select([y_col, x_col]).drop_nulls()
|
|
722
|
+
X = sub[x_col].to_numpy().astype(float)
|
|
723
|
+
Y = sub[y_col].to_numpy().astype(float)
|
|
724
|
+
n = len(X)
|
|
725
|
+
|
|
726
|
+
if n < 6:
|
|
727
|
+
return "Joinpoint analysis requires at least 6 data points."
|
|
728
|
+
|
|
729
|
+
sort_idx = np.argsort(X)
|
|
730
|
+
X = X[sort_idx]
|
|
731
|
+
Y = Y[sort_idx]
|
|
732
|
+
|
|
733
|
+
# ------------------------------------------------------------------
|
|
734
|
+
# Build piecewise linear design matrix for given joinpoints
|
|
735
|
+
# ------------------------------------------------------------------
|
|
736
|
+
def _piecewise_design(x_vals, joinpoints):
|
|
737
|
+
"""Build design matrix: [1, x, (x-jp1)+, (x-jp2)+, ...]."""
|
|
738
|
+
cols = [np.ones(len(x_vals)), x_vals]
|
|
739
|
+
for jp in joinpoints:
|
|
740
|
+
cols.append(np.maximum(x_vals - jp, 0.0))
|
|
741
|
+
return np.column_stack(cols)
|
|
742
|
+
|
|
743
|
+
def _fit_piecewise(x_vals, y_vals, joinpoints):
|
|
744
|
+
"""OLS fit of piecewise linear model; return (params, rss, bic)."""
|
|
745
|
+
A = _piecewise_design(x_vals, joinpoints)
|
|
746
|
+
try:
|
|
747
|
+
params, res, rank, sv = np.linalg.lstsq(A, y_vals, rcond=None)
|
|
748
|
+
y_hat = A @ params
|
|
749
|
+
rss = float(np.sum((y_vals - y_hat) ** 2))
|
|
750
|
+
except np.linalg.LinAlgError:
|
|
751
|
+
return None, np.inf, np.inf
|
|
752
|
+
k = A.shape[1]
|
|
753
|
+
nn = len(y_vals)
|
|
754
|
+
sigma2 = rss / max(nn - k, 1)
|
|
755
|
+
# BIC = n * ln(RSS/n) + k * ln(n)
|
|
756
|
+
bic = nn * np.log(max(rss / nn, 1e-15)) + k * np.log(nn)
|
|
757
|
+
return params, rss, bic
|
|
758
|
+
|
|
759
|
+
# ------------------------------------------------------------------
|
|
760
|
+
# Grid search over candidate joinpoints
|
|
761
|
+
# ------------------------------------------------------------------
|
|
762
|
+
# Candidate joinpoints: interior X values (exclude boundary 20%)
|
|
763
|
+
margin = max(2, int(0.15 * n))
|
|
764
|
+
candidate_x = X[margin: n - margin]
|
|
765
|
+
candidate_x = np.unique(candidate_x)
|
|
766
|
+
|
|
767
|
+
best_bic = np.inf
|
|
768
|
+
best_jps = []
|
|
769
|
+
best_params = None
|
|
770
|
+
best_n_jp = 0
|
|
771
|
+
|
|
772
|
+
# Evaluate 0 joinpoints (simple linear trend)
|
|
773
|
+
params0, rss0, bic0 = _fit_piecewise(X, Y, [])
|
|
774
|
+
best_bic = bic0
|
|
775
|
+
best_jps = []
|
|
776
|
+
best_params = params0
|
|
777
|
+
best_n_jp = 0
|
|
778
|
+
|
|
779
|
+
# Evaluate 1..max_jp joinpoints (greedy + random search for speed)
|
|
780
|
+
for n_jp in range(1, max_jp + 1):
|
|
781
|
+
if len(candidate_x) < n_jp:
|
|
782
|
+
break
|
|
783
|
+
|
|
784
|
+
# Subsample candidates for speed (up to 30 per slot)
|
|
785
|
+
step = max(1, len(candidate_x) // 30)
|
|
786
|
+
sampled = candidate_x[::step]
|
|
787
|
+
|
|
788
|
+
if n_jp == 1:
|
|
789
|
+
for jp in sampled:
|
|
790
|
+
p, r, b = _fit_piecewise(X, Y, [jp])
|
|
791
|
+
if b < best_bic:
|
|
792
|
+
best_bic = b
|
|
793
|
+
best_jps = [jp]
|
|
794
|
+
best_params = p
|
|
795
|
+
best_n_jp = n_jp
|
|
796
|
+
|
|
797
|
+
elif n_jp == 2:
|
|
798
|
+
for i in range(len(sampled)):
|
|
799
|
+
for j in range(i + 1, len(sampled)):
|
|
800
|
+
jps = [sampled[i], sampled[j]]
|
|
801
|
+
p, r, b = _fit_piecewise(X, Y, jps)
|
|
802
|
+
if b < best_bic:
|
|
803
|
+
best_bic = b
|
|
804
|
+
best_jps = jps
|
|
805
|
+
best_params = p
|
|
806
|
+
best_n_jp = n_jp
|
|
807
|
+
|
|
808
|
+
elif n_jp == 3:
|
|
809
|
+
for i in range(len(sampled)):
|
|
810
|
+
for j in range(i + 1, len(sampled)):
|
|
811
|
+
for k in range(j + 1, len(sampled)):
|
|
812
|
+
jps = [sampled[i], sampled[j], sampled[k]]
|
|
813
|
+
p, r, b = _fit_piecewise(X, Y, jps)
|
|
814
|
+
if b < best_bic:
|
|
815
|
+
best_bic = b
|
|
816
|
+
best_jps = jps
|
|
817
|
+
best_params = p
|
|
818
|
+
best_n_jp = n_jp
|
|
819
|
+
|
|
820
|
+
# ------------------------------------------------------------------
|
|
821
|
+
# Extract segment slopes
|
|
822
|
+
# ------------------------------------------------------------------
|
|
823
|
+
# params = [intercept, slope, delta1, delta2, ...]
|
|
824
|
+
# Cumulative slope in segment i = slope + sum(delta_j for j <= i)
|
|
825
|
+
segments = []
|
|
826
|
+
breakpoints = sorted(best_jps)
|
|
827
|
+
|
|
828
|
+
# Segment boundaries
|
|
829
|
+
seg_bounds = (
|
|
830
|
+
[float(X[0])]
|
|
831
|
+
+ [float(jp) for jp in breakpoints]
|
|
832
|
+
+ [float(X[-1])]
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
if best_params is not None:
|
|
836
|
+
base_slope = float(best_params[1]) if len(best_params) > 1 else 0.0
|
|
837
|
+
cum_slope = base_slope
|
|
838
|
+
for seg_i, (x_lo, x_hi) in enumerate(zip(seg_bounds[:-1], seg_bounds[1:])):
|
|
839
|
+
if seg_i > 0 and seg_i - 1 < len(best_params) - 2:
|
|
840
|
+
cum_slope += float(best_params[seg_i + 1])
|
|
841
|
+
# APC = annual percent change (relative to mean Y in segment)
|
|
842
|
+
mask = (X >= x_lo) & (X <= x_hi)
|
|
843
|
+
y_seg_mean = float(Y[mask].mean()) if mask.sum() > 0 else 1.0
|
|
844
|
+
apc = 100.0 * cum_slope / (y_seg_mean + 1e-12)
|
|
845
|
+
segments.append({
|
|
846
|
+
"seg": seg_i + 1,
|
|
847
|
+
"from": x_lo,
|
|
848
|
+
"to": x_hi,
|
|
849
|
+
"slope": cum_slope,
|
|
850
|
+
"apc": apc,
|
|
851
|
+
"n_pts": int(mask.sum()),
|
|
852
|
+
})
|
|
853
|
+
|
|
854
|
+
# ------------------------------------------------------------------
|
|
855
|
+
# Permutation test for number of joinpoints
|
|
856
|
+
# ------------------------------------------------------------------
|
|
857
|
+
perm_p = None
|
|
858
|
+
if n_perm > 0 and best_n_jp > 0:
|
|
859
|
+
# H0: linear trend; H1: best_n_jp joinpoints
|
|
860
|
+
_, rss_null, _ = _fit_piecewise(X, Y, [])
|
|
861
|
+
_, rss_alt, _ = _fit_piecewise(X, Y, best_jps)
|
|
862
|
+
obs_stat = rss_null / (rss_alt + 1e-15)
|
|
863
|
+
|
|
864
|
+
rng = np.random.default_rng(42)
|
|
865
|
+
perm_count = 0
|
|
866
|
+
for _ in range(n_perm):
|
|
867
|
+
Y_perm = rng.permutation(Y)
|
|
868
|
+
# Fit same joinpoints to permuted data
|
|
869
|
+
_, rss_alt_p, _ = _fit_piecewise(X, Y_perm, best_jps)
|
|
870
|
+
_, rss_null_p, _ = _fit_piecewise(X, Y_perm, [])
|
|
871
|
+
perm_stat = rss_null_p / (rss_alt_p + 1e-15)
|
|
872
|
+
if perm_stat >= obs_stat:
|
|
873
|
+
perm_count += 1
|
|
874
|
+
perm_p = perm_count / n_perm
|
|
875
|
+
|
|
876
|
+
# ------------------------------------------------------------------
|
|
877
|
+
# Format output
|
|
878
|
+
# ------------------------------------------------------------------
|
|
879
|
+
lines = [
|
|
880
|
+
f"\nJoinpoint Trend Analysis",
|
|
881
|
+
_sep(),
|
|
882
|
+
f" Y variable : {y_col}",
|
|
883
|
+
f" X variable : {x_col}",
|
|
884
|
+
f" N points : {n}",
|
|
885
|
+
f" Max JP : {max_jp}",
|
|
886
|
+
f" Best model : {best_n_jp} joinpoint(s) BIC = {best_bic:.4f}",
|
|
887
|
+
]
|
|
888
|
+
if breakpoints:
|
|
889
|
+
lines.append(f" Joinpoints : {[round(jp, 4) for jp in breakpoints]}")
|
|
890
|
+
if perm_p is not None:
|
|
891
|
+
lines.append(
|
|
892
|
+
f" Permutation p-value ({n_perm} perms): {perm_p:.4f}{_sig_stars(perm_p)}"
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
lines.append("")
|
|
896
|
+
lines.append("Trend Segments:")
|
|
897
|
+
rows = []
|
|
898
|
+
for seg in segments:
|
|
899
|
+
rows.append((
|
|
900
|
+
str(seg["seg"]),
|
|
901
|
+
f"{seg['from']:.2f}",
|
|
902
|
+
f"{seg['to']:.2f}",
|
|
903
|
+
f"{seg['slope']:.6f}",
|
|
904
|
+
f"{seg['apc']:.2f}%",
|
|
905
|
+
str(seg["n_pts"]),
|
|
906
|
+
))
|
|
907
|
+
lines.append(_coef_table(
|
|
908
|
+
rows,
|
|
909
|
+
["Seg", "From", "To", "Slope", "APC", "N pts"],
|
|
910
|
+
))
|
|
911
|
+
|
|
912
|
+
# ------------------------------------------------------------------
|
|
913
|
+
# Plot
|
|
914
|
+
# ------------------------------------------------------------------
|
|
915
|
+
try:
|
|
916
|
+
import matplotlib
|
|
917
|
+
matplotlib.use("Agg")
|
|
918
|
+
import matplotlib.pyplot as plt
|
|
919
|
+
|
|
920
|
+
fig, ax = plt.subplots(figsize=(9, 5))
|
|
921
|
+
ax.scatter(X, Y, color="#4C72B0", alpha=0.7, s=40, zorder=3,
|
|
922
|
+
label="Observed")
|
|
923
|
+
|
|
924
|
+
# Fitted piecewise line
|
|
925
|
+
if best_params is not None:
|
|
926
|
+
x_fine = np.linspace(X.min(), X.max(), 400)
|
|
927
|
+
A_fine = _piecewise_design(x_fine, breakpoints)
|
|
928
|
+
y_fine = A_fine @ best_params
|
|
929
|
+
ax.plot(x_fine, y_fine, color="red", linewidth=2.0,
|
|
930
|
+
label=f"Joinpoint fit ({best_n_jp} JP)")
|
|
931
|
+
|
|
932
|
+
for jp in breakpoints:
|
|
933
|
+
ax.axvline(jp, color="grey", linestyle="--", linewidth=1.0, alpha=0.7)
|
|
934
|
+
|
|
935
|
+
ax.set_xlabel(x_col)
|
|
936
|
+
ax.set_ylabel(y_col)
|
|
937
|
+
ax.set_title(f"Joinpoint Trend: {y_col} vs {x_col}")
|
|
938
|
+
ax.legend()
|
|
939
|
+
ax.grid(alpha=0.3)
|
|
940
|
+
fig.tight_layout()
|
|
941
|
+
|
|
942
|
+
session.output_dir.mkdir(parents=True, exist_ok=True)
|
|
943
|
+
plot_path = session.output_dir / "joinpoint.png"
|
|
944
|
+
fig.savefig(plot_path, dpi=150)
|
|
945
|
+
plt.close(fig)
|
|
946
|
+
session.plot_paths.append(str(plot_path))
|
|
947
|
+
lines.append(f"\nPlot saved: {plot_path}")
|
|
948
|
+
|
|
949
|
+
except Exception as plot_err:
|
|
950
|
+
lines.append(f"\nPlot error: {plot_err}")
|
|
951
|
+
|
|
952
|
+
lines.append(_sep())
|
|
953
|
+
return "\n".join(lines)
|
|
954
|
+
|
|
955
|
+
except Exception as e:
|
|
956
|
+
return friendly_error(e, "joinpoint")
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
# ---------------------------------------------------------------------------
|
|
960
|
+
# 5. Spline and LOESS Regression
|
|
961
|
+
# ---------------------------------------------------------------------------
|
|
962
|
+
|
|
963
|
+
@command("spline", usage="spline <y> <x1> [x2 ...] [--knots=N|<list>] [--type=natural|bs|loess]")
|
|
964
|
+
def cmd_spline(session: Session, args: str) -> str:
|
|
965
|
+
"""Spline and LOESS smoothing regression.
|
|
966
|
+
|
|
967
|
+
Types:
|
|
968
|
+
natural : Natural cubic splines via statsmodels.
|
|
969
|
+
bs : B-splines via statsmodels.
|
|
970
|
+
loess : LOWESS smoothing (no parametric assumptions).
|
|
971
|
+
|
|
972
|
+
The --knots option accepts an integer (number of equally spaced interior knots)
|
|
973
|
+
or a comma-separated list of knot positions (e.g. --knots=25,50,75).
|
|
974
|
+
|
|
975
|
+
Examples:
|
|
976
|
+
spline y x --knots=4 --type=natural
|
|
977
|
+
spline y x --type=loess
|
|
978
|
+
spline y x1 x2 --knots=3 --type=bs
|
|
979
|
+
spline y x --knots=20,40,60 --type=natural
|
|
980
|
+
"""
|
|
981
|
+
import numpy as np
|
|
982
|
+
|
|
983
|
+
ca = CommandArgs(args)
|
|
984
|
+
pos = [p for p in ca.positional if not p.startswith("--")]
|
|
985
|
+
spline_type = ca.options.get("type", "natural").lower()
|
|
986
|
+
knots_opt = ca.options.get("knots", "4")
|
|
987
|
+
|
|
988
|
+
if len(pos) < 2:
|
|
989
|
+
return "Usage: spline <y> <x1> [x2 ...] [--knots=N|<list>] [--type=natural|bs|loess]"
|
|
990
|
+
|
|
991
|
+
y_col = pos[0]
|
|
992
|
+
x_cols = pos[1:]
|
|
993
|
+
|
|
994
|
+
df = session.require_data()
|
|
995
|
+
|
|
996
|
+
needed = [y_col] + x_cols
|
|
997
|
+
missing = [c for c in needed if c not in df.columns]
|
|
998
|
+
if missing:
|
|
999
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
1000
|
+
|
|
1001
|
+
try:
|
|
1002
|
+
sub = df.select(needed).drop_nulls()
|
|
1003
|
+
Y = sub[y_col].to_numpy().astype(float)
|
|
1004
|
+
n = len(Y)
|
|
1005
|
+
|
|
1006
|
+
if n < 6:
|
|
1007
|
+
return "Spline regression requires at least 6 observations."
|
|
1008
|
+
|
|
1009
|
+
# Parse knots specification
|
|
1010
|
+
knot_positions = None
|
|
1011
|
+
try:
|
|
1012
|
+
n_knots = int(knots_opt)
|
|
1013
|
+
except ValueError:
|
|
1014
|
+
try:
|
|
1015
|
+
knot_positions = [float(k.strip()) for k in knots_opt.split(",")]
|
|
1016
|
+
n_knots = len(knot_positions)
|
|
1017
|
+
except ValueError:
|
|
1018
|
+
n_knots = 4
|
|
1019
|
+
|
|
1020
|
+
# Use first x column as primary smoothing variable for plots
|
|
1021
|
+
x_primary = sub[x_cols[0]].to_numpy().astype(float)
|
|
1022
|
+
|
|
1023
|
+
# Compute knot positions if not given
|
|
1024
|
+
if knot_positions is None:
|
|
1025
|
+
percentile_step = 100.0 / (n_knots + 1)
|
|
1026
|
+
knot_positions = [
|
|
1027
|
+
float(np.percentile(x_primary, percentile_step * (i + 1)))
|
|
1028
|
+
for i in range(n_knots)
|
|
1029
|
+
]
|
|
1030
|
+
|
|
1031
|
+
lines = [
|
|
1032
|
+
f"\nSpline / LOESS Regression",
|
|
1033
|
+
_sep(),
|
|
1034
|
+
f" Type : {spline_type}",
|
|
1035
|
+
f" Y : {y_col}",
|
|
1036
|
+
f" X : {', '.join(x_cols)}",
|
|
1037
|
+
f" N : {n}",
|
|
1038
|
+
]
|
|
1039
|
+
|
|
1040
|
+
y_hat = None
|
|
1041
|
+
y_hat_lo = None
|
|
1042
|
+
y_hat_hi = None
|
|
1043
|
+
model_info = {}
|
|
1044
|
+
|
|
1045
|
+
# ------------------------------------------------------------------
|
|
1046
|
+
# Natural cubic splines
|
|
1047
|
+
# ------------------------------------------------------------------
|
|
1048
|
+
if spline_type in ("natural", "bs"):
|
|
1049
|
+
try:
|
|
1050
|
+
import statsmodels.api as sm
|
|
1051
|
+
from patsy import dmatrix
|
|
1052
|
+
_has_patsy = True
|
|
1053
|
+
except ImportError:
|
|
1054
|
+
_has_patsy = False
|
|
1055
|
+
|
|
1056
|
+
if not _has_patsy:
|
|
1057
|
+
return (
|
|
1058
|
+
"statsmodels and patsy are required for spline regression. "
|
|
1059
|
+
"Run: pip install statsmodels patsy"
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
# Build formula for all x columns
|
|
1063
|
+
if spline_type == "natural":
|
|
1064
|
+
knot_str = ", ".join(str(round(k, 4)) for k in knot_positions)
|
|
1065
|
+
spline_terms = []
|
|
1066
|
+
for xc in x_cols:
|
|
1067
|
+
spline_terms.append(f'cr({xc}, knots=[{knot_str}])')
|
|
1068
|
+
formula_str = " + ".join(spline_terms)
|
|
1069
|
+
else:
|
|
1070
|
+
# B-splines
|
|
1071
|
+
knot_str = ", ".join(str(round(k, 4)) for k in knot_positions)
|
|
1072
|
+
spline_terms = []
|
|
1073
|
+
for xc in x_cols:
|
|
1074
|
+
spline_terms.append(f'bs({xc}, knots=[{knot_str}], include_intercept=False)')
|
|
1075
|
+
formula_str = " + ".join(spline_terms)
|
|
1076
|
+
|
|
1077
|
+
try:
|
|
1078
|
+
data_dict = {xc: sub[xc].to_numpy().astype(float) for xc in x_cols}
|
|
1079
|
+
data_dict[y_col] = Y
|
|
1080
|
+
X_design = dmatrix(formula_str, data_dict, return_type="matrix")
|
|
1081
|
+
X_sm = np.asarray(X_design)
|
|
1082
|
+
X_sm = sm.add_constant(X_sm, has_constant="add")
|
|
1083
|
+
|
|
1084
|
+
model = sm.OLS(Y, X_sm)
|
|
1085
|
+
result = model.fit()
|
|
1086
|
+
y_hat = result.fittedvalues
|
|
1087
|
+
|
|
1088
|
+
# Confidence interval via prediction
|
|
1089
|
+
pred = result.get_prediction(X_sm)
|
|
1090
|
+
pred_df = pred.summary_frame(alpha=0.05)
|
|
1091
|
+
y_hat_lo = pred_df["obs_ci_lower"].values
|
|
1092
|
+
y_hat_hi = pred_df["obs_ci_upper"].values
|
|
1093
|
+
|
|
1094
|
+
rss = float(np.sum((Y - y_hat) ** 2))
|
|
1095
|
+
tss = float(np.sum((Y - Y.mean()) ** 2))
|
|
1096
|
+
r2 = 1.0 - rss / (tss + 1e-15)
|
|
1097
|
+
aic = float(result.aic)
|
|
1098
|
+
bic = float(result.bic)
|
|
1099
|
+
|
|
1100
|
+
lines += [
|
|
1101
|
+
f" Knots : {[round(k, 4) for k in knot_positions]}",
|
|
1102
|
+
f" R-squared : {r2:.4f}",
|
|
1103
|
+
f" AIC : {aic:.4f}",
|
|
1104
|
+
f" BIC : {bic:.4f}",
|
|
1105
|
+
f" N params : {result.df_model:.0f} (df resid={result.df_resid:.0f})",
|
|
1106
|
+
]
|
|
1107
|
+
model_info = {"r2": r2, "aic": aic}
|
|
1108
|
+
|
|
1109
|
+
except Exception as sm_err:
|
|
1110
|
+
lines.append(f"\nSpline fit error: {sm_err}")
|
|
1111
|
+
|
|
1112
|
+
# ------------------------------------------------------------------
|
|
1113
|
+
# LOESS / LOWESS
|
|
1114
|
+
# ------------------------------------------------------------------
|
|
1115
|
+
elif spline_type == "loess":
|
|
1116
|
+
try:
|
|
1117
|
+
import statsmodels.api as sm
|
|
1118
|
+
except ImportError:
|
|
1119
|
+
return (
|
|
1120
|
+
"statsmodels is required for LOESS. "
|
|
1121
|
+
"Run: pip install statsmodels"
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
# LOWESS operates on a single predictor
|
|
1125
|
+
x_sm = x_primary
|
|
1126
|
+
frac = float(ca.options.get("frac", 0.3))
|
|
1127
|
+
frac = max(0.1, min(frac, 1.0))
|
|
1128
|
+
|
|
1129
|
+
lowess_result = sm.nonparametric.lowess(Y, x_sm, frac=frac, it=3)
|
|
1130
|
+
# lowess_result columns: [x_sorted, y_smoothed]
|
|
1131
|
+
x_loess = lowess_result[:, 0]
|
|
1132
|
+
y_loess = lowess_result[:, 1]
|
|
1133
|
+
|
|
1134
|
+
# Interpolate back to original order for residuals
|
|
1135
|
+
from numpy import interp
|
|
1136
|
+
y_hat = interp(x_sm, x_loess, y_loess)
|
|
1137
|
+
|
|
1138
|
+
# Bootstrap confidence band (100 reps, fast)
|
|
1139
|
+
rng = np.random.default_rng(42)
|
|
1140
|
+
boot_fits = []
|
|
1141
|
+
for _ in range(100):
|
|
1142
|
+
idx_b = rng.integers(0, n, n)
|
|
1143
|
+
x_b = x_sm[idx_b]
|
|
1144
|
+
y_b = Y[idx_b]
|
|
1145
|
+
sort_b = np.argsort(x_b)
|
|
1146
|
+
try:
|
|
1147
|
+
lw_b = sm.nonparametric.lowess(
|
|
1148
|
+
y_b[sort_b], x_b[sort_b], frac=frac, it=1
|
|
1149
|
+
)
|
|
1150
|
+
boot_fits.append(interp(x_sm, lw_b[:, 0], lw_b[:, 1]))
|
|
1151
|
+
except Exception:
|
|
1152
|
+
pass
|
|
1153
|
+
|
|
1154
|
+
if boot_fits:
|
|
1155
|
+
boot_arr = np.array(boot_fits)
|
|
1156
|
+
y_hat_lo = np.percentile(boot_arr, 2.5, axis=0)
|
|
1157
|
+
y_hat_hi = np.percentile(boot_arr, 97.5, axis=0)
|
|
1158
|
+
|
|
1159
|
+
rss = float(np.sum((Y - y_hat) ** 2))
|
|
1160
|
+
tss = float(np.sum((Y - Y.mean()) ** 2))
|
|
1161
|
+
r2 = 1.0 - rss / (tss + 1e-15)
|
|
1162
|
+
|
|
1163
|
+
lines += [
|
|
1164
|
+
f" Bandwidth (frac) : {frac:.2f}",
|
|
1165
|
+
f" R-squared (approx): {r2:.4f}",
|
|
1166
|
+
" (LOESS uses single predictor; additional X ignored.)",
|
|
1167
|
+
]
|
|
1168
|
+
model_info = {"r2": r2}
|
|
1169
|
+
|
|
1170
|
+
else:
|
|
1171
|
+
return (
|
|
1172
|
+
f"Unknown type '{spline_type}'. "
|
|
1173
|
+
"Choose from: natural, bs, loess"
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
# ------------------------------------------------------------------
|
|
1177
|
+
# Residual summary
|
|
1178
|
+
# ------------------------------------------------------------------
|
|
1179
|
+
if y_hat is not None:
|
|
1180
|
+
resid = Y - y_hat
|
|
1181
|
+
lines += [
|
|
1182
|
+
"",
|
|
1183
|
+
"Residual Summary:",
|
|
1184
|
+
f" {'Mean residual':<30} {float(resid.mean()):>10.4f}",
|
|
1185
|
+
f" {'SD residual':<30} {float(resid.std()):>10.4f}",
|
|
1186
|
+
f" {'Min':<30} {float(resid.min()):>10.4f}",
|
|
1187
|
+
f" {'Max':<30} {float(resid.max()):>10.4f}",
|
|
1188
|
+
f" {'RMSE':<30} {float(np.sqrt(np.mean(resid**2))):>10.4f}",
|
|
1189
|
+
]
|
|
1190
|
+
|
|
1191
|
+
# ------------------------------------------------------------------
|
|
1192
|
+
# Plot: fitted curve with CI band
|
|
1193
|
+
# ------------------------------------------------------------------
|
|
1194
|
+
try:
|
|
1195
|
+
import matplotlib
|
|
1196
|
+
matplotlib.use("Agg")
|
|
1197
|
+
import matplotlib.pyplot as plt
|
|
1198
|
+
|
|
1199
|
+
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
|
1200
|
+
|
|
1201
|
+
# Left: data + fitted + CI
|
|
1202
|
+
sort_idx = np.argsort(x_primary)
|
|
1203
|
+
x_s = x_primary[sort_idx]
|
|
1204
|
+
|
|
1205
|
+
axes[0].scatter(x_primary, Y, color="#4C72B0", alpha=0.5, s=20,
|
|
1206
|
+
label="Observed", zorder=3)
|
|
1207
|
+
if y_hat is not None:
|
|
1208
|
+
axes[0].plot(x_s, y_hat[sort_idx], color="red", linewidth=2.0,
|
|
1209
|
+
label="Fitted")
|
|
1210
|
+
if y_hat_lo is not None and y_hat_hi is not None:
|
|
1211
|
+
axes[0].fill_between(
|
|
1212
|
+
x_s, y_hat_lo[sort_idx], y_hat_hi[sort_idx],
|
|
1213
|
+
alpha=0.20, color="red", label="95% CI/band"
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
if spline_type in ("natural", "bs") and knot_positions:
|
|
1217
|
+
for kp in knot_positions:
|
|
1218
|
+
axes[0].axvline(kp, color="grey", linestyle=":",
|
|
1219
|
+
linewidth=0.8, alpha=0.7)
|
|
1220
|
+
|
|
1221
|
+
axes[0].set_xlabel(x_cols[0])
|
|
1222
|
+
axes[0].set_ylabel(y_col)
|
|
1223
|
+
axes[0].set_title(
|
|
1224
|
+
f"{spline_type.capitalize()} Spline: {y_col} ~ {x_cols[0]}"
|
|
1225
|
+
)
|
|
1226
|
+
axes[0].legend()
|
|
1227
|
+
axes[0].grid(alpha=0.3)
|
|
1228
|
+
|
|
1229
|
+
# Right: residual plot
|
|
1230
|
+
if y_hat is not None:
|
|
1231
|
+
resid = Y - y_hat
|
|
1232
|
+
axes[1].scatter(y_hat, resid, alpha=0.5, s=20, color="#2ca02c")
|
|
1233
|
+
axes[1].axhline(0, color="black", linewidth=1.0, linestyle="--")
|
|
1234
|
+
axes[1].set_xlabel("Fitted values")
|
|
1235
|
+
axes[1].set_ylabel("Residuals")
|
|
1236
|
+
axes[1].set_title("Residuals vs Fitted")
|
|
1237
|
+
axes[1].grid(alpha=0.3)
|
|
1238
|
+
|
|
1239
|
+
fig.tight_layout()
|
|
1240
|
+
session.output_dir.mkdir(parents=True, exist_ok=True)
|
|
1241
|
+
fname = f"spline_{spline_type}.png"
|
|
1242
|
+
plot_path = session.output_dir / fname
|
|
1243
|
+
fig.savefig(plot_path, dpi=150)
|
|
1244
|
+
plt.close(fig)
|
|
1245
|
+
session.plot_paths.append(str(plot_path))
|
|
1246
|
+
lines.append(f"\nPlot saved: {plot_path}")
|
|
1247
|
+
|
|
1248
|
+
except Exception as plot_err:
|
|
1249
|
+
lines.append(f"\nPlot error: {plot_err}")
|
|
1250
|
+
|
|
1251
|
+
lines.append(_sep())
|
|
1252
|
+
return "\n".join(lines)
|
|
1253
|
+
|
|
1254
|
+
except Exception as e:
|
|
1255
|
+
return friendly_error(e, "spline")
|