openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1255 @@
1
+ """Advanced statistical commands: irt, competing, cate, joinpoint, spline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from openstat.commands.base import command, CommandArgs, friendly_error
6
+ from openstat.session import Session
7
+
8
+
9
+ # ---------------------------------------------------------------------------
10
+ # Shared helpers
11
+ # ---------------------------------------------------------------------------
12
+
13
+ def _sep(width: int = 60) -> str:
14
+ return "=" * width
15
+
16
+
17
+ def _sig_stars(p: float) -> str:
18
+ if p < 0.001:
19
+ return "***"
20
+ if p < 0.01:
21
+ return "**"
22
+ if p < 0.05:
23
+ return "*"
24
+ if p < 0.10:
25
+ return "."
26
+ return ""
27
+
28
+
29
+ def _coef_table(rows: list[tuple], headers: list[str]) -> str:
30
+ """Render a simple fixed-width table.
31
+
32
+ rows : list of tuples whose elements are already formatted strings.
33
+ headers: list of column header strings.
34
+ """
35
+ col_widths = [max(len(h), max((len(str(r[i])) for r in rows), default=0))
36
+ for i, h in enumerate(headers)]
37
+ fmt = " " + " ".join(f"{{:<{w}}}" for w in col_widths)
38
+ lines = [fmt.format(*headers), " " + "-" * (sum(col_widths) + 2 * len(col_widths))]
39
+ for row in rows:
40
+ lines.append(fmt.format(*[str(x) for x in row]))
41
+ return "\n".join(lines)
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # 1. IRT — Item Response Theory
46
+ # ---------------------------------------------------------------------------
47
+
48
+ @command("irt", usage="irt <item1> [item2 ...] [--model=1pl|2pl|3pl]")
49
+ def cmd_irt(session: Session, args: str) -> str:
50
+ """Item Response Theory: estimate discrimination and difficulty parameters.
51
+
52
+ Implements 2PL (default) manually via scipy.optimize or falls back to
53
+ per-item logistic regression approximation when scipy is unavailable.
54
+
55
+ Examples:
56
+ irt q1 q2 q3 q4 q5
57
+ irt q1 q2 q3 --model=1pl
58
+ irt q1 q2 q3 q4 q5 --model=2pl
59
+ """
60
+ import numpy as np
61
+
62
+ ca = CommandArgs(args)
63
+ items = [p for p in ca.positional if not p.startswith("--")]
64
+ model = ca.options.get("model", "2pl").lower()
65
+
66
+ if not items:
67
+ return "Usage: irt <item1> [item2 ...] [--model=1pl|2pl|3pl]"
68
+
69
+ df = session.require_data()
70
+
71
+ missing = [c for c in items if c not in df.columns]
72
+ if missing:
73
+ return f"Columns not found: {', '.join(missing)}"
74
+
75
+ try:
76
+ sub = df.select(items).drop_nulls()
77
+ data = sub.to_numpy().astype(float)
78
+ n_persons, n_items = data.shape
79
+
80
+ if n_persons < 10:
81
+ return "IRT requires at least 10 complete observations."
82
+
83
+ # Check all items are binary (0/1)
84
+ for j, item in enumerate(items):
85
+ vals = np.unique(data[:, j])
86
+ non_binary = [v for v in vals if v not in (0.0, 1.0)]
87
+ if non_binary:
88
+ return (
89
+ f"Column '{item}' contains non-binary values {non_binary[:3]}. "
90
+ "IRT expects binary (0/1) item responses."
91
+ )
92
+
93
+ try:
94
+ from scipy.optimize import minimize
95
+ _has_scipy = True
96
+ except ImportError:
97
+ _has_scipy = False
98
+
99
+ # ------------------------------------------------------------------
100
+ # 2PL / 1PL via EM-like marginal maximum likelihood approximation
101
+ # We use a simplified approach: for each item, treat the sum score
102
+ # as a proxy for ability theta, then fit a logistic curve.
103
+ # ------------------------------------------------------------------
104
+
105
+ # Ability proxy: standardised sum score
106
+ raw_scores = data.sum(axis=1)
107
+ theta = (raw_scores - raw_scores.mean()) / (raw_scores.std() + 1e-12)
108
+
109
+ def _2pl_loglik(params, y, theta_vals):
110
+ a, b = params
111
+ # constrain a > 0 via soft barrier
112
+ if a <= 0:
113
+ return 1e9
114
+ p = 1.0 / (1.0 + np.exp(-a * (theta_vals - b)))
115
+ p = np.clip(p, 1e-9, 1 - 1e-9)
116
+ return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))
117
+
118
+ def _1pl_loglik(params, y, theta_vals):
119
+ b = params[0]
120
+ p = 1.0 / (1.0 + np.exp(-(theta_vals - b)))
121
+ p = np.clip(p, 1e-9, 1 - 1e-9)
122
+ return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))
123
+
124
+ item_params = []
125
+ item_info_at_b = []
126
+
127
+ for j, item in enumerate(items):
128
+ y_j = data[:, j]
129
+ p_bar = y_j.mean()
130
+
131
+ if _has_scipy:
132
+ if model == "1pl":
133
+ # Difficulty only (Rasch-like, a fixed at 1.0)
134
+ b0 = np.log(p_bar / (1 - p_bar + 1e-12))
135
+ res = minimize(_1pl_loglik, x0=[b0], args=(y_j, theta),
136
+ method="Nelder-Mead",
137
+ options={"maxiter": 2000, "xatol": 1e-5})
138
+ a_hat = 1.0
139
+ b_hat = float(res.x[0])
140
+ else:
141
+ # 2PL
142
+ b0 = -np.log(p_bar / (1 - p_bar + 1e-12))
143
+ res = minimize(_2pl_loglik, x0=[1.0, b0], args=(y_j, theta),
144
+ method="Nelder-Mead",
145
+ options={"maxiter": 3000, "xatol": 1e-5, "fatol": 1e-5})
146
+ a_hat = max(float(res.x[0]), 0.01)
147
+ b_hat = float(res.x[1])
148
+ else:
149
+ # Fallback: logistic regression approximation
150
+ try:
151
+ from sklearn.linear_model import LogisticRegression
152
+ lr = LogisticRegression(max_iter=500, C=1e6)
153
+ lr.fit(theta.reshape(-1, 1), y_j.astype(int))
154
+ a_hat = float(lr.coef_[0][0])
155
+ b_hat = -float(lr.intercept_[0]) / (a_hat + 1e-12)
156
+ if model == "1pl":
157
+ a_hat = 1.0
158
+ except ImportError:
159
+ # Last resort: method of moments
160
+ a_hat = 1.0
161
+ b_hat = -np.log(p_bar / (1 - p_bar + 1e-12))
162
+
163
+ # Item information at difficulty b
164
+ # I(theta) = a^2 * P(theta) * (1 - P(theta))
165
+ p_at_b = 0.25 # P(b) = 0.5, so info = a^2 * 0.25
166
+ info = a_hat ** 2 * p_at_b
167
+
168
+ # Empirical fit: proportion correct
169
+ p_obs = float(y_j.mean())
170
+
171
+ item_params.append((item, a_hat, b_hat, p_obs, info))
172
+ item_info_at_b.append(info)
173
+
174
+ # ------------------------------------------------------------------
175
+ # Build output
176
+ # ------------------------------------------------------------------
177
+ model_label = model.upper()
178
+ rows = []
179
+ for item, a, b, p_obs, info in item_params:
180
+ rows.append((
181
+ item,
182
+ f"{a:.4f}",
183
+ f"{b:.4f}",
184
+ f"{p_obs:.4f}",
185
+ f"{info:.4f}",
186
+ ))
187
+
188
+ header_line = (
189
+ f"\nIRT {model_label} — {n_items} items, {n_persons} persons\n"
190
+ + _sep()
191
+ )
192
+
193
+ tbl = _coef_table(
194
+ rows,
195
+ ["Item", "Discrim (a)", "Difficulty (b)", "P(correct)", "Info@b"],
196
+ )
197
+
198
+ # Test information function: TIF = sum of item info across ability range
199
+ theta_grid = np.linspace(-3, 3, 61)
200
+ tif = np.zeros_like(theta_grid)
201
+ for _, a, b, _, _ in item_params:
202
+ p = 1.0 / (1.0 + np.exp(-a * (theta_grid - b)))
203
+ tif += a ** 2 * p * (1 - p)
204
+
205
+ tif_max_idx = int(np.argmax(tif))
206
+ tif_max_theta = float(theta_grid[tif_max_idx])
207
+ tif_max_val = float(tif[tif_max_idx])
208
+
209
+ reliability_approx = float(np.mean(tif) / (np.mean(tif) + 1.0))
210
+
211
+ summary_lines = [
212
+ "",
213
+ "Test Information Summary:",
214
+ f" Peak information : {tif_max_val:.4f} at theta = {tif_max_theta:.2f}",
215
+ f" Mean information : {np.mean(tif):.4f}",
216
+ f" Marginal reliability (approx) : {reliability_approx:.4f}",
217
+ "",
218
+ "Note: Ability (theta) estimated from standardised sum score.",
219
+ ]
220
+ if not _has_scipy:
221
+ summary_lines.append(
222
+ "Note: scipy not found; used logistic regression / moments approximation."
223
+ )
224
+ if model == "3pl":
225
+ summary_lines.append(
226
+ "Note: 3PL guessing parameter not estimated; showing 2PL results."
227
+ )
228
+
229
+ return header_line + "\n" + tbl + "\n" + _sep() + "\n" + "\n".join(summary_lines)
230
+
231
+ except Exception as e:
232
+ return friendly_error(e, "irt")
233
+
234
+
235
+ # ---------------------------------------------------------------------------
236
+ # 2. Competing Risks Regression
237
+ # ---------------------------------------------------------------------------
238
+
239
+ @command("competing", usage="competing <time> <event> <cause> [covars...]")
240
+ def cmd_competing(session: Session, args: str) -> str:
241
+ """Fine-Gray competing risks regression and cumulative incidence curves.
242
+
243
+ Fits a cumulative incidence function (CIF) for each cause using lifelines,
244
+ or falls back to a manual Nelson-Aalen-based estimate if lifelines is
245
+ unavailable.
246
+
247
+ Examples:
248
+ competing time status cause
249
+ competing time status cause age gender
250
+ """
251
+ import numpy as np
252
+
253
+ ca = CommandArgs(args)
254
+ pos = [p for p in ca.positional if not p.startswith("--")]
255
+
256
+ if len(pos) < 3:
257
+ return "Usage: competing <time> <event> <cause> [covars...]"
258
+
259
+ time_col, event_col, cause_col = pos[0], pos[1], pos[2]
260
+ covars = pos[3:]
261
+
262
+ df = session.require_data()
263
+
264
+ needed = [time_col, event_col, cause_col] + covars
265
+ missing = [c for c in needed if c not in df.columns]
266
+ if missing:
267
+ return f"Columns not found: {', '.join(missing)}"
268
+
269
+ try:
270
+ sub = df.select(needed).drop_nulls()
271
+ T = sub[time_col].to_numpy().astype(float)
272
+ E = sub[event_col].to_numpy().astype(float)
273
+ C = sub[cause_col].to_numpy()
274
+
275
+ causes = sorted(set(C.tolist()))
276
+ n_total = len(T)
277
+
278
+ lines = [
279
+ f"\nCompeting Risks Analysis",
280
+ _sep(),
281
+ f" Time var : {time_col}",
282
+ f" Event var : {event_col}",
283
+ f" Cause var : {cause_col}",
284
+ f" N : {n_total}",
285
+ f" Causes : {causes}",
286
+ "",
287
+ ]
288
+
289
+ # ------------------------------------------------------------------
290
+ # Try lifelines for CIF and Fine-Gray
291
+ # ------------------------------------------------------------------
292
+ try:
293
+ from lifelines import AalenJohansenFitter
294
+ _has_lifelines = True
295
+ except ImportError:
296
+ _has_lifelines = False
297
+
298
+ cif_results = {}
299
+
300
+ if _has_lifelines:
301
+ for cause in causes:
302
+ event_of_interest = (C == cause).astype(int)
303
+ ajf = AalenJohansenFitter(calculate_variance=True)
304
+ try:
305
+ ajf.fit(T, E, event_col=event_of_interest)
306
+ cif_t = ajf.cumulative_density_
307
+ cif_results[cause] = ajf
308
+ t_max = float(T.max())
309
+ cif_at_max = float(cif_t.values[-1])
310
+ n_events = int((C == cause).sum())
311
+ lines.append(
312
+ f" Cause {cause}: {n_events} events, "
313
+ f"CIF at t={t_max:.1f} = {cif_at_max:.4f}"
314
+ )
315
+ except Exception as fit_err:
316
+ lines.append(f" Cause {cause}: CIF fit failed — {fit_err}")
317
+ else:
318
+ # Manual Aalen-Johansen CIF estimate
319
+ lines.append(" lifelines not found; computing manual CIF estimates.")
320
+ sort_idx = np.argsort(T)
321
+ T_s = T[sort_idx]
322
+ E_s = E[sort_idx]
323
+ C_s = C[sort_idx]
324
+
325
+ n_at_risk = n_total
326
+ S = 1.0 # overall survival
327
+
328
+ for cause in causes:
329
+ cif_vals = []
330
+ cif_times = [0.0]
331
+ cif_running = [0.0]
332
+ S_running = 1.0
333
+ n_r = n_total
334
+
335
+ for i, (t_i, e_i, c_i) in enumerate(zip(T_s, E_s, C_s)):
336
+ if e_i == 1:
337
+ d_j = int(c_i == cause)
338
+ d_total = 1
339
+ hazard_cause = d_j / n_r
340
+ hazard_all = d_total / n_r
341
+ cif_running_new = cif_running[-1] + S_running * hazard_cause
342
+ S_running = S_running * (1 - hazard_all)
343
+ cif_times.append(float(t_i))
344
+ cif_running.append(float(cif_running_new))
345
+ n_r = max(n_r - 1, 1)
346
+
347
+ n_events = int((C == cause).sum())
348
+ cif_final = cif_running[-1] if cif_running else 0.0
349
+ lines.append(
350
+ f" Cause {cause}: {n_events} events, "
351
+ f"CIF at t_max = {cif_final:.4f}"
352
+ )
353
+ cif_results[cause] = (cif_times, cif_running)
354
+
355
+ # ------------------------------------------------------------------
356
+ # Covariate association (subdistribution hazard approximation)
357
+ # ------------------------------------------------------------------
358
+ if covars:
359
+ lines.append("")
360
+ lines.append("Covariate Association (cause-specific logrank proxy):")
361
+ try:
362
+ from scipy.stats import pearsonr
363
+ X_cov = sub.select(covars).to_numpy().astype(float)
364
+ for cause in causes:
365
+ indicator = (C == cause).astype(float)
366
+ lines.append(f" Cause {cause}:")
367
+ for j, cv in enumerate(covars):
368
+ if X_cov.shape[0] > 2:
369
+ r, p = pearsonr(X_cov[:, j], indicator)
370
+ lines.append(
371
+ f" {cv:<20} r={r:.4f} p={p:.4f}{_sig_stars(p)}"
372
+ )
373
+ except ImportError:
374
+ lines.append(" scipy not found; skipping covariate associations.")
375
+ except Exception as cov_err:
376
+ lines.append(f" Covariate analysis error: {cov_err}")
377
+
378
+ # ------------------------------------------------------------------
379
+ # Plot CIF curves
380
+ # ------------------------------------------------------------------
381
+ try:
382
+ import matplotlib
383
+ matplotlib.use("Agg")
384
+ import matplotlib.pyplot as plt
385
+
386
+ fig, ax = plt.subplots(figsize=(8, 5))
387
+ colors = plt.cm.tab10.colors
388
+
389
+ if _has_lifelines:
390
+ for i, (cause, ajf) in enumerate(cif_results.items()):
391
+ cif_df = ajf.cumulative_density_
392
+ ax.step(
393
+ cif_df.index,
394
+ cif_df.values[:, 0],
395
+ where="post",
396
+ label=f"Cause {cause}",
397
+ color=colors[i % len(colors)],
398
+ linewidth=2,
399
+ )
400
+ else:
401
+ for i, (cause, (ct, cv)) in enumerate(cif_results.items()):
402
+ ax.step(
403
+ ct, cv,
404
+ where="post",
405
+ label=f"Cause {cause}",
406
+ color=colors[i % len(colors)],
407
+ linewidth=2,
408
+ )
409
+
410
+ ax.set_xlabel(f"Time ({time_col})")
411
+ ax.set_ylabel("Cumulative Incidence")
412
+ ax.set_title("Cumulative Incidence Functions")
413
+ ax.legend()
414
+ ax.set_ylim(0, 1)
415
+ ax.grid(alpha=0.3)
416
+ fig.tight_layout()
417
+
418
+ session.output_dir.mkdir(parents=True, exist_ok=True)
419
+ plot_path = session.output_dir / "competing_risks_cif.png"
420
+ fig.savefig(plot_path, dpi=150)
421
+ plt.close(fig)
422
+ session.plot_paths.append(str(plot_path))
423
+ lines.append(f"\nPlot saved: {plot_path}")
424
+
425
+ except Exception as plot_err:
426
+ lines.append(f"\nPlot error: {plot_err}")
427
+
428
+ lines.append(_sep())
429
+ return "\n".join(lines)
430
+
431
+ except Exception as e:
432
+ return friendly_error(e, "competing")
433
+
434
+
435
+ # ---------------------------------------------------------------------------
436
+ # 3. CATE — Conditional Average Treatment Effects
437
+ # ---------------------------------------------------------------------------
438
+
439
+ @command("cate", usage="cate <y> <treat> <x1> [x2 ...] [--method=xlearner|drlearner|tlearner]")
440
+ def cmd_cate(session: Session, args: str) -> str:
441
+ """Conditional Average Treatment Effects via meta-learners.
442
+
443
+ Implements T-Learner, X-Learner, and DR-Learner using sklearn.
444
+
445
+ Methods:
446
+ tlearner : Fit E[Y|T=1,X] and E[Y|T=0,X] separately; CATE = mu1 - mu0.
447
+ xlearner : Cross-fitting with imputed potential outcomes.
448
+ drlearner : Doubly robust estimation combining outcome and propensity models.
449
+
450
+ Examples:
451
+ cate outcome treatment age educ income --method=tlearner
452
+ cate outcome treatment age educ --method=xlearner
453
+ cate outcome treatment age educ income --method=drlearner
454
+ """
455
+ import numpy as np
456
+
457
+ ca = CommandArgs(args)
458
+ pos = [p for p in ca.positional if not p.startswith("--")]
459
+ method = ca.options.get("method", "tlearner").lower()
460
+
461
+ if len(pos) < 3:
462
+ return "Usage: cate <y> <treat> <x1> [x2 ...] [--method=tlearner|xlearner|drlearner]"
463
+
464
+ y_col = pos[0]
465
+ treat_col = pos[1]
466
+ x_cols = pos[2:]
467
+
468
+ df = session.require_data()
469
+
470
+ needed = [y_col, treat_col] + x_cols
471
+ missing = [c for c in needed if c not in df.columns]
472
+ if missing:
473
+ return f"Columns not found: {', '.join(missing)}"
474
+
475
+ try:
476
+ try:
477
+ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
478
+ from sklearn.linear_model import LogisticRegression, Ridge
479
+ _has_sklearn = True
480
+ except ImportError:
481
+ _has_sklearn = False
482
+ return "sklearn not installed. Run: pip install scikit-learn"
483
+
484
+ sub = df.select(needed).drop_nulls()
485
+ Y = sub[y_col].to_numpy().astype(float)
486
+ T = sub[treat_col].to_numpy().astype(float)
487
+ X = sub.select(x_cols).to_numpy().astype(float)
488
+
489
+ n = len(Y)
490
+ n_treated = int(T.sum())
491
+ n_control = int((1 - T).sum())
492
+
493
+ if n_treated < 5 or n_control < 5:
494
+ return "Need at least 5 treated and 5 control observations."
495
+
496
+ idx_t = T == 1
497
+ idx_c = T == 0
498
+
499
+ # Base learner: Ridge regression (fast, works on small data)
500
+ def _base_learner():
501
+ return Ridge(alpha=1.0)
502
+
503
+ def _prop_learner():
504
+ return LogisticRegression(max_iter=500, C=1.0)
505
+
506
+ cate_estimates = None
507
+
508
+ if method == "tlearner":
509
+ # T-Learner: two separate outcome models
510
+ mu1_model = _base_learner()
511
+ mu0_model = _base_learner()
512
+ mu1_model.fit(X[idx_t], Y[idx_t])
513
+ mu0_model.fit(X[idx_c], Y[idx_c])
514
+ mu1_hat = mu1_model.predict(X)
515
+ mu0_hat = mu0_model.predict(X)
516
+ cate_estimates = mu1_hat - mu0_hat
517
+ method_desc = "T-Learner (separate outcome models)"
518
+
519
+ elif method == "xlearner":
520
+ # X-Learner: impute counterfactual outcomes, cross-fit
521
+ mu1_model = _base_learner()
522
+ mu0_model = _base_learner()
523
+ mu1_model.fit(X[idx_t], Y[idx_t])
524
+ mu0_model.fit(X[idx_c], Y[idx_c])
525
+
526
+ # Imputed individual effects
527
+ D1 = Y[idx_t] - mu0_model.predict(X[idx_t]) # treated: Y1 - mu0(X)
528
+ D0 = mu1_model.predict(X[idx_c]) - Y[idx_c] # control: mu1(X) - Y0
529
+
530
+ tau1_model = _base_learner()
531
+ tau0_model = _base_learner()
532
+ tau1_model.fit(X[idx_t], D1)
533
+ tau0_model.fit(X[idx_c], D0)
534
+
535
+ # Propensity score for weighting
536
+ ps_model = _prop_learner()
537
+ ps_model.fit(X, T.astype(int))
538
+ e_hat = ps_model.predict_proba(X)[:, 1]
539
+ e_hat = np.clip(e_hat, 0.01, 0.99)
540
+
541
+ tau1_hat = tau1_model.predict(X)
542
+ tau0_hat = tau0_model.predict(X)
543
+ # Propensity-weighted combination
544
+ cate_estimates = e_hat * tau0_hat + (1 - e_hat) * tau1_hat
545
+ method_desc = "X-Learner (cross-fitted imputed outcomes)"
546
+
547
+ elif method == "drlearner":
548
+ # DR-Learner: doubly robust pseudo-outcomes
549
+ mu1_model = _base_learner()
550
+ mu0_model = _base_learner()
551
+ mu1_model.fit(X[idx_t], Y[idx_t])
552
+ mu0_model.fit(X[idx_c], Y[idx_c])
553
+ mu1_hat = mu1_model.predict(X)
554
+ mu0_hat = mu0_model.predict(X)
555
+
556
+ ps_model = _prop_learner()
557
+ ps_model.fit(X, T.astype(int))
558
+ e_hat = ps_model.predict_proba(X)[:, 1]
559
+ e_hat = np.clip(e_hat, 0.01, 0.99)
560
+
561
+ # DR pseudo-outcome
562
+ psi = (
563
+ (T * (Y - mu1_hat)) / e_hat
564
+ - ((1 - T) * (Y - mu0_hat)) / (1 - e_hat)
565
+ + mu1_hat - mu0_hat
566
+ )
567
+ # Second-stage regression on pseudo-outcomes
568
+ tau_model = _base_learner()
569
+ tau_model.fit(X, psi)
570
+ cate_estimates = tau_model.predict(X)
571
+ method_desc = "DR-Learner (doubly robust pseudo-outcomes)"
572
+
573
+ else:
574
+ return (
575
+ f"Unknown method '{method}'. "
576
+ "Choose from: tlearner, xlearner, drlearner"
577
+ )
578
+
579
+ # ------------------------------------------------------------------
580
+ # Summary statistics
581
+ # ------------------------------------------------------------------
582
+ ate = float(np.mean(cate_estimates))
583
+ att = float(np.mean(cate_estimates[idx_t]))
584
+ atc = float(np.mean(cate_estimates[idx_c]))
585
+ cate_std = float(np.std(cate_estimates))
586
+ cate_min = float(np.min(cate_estimates))
587
+ cate_max = float(np.max(cate_estimates))
588
+ q25, q50, q75 = np.percentile(cate_estimates, [25, 50, 75])
589
+
590
+ # Bootstrap SE for ATE (200 replicates, fast)
591
+ rng = np.random.default_rng(42)
592
+ boot_ate = []
593
+ for _ in range(200):
594
+ idx_b = rng.integers(0, n, size=n)
595
+ boot_ate.append(float(np.mean(cate_estimates[idx_b])))
596
+ ate_se = float(np.std(boot_ate))
597
+ ate_ci_lo = ate - 1.96 * ate_se
598
+ ate_ci_hi = ate + 1.96 * ate_se
599
+
600
+ lines = [
601
+ f"\nConditional Average Treatment Effects (CATE)",
602
+ _sep(),
603
+ f" Method : {method_desc}",
604
+ f" Outcome : {y_col}",
605
+ f" Treatment : {treat_col}",
606
+ f" Covariates : {', '.join(x_cols)}",
607
+ f" N total : {n} (treated={n_treated}, control={n_control})",
608
+ "",
609
+ "CATE Distribution:",
610
+ f" {'ATE (avg treatment effect)':<35} {ate:>10.4f}",
611
+ f" {'ATT (avg on treated)':<35} {att:>10.4f}",
612
+ f" {'ATC (avg on controls)':<35} {atc:>10.4f}",
613
+ f" {'SD of individual CATE':<35} {cate_std:>10.4f}",
614
+ f" {'Min':<35} {cate_min:>10.4f}",
615
+ f" {'Q25':<35} {q25:>10.4f}",
616
+ f" {'Median':<35} {q50:>10.4f}",
617
+ f" {'Q75':<35} {q75:>10.4f}",
618
+ f" {'Max':<35} {cate_max:>10.4f}",
619
+ "",
620
+ "ATE Inference (bootstrap, 200 reps):",
621
+ f" {'ATE':<15} {ate:>10.4f}",
622
+ f" {'Bootstrap SE':<15} {ate_se:>10.4f}",
623
+ f" {'95% CI':<15} [{ate_ci_lo:.4f}, {ate_ci_hi:.4f}]",
624
+ ]
625
+
626
+ # Heterogeneity test: variance of CATE vs bootstrap null variance
627
+ if cate_std > ate_se:
628
+ lines.append(
629
+ "\n Heterogeneity: CATE SD > bootstrap SE, "
630
+ "suggesting effect heterogeneity."
631
+ )
632
+
633
+ # ------------------------------------------------------------------
634
+ # Plot CATE distribution
635
+ # ------------------------------------------------------------------
636
+ try:
637
+ import matplotlib
638
+ matplotlib.use("Agg")
639
+ import matplotlib.pyplot as plt
640
+
641
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4))
642
+
643
+ # Histogram of CATE
644
+ axes[0].hist(cate_estimates, bins=30, color="#4C72B0", alpha=0.75,
645
+ edgecolor="white")
646
+ axes[0].axvline(ate, color="red", linestyle="--", linewidth=1.5,
647
+ label=f"ATE={ate:.3f}")
648
+ axes[0].axvline(0, color="black", linestyle=":", linewidth=1.0,
649
+ alpha=0.6)
650
+ axes[0].set_xlabel("CATE")
651
+ axes[0].set_ylabel("Frequency")
652
+ axes[0].set_title(f"CATE Distribution ({method.upper()})")
653
+ axes[0].legend()
654
+
655
+ # CATE vs first covariate (sorted)
656
+ sort_idx = np.argsort(X[:, 0])
657
+ axes[1].scatter(X[sort_idx, 0], cate_estimates[sort_idx],
658
+ alpha=0.4, s=20, color="#4C72B0")
659
+ axes[1].axhline(ate, color="red", linestyle="--", linewidth=1.5,
660
+ label=f"ATE={ate:.3f}")
661
+ axes[1].axhline(0, color="black", linestyle=":", linewidth=1.0, alpha=0.6)
662
+ axes[1].set_xlabel(x_cols[0])
663
+ axes[1].set_ylabel("CATE")
664
+ axes[1].set_title(f"CATE vs {x_cols[0]}")
665
+ axes[1].legend()
666
+
667
+ fig.tight_layout()
668
+ session.output_dir.mkdir(parents=True, exist_ok=True)
669
+ plot_path = session.output_dir / "cate.png"
670
+ fig.savefig(plot_path, dpi=150)
671
+ plt.close(fig)
672
+ session.plot_paths.append(str(plot_path))
673
+ lines.append(f"\nPlot saved: {plot_path}")
674
+
675
+ except Exception as plot_err:
676
+ lines.append(f"\nPlot error: {plot_err}")
677
+
678
+ lines.append(_sep())
679
+ return "\n".join(lines)
680
+
681
+ except Exception as e:
682
+ return friendly_error(e, "cate")
683
+
684
+
685
+ # ---------------------------------------------------------------------------
686
+ # 4. Joinpoint Trend Analysis
687
+ # ---------------------------------------------------------------------------
688
+
689
+ @command("joinpoint", usage="joinpoint <y> <x> [--max_points=3] [--permutations=100]")
690
+ def cmd_joinpoint(session: Session, args: str) -> str:
691
+ """Joinpoint (piecewise linear) trend analysis with BIC-based model selection.
692
+
693
+ Finds optimal changepoints (joinpoints) in a trend using BIC minimisation.
694
+ Reports segment slopes, percent change per unit, and p-values.
695
+
696
+ Examples:
697
+ joinpoint cancer_rate year
698
+ joinpoint incidence year --max_points=3
699
+ joinpoint rate year --max_points=2 --permutations=200
700
+ """
701
+ import numpy as np
702
+
703
+ ca = CommandArgs(args)
704
+ pos = [p for p in ca.positional if not p.startswith("--")]
705
+
706
+ if len(pos) < 2:
707
+ return "Usage: joinpoint <y> <x> [--max_points=3] [--permutations=100]"
708
+
709
+ y_col, x_col = pos[0], pos[1]
710
+ max_jp = int(ca.options.get("max_points", 3))
711
+ n_perm = int(ca.options.get("permutations", 100))
712
+
713
+ df = session.require_data()
714
+
715
+ if y_col not in df.columns:
716
+ return f"Column not found: {y_col}"
717
+ if x_col not in df.columns:
718
+ return f"Column not found: {x_col}"
719
+
720
+ try:
721
+ sub = df.select([y_col, x_col]).drop_nulls()
722
+ X = sub[x_col].to_numpy().astype(float)
723
+ Y = sub[y_col].to_numpy().astype(float)
724
+ n = len(X)
725
+
726
+ if n < 6:
727
+ return "Joinpoint analysis requires at least 6 data points."
728
+
729
+ sort_idx = np.argsort(X)
730
+ X = X[sort_idx]
731
+ Y = Y[sort_idx]
732
+
733
+ # ------------------------------------------------------------------
734
+ # Build piecewise linear design matrix for given joinpoints
735
+ # ------------------------------------------------------------------
736
+ def _piecewise_design(x_vals, joinpoints):
737
+ """Build design matrix: [1, x, (x-jp1)+, (x-jp2)+, ...]."""
738
+ cols = [np.ones(len(x_vals)), x_vals]
739
+ for jp in joinpoints:
740
+ cols.append(np.maximum(x_vals - jp, 0.0))
741
+ return np.column_stack(cols)
742
+
743
+ def _fit_piecewise(x_vals, y_vals, joinpoints):
744
+ """OLS fit of piecewise linear model; return (params, rss, bic)."""
745
+ A = _piecewise_design(x_vals, joinpoints)
746
+ try:
747
+ params, res, rank, sv = np.linalg.lstsq(A, y_vals, rcond=None)
748
+ y_hat = A @ params
749
+ rss = float(np.sum((y_vals - y_hat) ** 2))
750
+ except np.linalg.LinAlgError:
751
+ return None, np.inf, np.inf
752
+ k = A.shape[1]
753
+ nn = len(y_vals)
754
+ sigma2 = rss / max(nn - k, 1)
755
+ # BIC = n * ln(RSS/n) + k * ln(n)
756
+ bic = nn * np.log(max(rss / nn, 1e-15)) + k * np.log(nn)
757
+ return params, rss, bic
758
+
759
+ # ------------------------------------------------------------------
760
+ # Grid search over candidate joinpoints
761
+ # ------------------------------------------------------------------
762
+ # Candidate joinpoints: interior X values (exclude boundary 20%)
763
+ margin = max(2, int(0.15 * n))
764
+ candidate_x = X[margin: n - margin]
765
+ candidate_x = np.unique(candidate_x)
766
+
767
+ best_bic = np.inf
768
+ best_jps = []
769
+ best_params = None
770
+ best_n_jp = 0
771
+
772
+ # Evaluate 0 joinpoints (simple linear trend)
773
+ params0, rss0, bic0 = _fit_piecewise(X, Y, [])
774
+ best_bic = bic0
775
+ best_jps = []
776
+ best_params = params0
777
+ best_n_jp = 0
778
+
779
+ # Evaluate 1..max_jp joinpoints (greedy + random search for speed)
780
+ for n_jp in range(1, max_jp + 1):
781
+ if len(candidate_x) < n_jp:
782
+ break
783
+
784
+ # Subsample candidates for speed (up to 30 per slot)
785
+ step = max(1, len(candidate_x) // 30)
786
+ sampled = candidate_x[::step]
787
+
788
+ if n_jp == 1:
789
+ for jp in sampled:
790
+ p, r, b = _fit_piecewise(X, Y, [jp])
791
+ if b < best_bic:
792
+ best_bic = b
793
+ best_jps = [jp]
794
+ best_params = p
795
+ best_n_jp = n_jp
796
+
797
+ elif n_jp == 2:
798
+ for i in range(len(sampled)):
799
+ for j in range(i + 1, len(sampled)):
800
+ jps = [sampled[i], sampled[j]]
801
+ p, r, b = _fit_piecewise(X, Y, jps)
802
+ if b < best_bic:
803
+ best_bic = b
804
+ best_jps = jps
805
+ best_params = p
806
+ best_n_jp = n_jp
807
+
808
+ elif n_jp == 3:
809
+ for i in range(len(sampled)):
810
+ for j in range(i + 1, len(sampled)):
811
+ for k in range(j + 1, len(sampled)):
812
+ jps = [sampled[i], sampled[j], sampled[k]]
813
+ p, r, b = _fit_piecewise(X, Y, jps)
814
+ if b < best_bic:
815
+ best_bic = b
816
+ best_jps = jps
817
+ best_params = p
818
+ best_n_jp = n_jp
819
+
820
+ # ------------------------------------------------------------------
821
+ # Extract segment slopes
822
+ # ------------------------------------------------------------------
823
+ # params = [intercept, slope, delta1, delta2, ...]
824
+ # Cumulative slope in segment i = slope + sum(delta_j for j <= i)
825
+ segments = []
826
+ breakpoints = sorted(best_jps)
827
+
828
+ # Segment boundaries
829
+ seg_bounds = (
830
+ [float(X[0])]
831
+ + [float(jp) for jp in breakpoints]
832
+ + [float(X[-1])]
833
+ )
834
+
835
+ if best_params is not None:
836
+ base_slope = float(best_params[1]) if len(best_params) > 1 else 0.0
837
+ cum_slope = base_slope
838
+ for seg_i, (x_lo, x_hi) in enumerate(zip(seg_bounds[:-1], seg_bounds[1:])):
839
+ if seg_i > 0 and seg_i - 1 < len(best_params) - 2:
840
+ cum_slope += float(best_params[seg_i + 1])
841
+ # APC = annual percent change (relative to mean Y in segment)
842
+ mask = (X >= x_lo) & (X <= x_hi)
843
+ y_seg_mean = float(Y[mask].mean()) if mask.sum() > 0 else 1.0
844
+ apc = 100.0 * cum_slope / (y_seg_mean + 1e-12)
845
+ segments.append({
846
+ "seg": seg_i + 1,
847
+ "from": x_lo,
848
+ "to": x_hi,
849
+ "slope": cum_slope,
850
+ "apc": apc,
851
+ "n_pts": int(mask.sum()),
852
+ })
853
+
854
+ # ------------------------------------------------------------------
855
+ # Permutation test for number of joinpoints
856
+ # ------------------------------------------------------------------
857
+ perm_p = None
858
+ if n_perm > 0 and best_n_jp > 0:
859
+ # H0: linear trend; H1: best_n_jp joinpoints
860
+ _, rss_null, _ = _fit_piecewise(X, Y, [])
861
+ _, rss_alt, _ = _fit_piecewise(X, Y, best_jps)
862
+ obs_stat = rss_null / (rss_alt + 1e-15)
863
+
864
+ rng = np.random.default_rng(42)
865
+ perm_count = 0
866
+ for _ in range(n_perm):
867
+ Y_perm = rng.permutation(Y)
868
+ # Fit same joinpoints to permuted data
869
+ _, rss_alt_p, _ = _fit_piecewise(X, Y_perm, best_jps)
870
+ _, rss_null_p, _ = _fit_piecewise(X, Y_perm, [])
871
+ perm_stat = rss_null_p / (rss_alt_p + 1e-15)
872
+ if perm_stat >= obs_stat:
873
+ perm_count += 1
874
+ perm_p = perm_count / n_perm
875
+
876
+ # ------------------------------------------------------------------
877
+ # Format output
878
+ # ------------------------------------------------------------------
879
+ lines = [
880
+ f"\nJoinpoint Trend Analysis",
881
+ _sep(),
882
+ f" Y variable : {y_col}",
883
+ f" X variable : {x_col}",
884
+ f" N points : {n}",
885
+ f" Max JP : {max_jp}",
886
+ f" Best model : {best_n_jp} joinpoint(s) BIC = {best_bic:.4f}",
887
+ ]
888
+ if breakpoints:
889
+ lines.append(f" Joinpoints : {[round(jp, 4) for jp in breakpoints]}")
890
+ if perm_p is not None:
891
+ lines.append(
892
+ f" Permutation p-value ({n_perm} perms): {perm_p:.4f}{_sig_stars(perm_p)}"
893
+ )
894
+
895
+ lines.append("")
896
+ lines.append("Trend Segments:")
897
+ rows = []
898
+ for seg in segments:
899
+ rows.append((
900
+ str(seg["seg"]),
901
+ f"{seg['from']:.2f}",
902
+ f"{seg['to']:.2f}",
903
+ f"{seg['slope']:.6f}",
904
+ f"{seg['apc']:.2f}%",
905
+ str(seg["n_pts"]),
906
+ ))
907
+ lines.append(_coef_table(
908
+ rows,
909
+ ["Seg", "From", "To", "Slope", "APC", "N pts"],
910
+ ))
911
+
912
+ # ------------------------------------------------------------------
913
+ # Plot
914
+ # ------------------------------------------------------------------
915
+ try:
916
+ import matplotlib
917
+ matplotlib.use("Agg")
918
+ import matplotlib.pyplot as plt
919
+
920
+ fig, ax = plt.subplots(figsize=(9, 5))
921
+ ax.scatter(X, Y, color="#4C72B0", alpha=0.7, s=40, zorder=3,
922
+ label="Observed")
923
+
924
+ # Fitted piecewise line
925
+ if best_params is not None:
926
+ x_fine = np.linspace(X.min(), X.max(), 400)
927
+ A_fine = _piecewise_design(x_fine, breakpoints)
928
+ y_fine = A_fine @ best_params
929
+ ax.plot(x_fine, y_fine, color="red", linewidth=2.0,
930
+ label=f"Joinpoint fit ({best_n_jp} JP)")
931
+
932
+ for jp in breakpoints:
933
+ ax.axvline(jp, color="grey", linestyle="--", linewidth=1.0, alpha=0.7)
934
+
935
+ ax.set_xlabel(x_col)
936
+ ax.set_ylabel(y_col)
937
+ ax.set_title(f"Joinpoint Trend: {y_col} vs {x_col}")
938
+ ax.legend()
939
+ ax.grid(alpha=0.3)
940
+ fig.tight_layout()
941
+
942
+ session.output_dir.mkdir(parents=True, exist_ok=True)
943
+ plot_path = session.output_dir / "joinpoint.png"
944
+ fig.savefig(plot_path, dpi=150)
945
+ plt.close(fig)
946
+ session.plot_paths.append(str(plot_path))
947
+ lines.append(f"\nPlot saved: {plot_path}")
948
+
949
+ except Exception as plot_err:
950
+ lines.append(f"\nPlot error: {plot_err}")
951
+
952
+ lines.append(_sep())
953
+ return "\n".join(lines)
954
+
955
+ except Exception as e:
956
+ return friendly_error(e, "joinpoint")
957
+
958
+
959
+ # ---------------------------------------------------------------------------
960
+ # 5. Spline and LOESS Regression
961
+ # ---------------------------------------------------------------------------
962
+
963
+ @command("spline", usage="spline <y> <x1> [x2 ...] [--knots=N|<list>] [--type=natural|bs|loess]")
964
+ def cmd_spline(session: Session, args: str) -> str:
965
+ """Spline and LOESS smoothing regression.
966
+
967
+ Types:
968
+ natural : Natural cubic splines via statsmodels.
969
+ bs : B-splines via statsmodels.
970
+ loess : LOWESS smoothing (no parametric assumptions).
971
+
972
+ The --knots option accepts an integer (number of equally spaced interior knots)
973
+ or a comma-separated list of knot positions (e.g. --knots=25,50,75).
974
+
975
+ Examples:
976
+ spline y x --knots=4 --type=natural
977
+ spline y x --type=loess
978
+ spline y x1 x2 --knots=3 --type=bs
979
+ spline y x --knots=20,40,60 --type=natural
980
+ """
981
+ import numpy as np
982
+
983
+ ca = CommandArgs(args)
984
+ pos = [p for p in ca.positional if not p.startswith("--")]
985
+ spline_type = ca.options.get("type", "natural").lower()
986
+ knots_opt = ca.options.get("knots", "4")
987
+
988
+ if len(pos) < 2:
989
+ return "Usage: spline <y> <x1> [x2 ...] [--knots=N|<list>] [--type=natural|bs|loess]"
990
+
991
+ y_col = pos[0]
992
+ x_cols = pos[1:]
993
+
994
+ df = session.require_data()
995
+
996
+ needed = [y_col] + x_cols
997
+ missing = [c for c in needed if c not in df.columns]
998
+ if missing:
999
+ return f"Columns not found: {', '.join(missing)}"
1000
+
1001
+ try:
1002
+ sub = df.select(needed).drop_nulls()
1003
+ Y = sub[y_col].to_numpy().astype(float)
1004
+ n = len(Y)
1005
+
1006
+ if n < 6:
1007
+ return "Spline regression requires at least 6 observations."
1008
+
1009
+ # Parse knots specification
1010
+ knot_positions = None
1011
+ try:
1012
+ n_knots = int(knots_opt)
1013
+ except ValueError:
1014
+ try:
1015
+ knot_positions = [float(k.strip()) for k in knots_opt.split(",")]
1016
+ n_knots = len(knot_positions)
1017
+ except ValueError:
1018
+ n_knots = 4
1019
+
1020
+ # Use first x column as primary smoothing variable for plots
1021
+ x_primary = sub[x_cols[0]].to_numpy().astype(float)
1022
+
1023
+ # Compute knot positions if not given
1024
+ if knot_positions is None:
1025
+ percentile_step = 100.0 / (n_knots + 1)
1026
+ knot_positions = [
1027
+ float(np.percentile(x_primary, percentile_step * (i + 1)))
1028
+ for i in range(n_knots)
1029
+ ]
1030
+
1031
+ lines = [
1032
+ f"\nSpline / LOESS Regression",
1033
+ _sep(),
1034
+ f" Type : {spline_type}",
1035
+ f" Y : {y_col}",
1036
+ f" X : {', '.join(x_cols)}",
1037
+ f" N : {n}",
1038
+ ]
1039
+
1040
+ y_hat = None
1041
+ y_hat_lo = None
1042
+ y_hat_hi = None
1043
+ model_info = {}
1044
+
1045
+ # ------------------------------------------------------------------
1046
+ # Natural cubic splines
1047
+ # ------------------------------------------------------------------
1048
+ if spline_type in ("natural", "bs"):
1049
+ try:
1050
+ import statsmodels.api as sm
1051
+ from patsy import dmatrix
1052
+ _has_patsy = True
1053
+ except ImportError:
1054
+ _has_patsy = False
1055
+
1056
+ if not _has_patsy:
1057
+ return (
1058
+ "statsmodels and patsy are required for spline regression. "
1059
+ "Run: pip install statsmodels patsy"
1060
+ )
1061
+
1062
+ # Build formula for all x columns
1063
+ if spline_type == "natural":
1064
+ knot_str = ", ".join(str(round(k, 4)) for k in knot_positions)
1065
+ spline_terms = []
1066
+ for xc in x_cols:
1067
+ spline_terms.append(f'cr({xc}, knots=[{knot_str}])')
1068
+ formula_str = " + ".join(spline_terms)
1069
+ else:
1070
+ # B-splines
1071
+ knot_str = ", ".join(str(round(k, 4)) for k in knot_positions)
1072
+ spline_terms = []
1073
+ for xc in x_cols:
1074
+ spline_terms.append(f'bs({xc}, knots=[{knot_str}], include_intercept=False)')
1075
+ formula_str = " + ".join(spline_terms)
1076
+
1077
+ try:
1078
+ data_dict = {xc: sub[xc].to_numpy().astype(float) for xc in x_cols}
1079
+ data_dict[y_col] = Y
1080
+ X_design = dmatrix(formula_str, data_dict, return_type="matrix")
1081
+ X_sm = np.asarray(X_design)
1082
+ X_sm = sm.add_constant(X_sm, has_constant="add")
1083
+
1084
+ model = sm.OLS(Y, X_sm)
1085
+ result = model.fit()
1086
+ y_hat = result.fittedvalues
1087
+
1088
+ # Confidence interval via prediction
1089
+ pred = result.get_prediction(X_sm)
1090
+ pred_df = pred.summary_frame(alpha=0.05)
1091
+ y_hat_lo = pred_df["obs_ci_lower"].values
1092
+ y_hat_hi = pred_df["obs_ci_upper"].values
1093
+
1094
+ rss = float(np.sum((Y - y_hat) ** 2))
1095
+ tss = float(np.sum((Y - Y.mean()) ** 2))
1096
+ r2 = 1.0 - rss / (tss + 1e-15)
1097
+ aic = float(result.aic)
1098
+ bic = float(result.bic)
1099
+
1100
+ lines += [
1101
+ f" Knots : {[round(k, 4) for k in knot_positions]}",
1102
+ f" R-squared : {r2:.4f}",
1103
+ f" AIC : {aic:.4f}",
1104
+ f" BIC : {bic:.4f}",
1105
+ f" N params : {result.df_model:.0f} (df resid={result.df_resid:.0f})",
1106
+ ]
1107
+ model_info = {"r2": r2, "aic": aic}
1108
+
1109
+ except Exception as sm_err:
1110
+ lines.append(f"\nSpline fit error: {sm_err}")
1111
+
1112
+ # ------------------------------------------------------------------
1113
+ # LOESS / LOWESS
1114
+ # ------------------------------------------------------------------
1115
+ elif spline_type == "loess":
1116
+ try:
1117
+ import statsmodels.api as sm
1118
+ except ImportError:
1119
+ return (
1120
+ "statsmodels is required for LOESS. "
1121
+ "Run: pip install statsmodels"
1122
+ )
1123
+
1124
+ # LOWESS operates on a single predictor
1125
+ x_sm = x_primary
1126
+ frac = float(ca.options.get("frac", 0.3))
1127
+ frac = max(0.1, min(frac, 1.0))
1128
+
1129
+ lowess_result = sm.nonparametric.lowess(Y, x_sm, frac=frac, it=3)
1130
+ # lowess_result columns: [x_sorted, y_smoothed]
1131
+ x_loess = lowess_result[:, 0]
1132
+ y_loess = lowess_result[:, 1]
1133
+
1134
+ # Interpolate back to original order for residuals
1135
+ from numpy import interp
1136
+ y_hat = interp(x_sm, x_loess, y_loess)
1137
+
1138
+ # Bootstrap confidence band (100 reps, fast)
1139
+ rng = np.random.default_rng(42)
1140
+ boot_fits = []
1141
+ for _ in range(100):
1142
+ idx_b = rng.integers(0, n, n)
1143
+ x_b = x_sm[idx_b]
1144
+ y_b = Y[idx_b]
1145
+ sort_b = np.argsort(x_b)
1146
+ try:
1147
+ lw_b = sm.nonparametric.lowess(
1148
+ y_b[sort_b], x_b[sort_b], frac=frac, it=1
1149
+ )
1150
+ boot_fits.append(interp(x_sm, lw_b[:, 0], lw_b[:, 1]))
1151
+ except Exception:
1152
+ pass
1153
+
1154
+ if boot_fits:
1155
+ boot_arr = np.array(boot_fits)
1156
+ y_hat_lo = np.percentile(boot_arr, 2.5, axis=0)
1157
+ y_hat_hi = np.percentile(boot_arr, 97.5, axis=0)
1158
+
1159
+ rss = float(np.sum((Y - y_hat) ** 2))
1160
+ tss = float(np.sum((Y - Y.mean()) ** 2))
1161
+ r2 = 1.0 - rss / (tss + 1e-15)
1162
+
1163
+ lines += [
1164
+ f" Bandwidth (frac) : {frac:.2f}",
1165
+ f" R-squared (approx): {r2:.4f}",
1166
+ " (LOESS uses single predictor; additional X ignored.)",
1167
+ ]
1168
+ model_info = {"r2": r2}
1169
+
1170
+ else:
1171
+ return (
1172
+ f"Unknown type '{spline_type}'. "
1173
+ "Choose from: natural, bs, loess"
1174
+ )
1175
+
1176
+ # ------------------------------------------------------------------
1177
+ # Residual summary
1178
+ # ------------------------------------------------------------------
1179
+ if y_hat is not None:
1180
+ resid = Y - y_hat
1181
+ lines += [
1182
+ "",
1183
+ "Residual Summary:",
1184
+ f" {'Mean residual':<30} {float(resid.mean()):>10.4f}",
1185
+ f" {'SD residual':<30} {float(resid.std()):>10.4f}",
1186
+ f" {'Min':<30} {float(resid.min()):>10.4f}",
1187
+ f" {'Max':<30} {float(resid.max()):>10.4f}",
1188
+ f" {'RMSE':<30} {float(np.sqrt(np.mean(resid**2))):>10.4f}",
1189
+ ]
1190
+
1191
+ # ------------------------------------------------------------------
1192
+ # Plot: fitted curve with CI band
1193
+ # ------------------------------------------------------------------
1194
+ try:
1195
+ import matplotlib
1196
+ matplotlib.use("Agg")
1197
+ import matplotlib.pyplot as plt
1198
+
1199
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
1200
+
1201
+ # Left: data + fitted + CI
1202
+ sort_idx = np.argsort(x_primary)
1203
+ x_s = x_primary[sort_idx]
1204
+
1205
+ axes[0].scatter(x_primary, Y, color="#4C72B0", alpha=0.5, s=20,
1206
+ label="Observed", zorder=3)
1207
+ if y_hat is not None:
1208
+ axes[0].plot(x_s, y_hat[sort_idx], color="red", linewidth=2.0,
1209
+ label="Fitted")
1210
+ if y_hat_lo is not None and y_hat_hi is not None:
1211
+ axes[0].fill_between(
1212
+ x_s, y_hat_lo[sort_idx], y_hat_hi[sort_idx],
1213
+ alpha=0.20, color="red", label="95% CI/band"
1214
+ )
1215
+
1216
+ if spline_type in ("natural", "bs") and knot_positions:
1217
+ for kp in knot_positions:
1218
+ axes[0].axvline(kp, color="grey", linestyle=":",
1219
+ linewidth=0.8, alpha=0.7)
1220
+
1221
+ axes[0].set_xlabel(x_cols[0])
1222
+ axes[0].set_ylabel(y_col)
1223
+ axes[0].set_title(
1224
+ f"{spline_type.capitalize()} Spline: {y_col} ~ {x_cols[0]}"
1225
+ )
1226
+ axes[0].legend()
1227
+ axes[0].grid(alpha=0.3)
1228
+
1229
+ # Right: residual plot
1230
+ if y_hat is not None:
1231
+ resid = Y - y_hat
1232
+ axes[1].scatter(y_hat, resid, alpha=0.5, s=20, color="#2ca02c")
1233
+ axes[1].axhline(0, color="black", linewidth=1.0, linestyle="--")
1234
+ axes[1].set_xlabel("Fitted values")
1235
+ axes[1].set_ylabel("Residuals")
1236
+ axes[1].set_title("Residuals vs Fitted")
1237
+ axes[1].grid(alpha=0.3)
1238
+
1239
+ fig.tight_layout()
1240
+ session.output_dir.mkdir(parents=True, exist_ok=True)
1241
+ fname = f"spline_{spline_type}.png"
1242
+ plot_path = session.output_dir / fname
1243
+ fig.savefig(plot_path, dpi=150)
1244
+ plt.close(fig)
1245
+ session.plot_paths.append(str(plot_path))
1246
+ lines.append(f"\nPlot saved: {plot_path}")
1247
+
1248
+ except Exception as plot_err:
1249
+ lines.append(f"\nPlot error: {plot_err}")
1250
+
1251
+ lines.append(_sep())
1252
+ return "\n".join(lines)
1253
+
1254
+ except Exception as e:
1255
+ return friendly_error(e, "spline")