openstat-cli 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstat/__init__.py +3 -0
- openstat/__main__.py +4 -0
- openstat/backends/__init__.py +16 -0
- openstat/backends/duckdb_backend.py +70 -0
- openstat/backends/polars_backend.py +52 -0
- openstat/cli.py +92 -0
- openstat/commands/__init__.py +82 -0
- openstat/commands/adv_stat_cmds.py +1255 -0
- openstat/commands/advanced_ml_cmds.py +576 -0
- openstat/commands/advreg_cmds.py +207 -0
- openstat/commands/alias_cmds.py +135 -0
- openstat/commands/arch_cmds.py +82 -0
- openstat/commands/arules_cmds.py +111 -0
- openstat/commands/automodel_cmds.py +212 -0
- openstat/commands/backend_cmds.py +82 -0
- openstat/commands/base.py +170 -0
- openstat/commands/bayes_cmds.py +71 -0
- openstat/commands/causal_cmds.py +269 -0
- openstat/commands/cluster_cmds.py +152 -0
- openstat/commands/data_cmds.py +996 -0
- openstat/commands/datamanip_cmds.py +672 -0
- openstat/commands/dataquality_cmds.py +174 -0
- openstat/commands/datetime_cmds.py +176 -0
- openstat/commands/dimreduce_cmds.py +184 -0
- openstat/commands/discrete_cmds.py +149 -0
- openstat/commands/dsl_cmds.py +143 -0
- openstat/commands/epi_cmds.py +93 -0
- openstat/commands/equiv_tobit_cmds.py +94 -0
- openstat/commands/esttab_cmds.py +196 -0
- openstat/commands/export_beamer_cmds.py +142 -0
- openstat/commands/export_cmds.py +201 -0
- openstat/commands/export_extra_cmds.py +240 -0
- openstat/commands/factor_cmds.py +180 -0
- openstat/commands/groupby_cmds.py +155 -0
- openstat/commands/help_cmds.py +237 -0
- openstat/commands/i18n_cmds.py +43 -0
- openstat/commands/import_extra_cmds.py +561 -0
- openstat/commands/influence_cmds.py +134 -0
- openstat/commands/iv_cmds.py +106 -0
- openstat/commands/manova_cmds.py +105 -0
- openstat/commands/mediate_cmds.py +233 -0
- openstat/commands/meta_cmds.py +284 -0
- openstat/commands/mi_cmds.py +228 -0
- openstat/commands/mixed_cmds.py +79 -0
- openstat/commands/mixture_changepoint_cmds.py +166 -0
- openstat/commands/ml_adv_cmds.py +147 -0
- openstat/commands/ml_cmds.py +178 -0
- openstat/commands/model_eval_cmds.py +142 -0
- openstat/commands/network_cmds.py +288 -0
- openstat/commands/nlquery_cmds.py +161 -0
- openstat/commands/nonparam_cmds.py +149 -0
- openstat/commands/outreg_cmds.py +247 -0
- openstat/commands/panel_cmds.py +141 -0
- openstat/commands/pdf_cmds.py +226 -0
- openstat/commands/pipeline_cmds.py +319 -0
- openstat/commands/plot_cmds.py +189 -0
- openstat/commands/plugin_cmds.py +79 -0
- openstat/commands/posthoc_cmds.py +153 -0
- openstat/commands/power_cmds.py +172 -0
- openstat/commands/profile_cmds.py +246 -0
- openstat/commands/rbridge_cmds.py +81 -0
- openstat/commands/regex_cmds.py +104 -0
- openstat/commands/report_cmds.py +48 -0
- openstat/commands/repro_cmds.py +129 -0
- openstat/commands/resampling_cmds.py +109 -0
- openstat/commands/reshape_cmds.py +223 -0
- openstat/commands/sem_cmds.py +177 -0
- openstat/commands/stat_cmds.py +1040 -0
- openstat/commands/stata_import_cmds.py +215 -0
- openstat/commands/string_cmds.py +124 -0
- openstat/commands/surv_cmds.py +145 -0
- openstat/commands/survey_cmds.py +153 -0
- openstat/commands/textanalysis_cmds.py +192 -0
- openstat/commands/ts_adv_cmds.py +136 -0
- openstat/commands/ts_cmds.py +195 -0
- openstat/commands/tui_cmds.py +111 -0
- openstat/commands/ux_cmds.py +191 -0
- openstat/commands/validate_cmds.py +270 -0
- openstat/commands/viz_adv_cmds.py +312 -0
- openstat/commands/viz_extra_cmds.py +251 -0
- openstat/commands/watch_cmds.py +69 -0
- openstat/config.py +106 -0
- openstat/dsl/__init__.py +0 -0
- openstat/dsl/parser.py +332 -0
- openstat/dsl/tokenizer.py +105 -0
- openstat/i18n.py +120 -0
- openstat/io/__init__.py +0 -0
- openstat/io/loader.py +187 -0
- openstat/jupyter/__init__.py +18 -0
- openstat/jupyter/display.py +18 -0
- openstat/jupyter/magic.py +60 -0
- openstat/logging_config.py +59 -0
- openstat/plots/__init__.py +0 -0
- openstat/plots/plotter.py +437 -0
- openstat/plots/surv_plots.py +32 -0
- openstat/plots/ts_plots.py +59 -0
- openstat/plugins/__init__.py +5 -0
- openstat/plugins/manager.py +69 -0
- openstat/repl.py +457 -0
- openstat/reporting/__init__.py +0 -0
- openstat/reporting/eda.py +208 -0
- openstat/reporting/report.py +67 -0
- openstat/script_runner.py +319 -0
- openstat/session.py +133 -0
- openstat/stats/__init__.py +0 -0
- openstat/stats/advanced_regression.py +269 -0
- openstat/stats/arch_garch.py +84 -0
- openstat/stats/bayesian.py +103 -0
- openstat/stats/causal.py +258 -0
- openstat/stats/clustering.py +206 -0
- openstat/stats/discrete.py +311 -0
- openstat/stats/epidemiology.py +119 -0
- openstat/stats/equiv_tobit.py +163 -0
- openstat/stats/factor.py +174 -0
- openstat/stats/imputation.py +282 -0
- openstat/stats/influence.py +78 -0
- openstat/stats/iv.py +131 -0
- openstat/stats/manova.py +124 -0
- openstat/stats/mixed.py +128 -0
- openstat/stats/ml.py +275 -0
- openstat/stats/ml_advanced.py +117 -0
- openstat/stats/model_eval.py +183 -0
- openstat/stats/models.py +1342 -0
- openstat/stats/nonparametric.py +130 -0
- openstat/stats/panel.py +179 -0
- openstat/stats/power.py +295 -0
- openstat/stats/resampling.py +203 -0
- openstat/stats/survey.py +213 -0
- openstat/stats/survival.py +196 -0
- openstat/stats/timeseries.py +142 -0
- openstat/stats/ts_advanced.py +114 -0
- openstat/types.py +11 -0
- openstat/web/__init__.py +1 -0
- openstat/web/app.py +117 -0
- openstat/web/session_manager.py +73 -0
- openstat/web/static/app.js +117 -0
- openstat/web/static/index.html +38 -0
- openstat/web/static/style.css +103 -0
- openstat_cli-1.0.0.dist-info/METADATA +748 -0
- openstat_cli-1.0.0.dist-info/RECORD +143 -0
- openstat_cli-1.0.0.dist-info/WHEEL +4 -0
- openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
- openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
"""Plotting: histogram, scatter, line, box, bar, heatmap via matplotlib."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import matplotlib
|
|
8
|
+
matplotlib.use("Agg") # non-interactive backend
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import numpy as np
|
|
11
|
+
import polars as pl
|
|
12
|
+
|
|
13
|
+
from openstat.config import get_config
|
|
14
|
+
from openstat.types import NUMERIC_DTYPES
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _validate_col(df: pl.DataFrame, col: str) -> None:
|
|
18
|
+
if col not in df.columns:
|
|
19
|
+
raise ValueError(f"Column not found: {col}")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _unique_path(directory: Path, stem: str, suffix: str = ".png") -> Path:
|
|
23
|
+
"""Return a path that does not collide with existing files.
|
|
24
|
+
|
|
25
|
+
First try ``stem.suffix``; if it exists, try ``stem_2.suffix``, etc.
|
|
26
|
+
"""
|
|
27
|
+
candidate = directory / f"{stem}{suffix}"
|
|
28
|
+
if not candidate.exists():
|
|
29
|
+
return candidate
|
|
30
|
+
counter = 2
|
|
31
|
+
while True:
|
|
32
|
+
candidate = directory / f"{stem}_{counter}{suffix}"
|
|
33
|
+
if not candidate.exists():
|
|
34
|
+
return candidate
|
|
35
|
+
counter += 1
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def plot_histogram(
|
|
39
|
+
df: pl.DataFrame, col: str, output_dir: Path, *, bins: int = 30
|
|
40
|
+
) -> Path:
|
|
41
|
+
"""Create a histogram and save to PNG."""
|
|
42
|
+
cfg = get_config()
|
|
43
|
+
_validate_col(df, col)
|
|
44
|
+
data = df[col].drop_nulls().to_numpy()
|
|
45
|
+
|
|
46
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
47
|
+
ax.hist(data, bins=bins, edgecolor="white", alpha=0.85, color="#4C72B0")
|
|
48
|
+
ax.set_xlabel(col)
|
|
49
|
+
ax.set_ylabel("Frequency")
|
|
50
|
+
ax.set_title(f"Histogram of {col}")
|
|
51
|
+
fig.tight_layout()
|
|
52
|
+
|
|
53
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
path = _unique_path(output_dir, f"hist_{col}")
|
|
55
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
56
|
+
plt.close(fig)
|
|
57
|
+
return path
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def plot_scatter(
|
|
61
|
+
df: pl.DataFrame, y_col: str, x_col: str, output_dir: Path
|
|
62
|
+
) -> Path:
|
|
63
|
+
"""Create a scatter plot and save to PNG."""
|
|
64
|
+
cfg = get_config()
|
|
65
|
+
_validate_col(df, y_col)
|
|
66
|
+
_validate_col(df, x_col)
|
|
67
|
+
|
|
68
|
+
sub = df.select([x_col, y_col]).drop_nulls()
|
|
69
|
+
x = sub[x_col].to_numpy()
|
|
70
|
+
y = sub[y_col].to_numpy()
|
|
71
|
+
|
|
72
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
73
|
+
ax.scatter(x, y, alpha=0.6, s=20, color="#4C72B0")
|
|
74
|
+
ax.set_xlabel(x_col)
|
|
75
|
+
ax.set_ylabel(y_col)
|
|
76
|
+
ax.set_title(f"{y_col} vs {x_col}")
|
|
77
|
+
fig.tight_layout()
|
|
78
|
+
|
|
79
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
path = _unique_path(output_dir, f"scatter_{y_col}_vs_{x_col}")
|
|
81
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
82
|
+
plt.close(fig)
|
|
83
|
+
return path
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def plot_line(
|
|
87
|
+
df: pl.DataFrame, y_col: str, x_col: str, output_dir: Path
|
|
88
|
+
) -> Path:
|
|
89
|
+
"""Create a line plot and save to PNG."""
|
|
90
|
+
cfg = get_config()
|
|
91
|
+
_validate_col(df, y_col)
|
|
92
|
+
_validate_col(df, x_col)
|
|
93
|
+
|
|
94
|
+
sub = df.select([x_col, y_col]).drop_nulls().sort(x_col)
|
|
95
|
+
x = sub[x_col].to_numpy()
|
|
96
|
+
y = sub[y_col].to_numpy()
|
|
97
|
+
|
|
98
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
99
|
+
ax.plot(x, y, marker="o", markersize=3, linewidth=1.5, color="#4C72B0")
|
|
100
|
+
ax.set_xlabel(x_col)
|
|
101
|
+
ax.set_ylabel(y_col)
|
|
102
|
+
ax.set_title(f"{y_col} over {x_col}")
|
|
103
|
+
fig.tight_layout()
|
|
104
|
+
|
|
105
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
106
|
+
path = _unique_path(output_dir, f"line_{y_col}_vs_{x_col}")
|
|
107
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
108
|
+
plt.close(fig)
|
|
109
|
+
return path
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def plot_box(
|
|
113
|
+
df: pl.DataFrame, col: str, output_dir: Path, *, group_col: str | None = None
|
|
114
|
+
) -> Path:
|
|
115
|
+
"""Create a box plot, optionally grouped. Save to PNG."""
|
|
116
|
+
_validate_col(df, col)
|
|
117
|
+
if group_col:
|
|
118
|
+
_validate_col(df, group_col)
|
|
119
|
+
|
|
120
|
+
cfg = get_config()
|
|
121
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
122
|
+
|
|
123
|
+
if group_col:
|
|
124
|
+
groups = df[group_col].unique().sort().to_list()
|
|
125
|
+
data = []
|
|
126
|
+
labels = []
|
|
127
|
+
for g in groups:
|
|
128
|
+
vals = df.filter(pl.col(group_col) == g)[col].drop_nulls().to_numpy()
|
|
129
|
+
if len(vals) > 0:
|
|
130
|
+
data.append(vals)
|
|
131
|
+
labels.append(str(g))
|
|
132
|
+
ax.boxplot(data, tick_labels=labels)
|
|
133
|
+
ax.set_xlabel(group_col)
|
|
134
|
+
ax.set_title(f"{col} by {group_col}")
|
|
135
|
+
else:
|
|
136
|
+
data = df[col].drop_nulls().to_numpy()
|
|
137
|
+
ax.boxplot(data)
|
|
138
|
+
ax.set_title(f"Box plot of {col}")
|
|
139
|
+
|
|
140
|
+
ax.set_ylabel(col)
|
|
141
|
+
fig.tight_layout()
|
|
142
|
+
|
|
143
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
name_suffix = f"_{col}_by_{group_col}" if group_col else f"_{col}"
|
|
145
|
+
path = _unique_path(output_dir, f"box{name_suffix}")
|
|
146
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
147
|
+
plt.close(fig)
|
|
148
|
+
return path
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def plot_bar(
|
|
152
|
+
df: pl.DataFrame, col: str, output_dir: Path, *, group_col: str | None = None
|
|
153
|
+
) -> Path:
|
|
154
|
+
"""Create a bar chart. Shows mean of a numeric col by group, or counts of a categorical col."""
|
|
155
|
+
_validate_col(df, col)
|
|
156
|
+
|
|
157
|
+
cfg = get_config()
|
|
158
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
159
|
+
|
|
160
|
+
if group_col:
|
|
161
|
+
_validate_col(df, group_col)
|
|
162
|
+
agg = df.group_by(group_col).agg(pl.col(col).mean().alias("mean")).sort(group_col)
|
|
163
|
+
labels = [str(v) for v in agg[group_col].to_list()]
|
|
164
|
+
values = agg["mean"].to_numpy()
|
|
165
|
+
ax.bar(labels, values, color="#4C72B0", alpha=0.85, edgecolor="white")
|
|
166
|
+
ax.set_xlabel(group_col)
|
|
167
|
+
ax.set_ylabel(f"Mean of {col}")
|
|
168
|
+
ax.set_title(f"Mean {col} by {group_col}")
|
|
169
|
+
else:
|
|
170
|
+
counts = (
|
|
171
|
+
df.group_by(col).len()
|
|
172
|
+
.sort("len", descending=True)
|
|
173
|
+
.rename({"len": "count"})
|
|
174
|
+
.head(20)
|
|
175
|
+
)
|
|
176
|
+
labels = [str(v) for v in counts[col].to_list()]
|
|
177
|
+
values = counts["count"].to_numpy()
|
|
178
|
+
ax.bar(labels, values, color="#4C72B0", alpha=0.85, edgecolor="white")
|
|
179
|
+
ax.set_xlabel(col)
|
|
180
|
+
ax.set_ylabel("Count")
|
|
181
|
+
ax.set_title(f"Bar chart of {col}")
|
|
182
|
+
|
|
183
|
+
plt.xticks(rotation=45, ha="right")
|
|
184
|
+
fig.tight_layout()
|
|
185
|
+
|
|
186
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
187
|
+
name_suffix = f"_{col}_by_{group_col}" if group_col else f"_{col}"
|
|
188
|
+
path = _unique_path(output_dir, f"bar{name_suffix}")
|
|
189
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
190
|
+
plt.close(fig)
|
|
191
|
+
return path
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def plot_heatmap(
|
|
195
|
+
df: pl.DataFrame, cols: list[str] | None, output_dir: Path
|
|
196
|
+
) -> Path:
|
|
197
|
+
"""Create a correlation heatmap for numeric columns."""
|
|
198
|
+
cfg = get_config()
|
|
199
|
+
if cols:
|
|
200
|
+
for c in cols:
|
|
201
|
+
_validate_col(df, c)
|
|
202
|
+
num_cols = [c for c in cols if df[c].dtype in NUMERIC_DTYPES]
|
|
203
|
+
else:
|
|
204
|
+
num_cols = [c for c in df.columns if df[c].dtype in NUMERIC_DTYPES]
|
|
205
|
+
|
|
206
|
+
if len(num_cols) < 2:
|
|
207
|
+
raise ValueError("Need at least 2 numeric columns for a heatmap")
|
|
208
|
+
|
|
209
|
+
sub = df.select(num_cols).drop_nulls()
|
|
210
|
+
n = len(num_cols)
|
|
211
|
+
corr_matrix = np.zeros((n, n))
|
|
212
|
+
for i in range(n):
|
|
213
|
+
for j in range(n):
|
|
214
|
+
r = sub.select(pl.corr(num_cols[i], num_cols[j])).item()
|
|
215
|
+
corr_matrix[i, j] = r if r is not None else 0.0
|
|
216
|
+
|
|
217
|
+
fig, ax = plt.subplots(figsize=(max(6, n + 2), max(5, n + 1)))
|
|
218
|
+
im = ax.imshow(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto")
|
|
219
|
+
fig.colorbar(im, ax=ax, label="Pearson r")
|
|
220
|
+
|
|
221
|
+
ax.set_xticks(range(n))
|
|
222
|
+
ax.set_yticks(range(n))
|
|
223
|
+
ax.set_xticklabels(num_cols, rotation=45, ha="right")
|
|
224
|
+
ax.set_yticklabels(num_cols)
|
|
225
|
+
|
|
226
|
+
# Annotate cells
|
|
227
|
+
for i in range(n):
|
|
228
|
+
for j in range(n):
|
|
229
|
+
val = corr_matrix[i, j]
|
|
230
|
+
color = "white" if abs(val) > 0.5 else "black"
|
|
231
|
+
ax.text(j, i, f"{val:.2f}", ha="center", va="center", color=color, fontsize=9)
|
|
232
|
+
|
|
233
|
+
ax.set_title("Correlation Heatmap")
|
|
234
|
+
fig.tight_layout()
|
|
235
|
+
|
|
236
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
237
|
+
path = _unique_path(output_dir, "heatmap_corr")
|
|
238
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
239
|
+
plt.close(fig)
|
|
240
|
+
return path
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# ---------------------------------------------------------------------------
|
|
244
|
+
# Diagnostic plots (post-estimation)
|
|
245
|
+
# ---------------------------------------------------------------------------
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def plot_residuals_vs_fitted(
|
|
249
|
+
fitted: np.ndarray, residuals: np.ndarray, output_dir: Path
|
|
250
|
+
) -> Path:
|
|
251
|
+
"""Residuals vs. fitted values plot."""
|
|
252
|
+
cfg = get_config()
|
|
253
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
254
|
+
ax.scatter(fitted, residuals, alpha=0.6, s=20, color="#4C72B0")
|
|
255
|
+
ax.axhline(y=0, color="red", linestyle="--", linewidth=1)
|
|
256
|
+
ax.set_xlabel("Fitted values")
|
|
257
|
+
ax.set_ylabel("Residuals")
|
|
258
|
+
ax.set_title("Residuals vs Fitted")
|
|
259
|
+
fig.tight_layout()
|
|
260
|
+
|
|
261
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
262
|
+
path = _unique_path(output_dir, "resid_vs_fitted")
|
|
263
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
264
|
+
plt.close(fig)
|
|
265
|
+
return path
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def plot_qq(std_residuals: np.ndarray, output_dir: Path) -> Path:
|
|
269
|
+
"""Normal Q-Q plot of standardized residuals."""
|
|
270
|
+
cfg = get_config()
|
|
271
|
+
from scipy import stats as sp_stats
|
|
272
|
+
|
|
273
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
274
|
+
sorted_resid = np.sort(std_residuals)
|
|
275
|
+
n = len(sorted_resid)
|
|
276
|
+
theoretical = sp_stats.norm.ppf(np.arange(1, n + 1) / (n + 1))
|
|
277
|
+
|
|
278
|
+
ax.scatter(theoretical, sorted_resid, alpha=0.6, s=20, color="#4C72B0")
|
|
279
|
+
# Add 45-degree reference line
|
|
280
|
+
lims = [min(theoretical.min(), sorted_resid.min()),
|
|
281
|
+
max(theoretical.max(), sorted_resid.max())]
|
|
282
|
+
ax.plot(lims, lims, "r--", linewidth=1)
|
|
283
|
+
ax.set_xlabel("Theoretical quantiles")
|
|
284
|
+
ax.set_ylabel("Standardized residuals")
|
|
285
|
+
ax.set_title("Normal Q-Q Plot")
|
|
286
|
+
fig.tight_layout()
|
|
287
|
+
|
|
288
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
289
|
+
path = _unique_path(output_dir, "qq_plot")
|
|
290
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
291
|
+
plt.close(fig)
|
|
292
|
+
return path
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def plot_scale_location(
|
|
296
|
+
fitted: np.ndarray, std_residuals: np.ndarray, output_dir: Path
|
|
297
|
+
) -> Path:
|
|
298
|
+
"""Scale-Location plot (sqrt of abs standardized residuals vs fitted)."""
|
|
299
|
+
cfg = get_config()
|
|
300
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
301
|
+
sqrt_abs_resid = np.sqrt(np.abs(std_residuals))
|
|
302
|
+
ax.scatter(fitted, sqrt_abs_resid, alpha=0.6, s=20, color="#4C72B0")
|
|
303
|
+
ax.set_xlabel("Fitted values")
|
|
304
|
+
ax.set_ylabel(r"$\sqrt{|Standardized\ residuals|}$")
|
|
305
|
+
ax.set_title("Scale-Location")
|
|
306
|
+
fig.tight_layout()
|
|
307
|
+
|
|
308
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
309
|
+
path = _unique_path(output_dir, "scale_location")
|
|
310
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
311
|
+
plt.close(fig)
|
|
312
|
+
return path
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ---------------------------------------------------------------------------
|
|
316
|
+
# Coefficient plot (post-estimation)
|
|
317
|
+
# ---------------------------------------------------------------------------
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def plot_coef(
|
|
321
|
+
params: dict,
|
|
322
|
+
ci_lower: dict,
|
|
323
|
+
ci_upper: dict,
|
|
324
|
+
output_dir: Path,
|
|
325
|
+
*,
|
|
326
|
+
title: str = "Coefficient Plot",
|
|
327
|
+
drop_const: bool = True,
|
|
328
|
+
) -> Path:
|
|
329
|
+
"""Coefficient plot with 95% CI error bars. Saves to PNG."""
|
|
330
|
+
cfg = get_config()
|
|
331
|
+
|
|
332
|
+
_CONST_NAMES = {"const", "Intercept", "_cons"}
|
|
333
|
+
names = [k for k in params if not (drop_const and k in _CONST_NAMES)]
|
|
334
|
+
if not names:
|
|
335
|
+
names = list(params.keys())
|
|
336
|
+
|
|
337
|
+
coefs = np.array([params[n] for n in names])
|
|
338
|
+
err_lower = np.array([params[n] - ci_lower[n] for n in names])
|
|
339
|
+
err_upper = np.array([ci_upper[n] - params[n] for n in names])
|
|
340
|
+
|
|
341
|
+
fig_h = max(cfg.plot_figsize_h, len(names) * 0.55 + 1.5)
|
|
342
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, fig_h))
|
|
343
|
+
|
|
344
|
+
y_pos = np.arange(len(names))
|
|
345
|
+
ax.errorbar(
|
|
346
|
+
coefs, y_pos,
|
|
347
|
+
xerr=[err_lower, err_upper],
|
|
348
|
+
fmt="o",
|
|
349
|
+
color="#4C72B0",
|
|
350
|
+
ecolor="#4C72B0",
|
|
351
|
+
capsize=4,
|
|
352
|
+
linewidth=1.5,
|
|
353
|
+
markersize=6,
|
|
354
|
+
)
|
|
355
|
+
ax.axvline(0, color="gray", linestyle="--", linewidth=1, alpha=0.8)
|
|
356
|
+
ax.set_yticks(y_pos)
|
|
357
|
+
ax.set_yticklabels(names)
|
|
358
|
+
ax.set_xlabel("Coefficient (95% CI)")
|
|
359
|
+
ax.set_title(title)
|
|
360
|
+
fig.tight_layout()
|
|
361
|
+
|
|
362
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
363
|
+
path = _unique_path(output_dir, "coef_plot")
|
|
364
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
365
|
+
plt.close(fig)
|
|
366
|
+
return path
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def plot_interaction(
|
|
370
|
+
df: pl.DataFrame,
|
|
371
|
+
y_col: str,
|
|
372
|
+
x_col: str,
|
|
373
|
+
mod_col: str,
|
|
374
|
+
output_dir: Path,
|
|
375
|
+
*,
|
|
376
|
+
n_levels: int = 3,
|
|
377
|
+
) -> Path:
|
|
378
|
+
"""Interaction plot: y vs x for low/medium/high levels of moderator.
|
|
379
|
+
|
|
380
|
+
Uses ±1 SD split for continuous moderators, unique values for categorical.
|
|
381
|
+
"""
|
|
382
|
+
cfg = get_config()
|
|
383
|
+
_validate_col(df, y_col)
|
|
384
|
+
_validate_col(df, x_col)
|
|
385
|
+
_validate_col(df, mod_col)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
390
|
+
|
|
391
|
+
mod_series = df[mod_col].drop_nulls()
|
|
392
|
+
is_numeric_mod = mod_series.dtype in (
|
|
393
|
+
pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64,
|
|
394
|
+
pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
if is_numeric_mod:
|
|
398
|
+
mu = mod_series.mean()
|
|
399
|
+
sd = mod_series.std()
|
|
400
|
+
cuts = {
|
|
401
|
+
f"{mod_col} Low (−1SD)": df.filter(pl.col(mod_col) < mu - sd),
|
|
402
|
+
f"{mod_col} Mean": df.filter((pl.col(mod_col) >= mu - sd) & (pl.col(mod_col) <= mu + sd)),
|
|
403
|
+
f"{mod_col} High (+1SD)": df.filter(pl.col(mod_col) > mu + sd),
|
|
404
|
+
}
|
|
405
|
+
else:
|
|
406
|
+
unique_vals = mod_series.unique().sort().to_list()[:n_levels]
|
|
407
|
+
cuts = {
|
|
408
|
+
f"{mod_col}={v}": df.filter(pl.col(mod_col) == v)
|
|
409
|
+
for v in unique_vals
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
colors = ["#4C72B0", "#DD8452", "#55A868", "#C44E52", "#8172B3"]
|
|
413
|
+
for i, (label, sub) in enumerate(cuts.items()):
|
|
414
|
+
if sub.is_empty():
|
|
415
|
+
continue
|
|
416
|
+
x_vals = sub[x_col].drop_nulls().to_numpy()
|
|
417
|
+
y_vals = sub[y_col].drop_nulls().to_numpy()
|
|
418
|
+
if len(x_vals) == 0:
|
|
419
|
+
continue
|
|
420
|
+
# regression line for this group
|
|
421
|
+
if len(x_vals) > 1:
|
|
422
|
+
m, b = np.polyfit(x_vals, y_vals, 1)
|
|
423
|
+
x_range = np.linspace(x_vals.min(), x_vals.max(), 50)
|
|
424
|
+
ax.plot(x_range, m * x_range + b, color=colors[i % len(colors)], label=label, linewidth=2)
|
|
425
|
+
ax.scatter(x_vals, y_vals, alpha=0.25, color=colors[i % len(colors)], s=20)
|
|
426
|
+
|
|
427
|
+
ax.set_xlabel(x_col)
|
|
428
|
+
ax.set_ylabel(y_col)
|
|
429
|
+
ax.set_title(f"Interaction: {y_col} ~ {x_col} × {mod_col}")
|
|
430
|
+
ax.legend(loc="best", fontsize=9)
|
|
431
|
+
fig.tight_layout()
|
|
432
|
+
|
|
433
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
434
|
+
path = _unique_path(output_dir, "interaction_plot")
|
|
435
|
+
fig.savefig(path, dpi=cfg.plot_dpi)
|
|
436
|
+
plt.close(fig)
|
|
437
|
+
return path
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Survival analysis plots: Kaplan-Meier curves."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
|
|
9
|
+
from openstat.config import get_config
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def plot_km(kmf_objects, output_dir: Path, group_var: str | None = None) -> Path:
|
|
13
|
+
"""Plot Kaplan-Meier survival curve(s)."""
|
|
14
|
+
cfg = get_config()
|
|
15
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
16
|
+
|
|
17
|
+
if isinstance(kmf_objects, list):
|
|
18
|
+
for kmf in kmf_objects:
|
|
19
|
+
kmf.plot_survival_function(ax=ax)
|
|
20
|
+
else:
|
|
21
|
+
kmf_objects.plot_survival_function(ax=ax)
|
|
22
|
+
|
|
23
|
+
ax.set_title("Kaplan-Meier Survival Estimate" + (f" by {group_var}" if group_var else ""))
|
|
24
|
+
ax.set_xlabel("Time")
|
|
25
|
+
ax.set_ylabel("Survival Probability")
|
|
26
|
+
ax.set_ylim(0, 1)
|
|
27
|
+
|
|
28
|
+
name = f"km_{group_var}" if group_var else "km"
|
|
29
|
+
path = output_dir / f"{name}.png"
|
|
30
|
+
fig.savefig(path, dpi=cfg.plot_dpi, bbox_inches="tight")
|
|
31
|
+
plt.close(fig)
|
|
32
|
+
return path
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Time series plots: ACF, PACF, forecast."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
|
|
10
|
+
from openstat.config import get_config
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def plot_acf(series: np.ndarray, var_name: str, output_dir: Path, lags: int = 40) -> Path:
|
|
14
|
+
"""Plot autocorrelation function."""
|
|
15
|
+
from statsmodels.graphics.tsaplots import plot_acf as sm_plot_acf
|
|
16
|
+
|
|
17
|
+
cfg = get_config()
|
|
18
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
19
|
+
sm_plot_acf(series, lags=min(lags, len(series) // 2 - 1), ax=ax)
|
|
20
|
+
ax.set_title(f"ACF: {var_name}")
|
|
21
|
+
|
|
22
|
+
path = output_dir / f"acf_{var_name}.png"
|
|
23
|
+
fig.savefig(path, dpi=cfg.plot_dpi, bbox_inches="tight")
|
|
24
|
+
plt.close(fig)
|
|
25
|
+
return path
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def plot_pacf(series: np.ndarray, var_name: str, output_dir: Path, lags: int = 40) -> Path:
|
|
29
|
+
"""Plot partial autocorrelation function."""
|
|
30
|
+
from statsmodels.graphics.tsaplots import plot_pacf as sm_plot_pacf
|
|
31
|
+
|
|
32
|
+
cfg = get_config()
|
|
33
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
34
|
+
sm_plot_pacf(series, lags=min(lags, len(series) // 2 - 1), ax=ax, method="ywm")
|
|
35
|
+
ax.set_title(f"PACF: {var_name}")
|
|
36
|
+
|
|
37
|
+
path = output_dir / f"pacf_{var_name}.png"
|
|
38
|
+
fig.savefig(path, dpi=cfg.plot_dpi, bbox_inches="tight")
|
|
39
|
+
plt.close(fig)
|
|
40
|
+
return path
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def plot_forecast(actual: np.ndarray, forecast: np.ndarray, var_name: str, output_dir: Path) -> Path:
|
|
44
|
+
"""Plot actual values with forecast extension."""
|
|
45
|
+
cfg = get_config()
|
|
46
|
+
fig, ax = plt.subplots(figsize=(cfg.plot_figsize_w, cfg.plot_figsize_h))
|
|
47
|
+
|
|
48
|
+
n_actual = len(actual)
|
|
49
|
+
n_fc = len(forecast)
|
|
50
|
+
ax.plot(range(n_actual), actual, label="Actual", color="blue")
|
|
51
|
+
ax.plot(range(n_actual, n_actual + n_fc), forecast, label="Forecast", color="red", linestyle="--")
|
|
52
|
+
ax.axvline(x=n_actual, color="gray", linestyle=":", alpha=0.5)
|
|
53
|
+
ax.set_title(f"Forecast: {var_name}")
|
|
54
|
+
ax.legend()
|
|
55
|
+
|
|
56
|
+
path = output_dir / f"forecast_{var_name}.png"
|
|
57
|
+
fig.savefig(path, dpi=cfg.plot_dpi, bbox_inches="tight")
|
|
58
|
+
plt.close(fig)
|
|
59
|
+
return path
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Plugin discovery and lifecycle management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from importlib.metadata import entry_points
|
|
7
|
+
|
|
8
|
+
from openstat.logging_config import get_logger
|
|
9
|
+
|
|
10
|
+
log = get_logger("plugins")
|
|
11
|
+
|
|
12
|
+
ENTRY_POINT_GROUP = "openstat_plugin"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class PluginInfo:
|
|
17
|
+
"""Metadata for an installed plugin."""
|
|
18
|
+
|
|
19
|
+
name: str
|
|
20
|
+
version: str = "0.0.0"
|
|
21
|
+
description: str = ""
|
|
22
|
+
commands: list[str] = field(default_factory=list)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PluginManager:
|
|
26
|
+
"""Discovers and loads OpenStat plugins via entry_points."""
|
|
27
|
+
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
self._loaded: dict[str, PluginInfo] = {}
|
|
30
|
+
self._errors: dict[str, str] = {}
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def plugins(self) -> dict[str, PluginInfo]:
|
|
34
|
+
return dict(self._loaded)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def errors(self) -> dict[str, str]:
|
|
38
|
+
return dict(self._errors)
|
|
39
|
+
|
|
40
|
+
def discover(self) -> list[str]:
|
|
41
|
+
"""Discover and load all installed plugins.
|
|
42
|
+
|
|
43
|
+
Returns list of successfully loaded plugin names.
|
|
44
|
+
"""
|
|
45
|
+
eps = entry_points(group=ENTRY_POINT_GROUP)
|
|
46
|
+
loaded: list[str] = []
|
|
47
|
+
for ep in eps:
|
|
48
|
+
try:
|
|
49
|
+
module = ep.load()
|
|
50
|
+
if hasattr(module, "setup"):
|
|
51
|
+
info = module.setup()
|
|
52
|
+
if not isinstance(info, PluginInfo):
|
|
53
|
+
info = PluginInfo(name=ep.name)
|
|
54
|
+
self._loaded[ep.name] = info
|
|
55
|
+
else:
|
|
56
|
+
# Module loaded — commands registered via @command at import
|
|
57
|
+
self._loaded[ep.name] = PluginInfo(name=ep.name)
|
|
58
|
+
loaded.append(ep.name)
|
|
59
|
+
log.info("Loaded plugin: %s", ep.name)
|
|
60
|
+
except Exception as exc:
|
|
61
|
+
self._errors[ep.name] = str(exc)
|
|
62
|
+
log.warning("Failed to load plugin %s: %s", ep.name, exc)
|
|
63
|
+
return loaded
|
|
64
|
+
|
|
65
|
+
def list_plugins(self) -> list[PluginInfo]:
|
|
66
|
+
return list(self._loaded.values())
|
|
67
|
+
|
|
68
|
+
def get_info(self, name: str) -> PluginInfo | None:
|
|
69
|
+
return self._loaded.get(name)
|