openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,203 @@
1
+ """Bootstrap and permutation test statistics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import polars as pl
7
+
8
+
9
+ def bootstrap_ci(
10
+ df: pl.DataFrame,
11
+ col: str,
12
+ stat: str = "mean",
13
+ n_boot: int = 2000,
14
+ ci: float = 0.95,
15
+ seed: int = 42,
16
+ ) -> dict:
17
+ """Bootstrap confidence interval for a statistic."""
18
+ rng = np.random.default_rng(seed)
19
+ data = df[col].drop_nulls().to_numpy().astype(float)
20
+ n = len(data)
21
+
22
+ stat_fns = {
23
+ "mean": np.mean,
24
+ "median": np.median,
25
+ "std": np.std,
26
+ "var": np.var,
27
+ "min": np.min,
28
+ "max": np.max,
29
+ }
30
+ if stat not in stat_fns:
31
+ raise ValueError(f"Unknown statistic: {stat}. Use: {', '.join(stat_fns)}")
32
+
33
+ fn = stat_fns[stat]
34
+ observed = float(fn(data))
35
+
36
+ boot_stats = np.array([
37
+ fn(rng.choice(data, size=n, replace=True))
38
+ for _ in range(n_boot)
39
+ ])
40
+
41
+ alpha = 1 - ci
42
+ lo = float(np.quantile(boot_stats, alpha / 2))
43
+ hi = float(np.quantile(boot_stats, 1 - alpha / 2))
44
+ se = float(boot_stats.std())
45
+ bias = float(boot_stats.mean() - observed)
46
+
47
+ return {
48
+ "test": f"Bootstrap CI ({stat})",
49
+ "col": col,
50
+ "stat": stat,
51
+ "observed": observed,
52
+ "n_obs": n,
53
+ "n_boot": n_boot,
54
+ "ci_level": ci,
55
+ "ci_lo": lo,
56
+ "ci_hi": hi,
57
+ "se_boot": se,
58
+ "bias": bias,
59
+ }
60
+
61
+
62
+ def bootstrap_diff(
63
+ df: pl.DataFrame,
64
+ col: str,
65
+ by: str,
66
+ stat: str = "mean",
67
+ n_boot: int = 2000,
68
+ ci: float = 0.95,
69
+ seed: int = 42,
70
+ ) -> dict:
71
+ """Bootstrap CI for the difference in a statistic between two groups."""
72
+ rng = np.random.default_rng(seed)
73
+ groups = df[by].drop_nulls().unique().sort().to_list()
74
+ if len(groups) != 2:
75
+ raise ValueError(f"bootstrap_diff requires exactly 2 groups, got {len(groups)}")
76
+
77
+ g1 = df.filter(pl.col(by) == groups[0])[col].drop_nulls().to_numpy().astype(float)
78
+ g2 = df.filter(pl.col(by) == groups[1])[col].drop_nulls().to_numpy().astype(float)
79
+
80
+ stat_fns = {
81
+ "mean": np.mean, "median": np.median, "std": np.std,
82
+ "var": np.var, "min": np.min, "max": np.max,
83
+ }
84
+ fn = stat_fns.get(stat, np.mean)
85
+ observed_diff = float(fn(g1) - fn(g2))
86
+
87
+ boot_diffs = np.array([
88
+ fn(rng.choice(g1, size=len(g1), replace=True)) -
89
+ fn(rng.choice(g2, size=len(g2), replace=True))
90
+ for _ in range(n_boot)
91
+ ])
92
+
93
+ alpha = 1 - ci
94
+ lo = float(np.quantile(boot_diffs, alpha / 2))
95
+ hi = float(np.quantile(boot_diffs, 1 - alpha / 2))
96
+ # Shift bootstrap distribution to null (mean=0) for p-value
97
+ boot_centered = boot_diffs - boot_diffs.mean()
98
+ p_value = float((np.abs(boot_centered) >= np.abs(observed_diff)).mean())
99
+
100
+ return {
101
+ "test": f"Bootstrap Difference ({stat})",
102
+ "col": col, "by": by,
103
+ "groups": [str(g) for g in groups],
104
+ "observed_diff": observed_diff,
105
+ "n_boot": n_boot,
106
+ "ci_level": ci,
107
+ "ci_lo": lo,
108
+ "ci_hi": hi,
109
+ "se_boot": float(boot_diffs.std()),
110
+ "p_value": p_value,
111
+ }
112
+
113
+
114
+ def permutation_test(
115
+ df: pl.DataFrame,
116
+ col: str,
117
+ by: str,
118
+ stat: str = "mean",
119
+ n_perm: int = 2000,
120
+ alternative: str = "two-sided",
121
+ seed: int = 42,
122
+ ) -> dict:
123
+ """Permutation test for difference between two groups."""
124
+ rng = np.random.default_rng(seed)
125
+ groups = df[by].drop_nulls().unique().sort().to_list()
126
+ if len(groups) != 2:
127
+ raise ValueError(f"permutation_test requires exactly 2 groups, got {len(groups)}")
128
+
129
+ g1 = df.filter(pl.col(by) == groups[0])[col].drop_nulls().to_numpy().astype(float)
130
+ g2 = df.filter(pl.col(by) == groups[1])[col].drop_nulls().to_numpy().astype(float)
131
+
132
+ stat_fns = {"mean": np.mean, "median": np.median, "std": np.std}
133
+ fn = stat_fns.get(stat, np.mean)
134
+
135
+ observed = float(fn(g1) - fn(g2))
136
+ combined = np.concatenate([g1, g2])
137
+ n1 = len(g1)
138
+
139
+ perm_stats = np.array([
140
+ fn(perm := rng.permutation(combined), ) - fn(perm[n1:]) # noqa: confusing but valid
141
+ for _ in range(n_perm)
142
+ ])
143
+ # Fix: proper permutation
144
+ perm_stats = np.zeros(n_perm)
145
+ for i in range(n_perm):
146
+ perm = rng.permutation(combined)
147
+ perm_stats[i] = fn(perm[:n1]) - fn(perm[n1:])
148
+
149
+ if alternative == "two-sided":
150
+ p_value = float((np.abs(perm_stats) >= np.abs(observed)).mean())
151
+ elif alternative == "greater":
152
+ p_value = float((perm_stats >= observed).mean())
153
+ else:
154
+ p_value = float((perm_stats <= observed).mean())
155
+
156
+ return {
157
+ "test": "Permutation Test",
158
+ "col": col, "by": by,
159
+ "stat": stat,
160
+ "groups": [str(g) for g in groups],
161
+ "observed_diff": observed,
162
+ "n_perm": n_perm,
163
+ "alternative": alternative,
164
+ "p_value": p_value,
165
+ "reject_5pct": p_value < 0.05,
166
+ }
167
+
168
+
169
+ def jackknife_ci(
170
+ df: pl.DataFrame,
171
+ col: str,
172
+ stat: str = "mean",
173
+ ) -> dict:
174
+ """Jackknife (leave-one-out) bias and standard error estimate."""
175
+ data = df[col].drop_nulls().to_numpy().astype(float)
176
+ n = len(data)
177
+ stat_fns = {
178
+ "mean": np.mean, "median": np.median, "std": np.std,
179
+ "var": np.var, "min": np.min, "max": np.max,
180
+ }
181
+ if stat not in stat_fns:
182
+ raise ValueError(f"Unknown statistic: {stat}")
183
+ fn = stat_fns[stat]
184
+ observed = float(fn(data))
185
+
186
+ jack_stats = np.array([
187
+ fn(np.delete(data, i))
188
+ for i in range(n)
189
+ ])
190
+ jack_mean = jack_stats.mean()
191
+ bias = float((n - 1) * (jack_mean - observed))
192
+ se = float(np.sqrt((n - 1) / n * np.sum((jack_stats - jack_mean) ** 2)))
193
+
194
+ return {
195
+ "test": f"Jackknife ({stat})",
196
+ "col": col,
197
+ "stat": stat,
198
+ "observed": observed,
199
+ "n_obs": n,
200
+ "bias": bias,
201
+ "se_jackknife": se,
202
+ "bias_corrected": observed - bias,
203
+ }
@@ -0,0 +1,213 @@
1
+ """Survey-weighted estimation: weighted means, WLS, Taylor linearization, DEFF."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import polars as pl
7
+ import statsmodels.api as sm
8
+ from scipy import stats as sp_stats
9
+
10
+ from openstat.stats.models import FitResult
11
+
12
+
13
+ def weighted_summary(df: pl.DataFrame, cols: list[str], weight_var: str) -> str:
14
+ """Compute weighted summary statistics."""
15
+ pdf = df.to_pandas()
16
+ weights = pdf[weight_var].values
17
+
18
+ lines = ["Survey-Weighted Summary Statistics:"]
19
+ lines.append(f"{'Variable':>15} {'Wt.Mean':>12} {'Wt.SE':>12} {'N':>8}")
20
+ lines.append("-" * 50)
21
+
22
+ for col in cols:
23
+ vals = pdf[col].values
24
+ mask = ~np.isnan(vals) & ~np.isnan(weights)
25
+ v = vals[mask]
26
+ w = weights[mask]
27
+ n = len(v)
28
+ if n == 0:
29
+ lines.append(f"{col:>15} {'—':>12} {'—':>12} {0:>8}")
30
+ continue
31
+
32
+ # Weighted mean
33
+ wt_mean = np.average(v, weights=w)
34
+ # Weighted variance (for SE estimation)
35
+ w_sum = w.sum()
36
+ w_sum2 = (w ** 2).sum()
37
+ wt_var = np.average((v - wt_mean) ** 2, weights=w) * w_sum ** 2 / (w_sum ** 2 - w_sum2)
38
+ wt_se = np.sqrt(wt_var / n)
39
+
40
+ lines.append(f"{col:>15} {wt_mean:>12.4f} {wt_se:>12.4f} {n:>8}")
41
+
42
+ return "\n".join(lines)
43
+
44
+
45
+ def fit_weighted_ols(
46
+ df: pl.DataFrame,
47
+ dep: str,
48
+ indeps: list[str],
49
+ weight_var: str,
50
+ strata_var: str | None = None,
51
+ psu_var: str | None = None,
52
+ ) -> tuple[FitResult, object]:
53
+ """Fit weighted OLS with optional Taylor linearization for complex survey SE."""
54
+ all_cols = [dep] + indeps + [weight_var]
55
+ if strata_var:
56
+ all_cols.append(strata_var)
57
+ if psu_var:
58
+ all_cols.append(psu_var)
59
+
60
+ pdf = df.select(all_cols).to_pandas().dropna()
61
+ weights = pdf[weight_var].values
62
+ y = pdf[dep].values
63
+ X = sm.add_constant(pdf[indeps].values)
64
+ var_names = ["const"] + indeps
65
+
66
+ # Fit WLS
67
+ model = sm.WLS(y, X, weights=weights)
68
+ result = model.fit()
69
+
70
+ # Taylor linearization for SE if PSU/strata provided
71
+ if psu_var and strata_var:
72
+ vcov = _taylor_linearization(pdf, y, X, weights, result.params,
73
+ psu_var, strata_var)
74
+ se = np.sqrt(np.diag(vcov))
75
+ else:
76
+ se = result.bse
77
+
78
+ params = {name: float(result.params[i]) for i, name in enumerate(var_names)}
79
+ std_errors = {name: float(se[i]) for i, name in enumerate(var_names)}
80
+ t_vals = {name: float(result.params[i] / se[i]) if se[i] > 0 else 0.0 for i, name in enumerate(var_names)}
81
+ p_vals = {name: float(2 * (1 - sp_stats.t.cdf(abs(t_vals[name]), result.df_resid))) for name in var_names}
82
+ ci_low = {name: params[name] - 1.96 * std_errors[name] for name in var_names}
83
+ ci_high = {name: params[name] + 1.96 * std_errors[name] for name in var_names}
84
+
85
+ warnings_list = [f"Weight variable: {weight_var}"]
86
+ if strata_var:
87
+ warnings_list.append(f"Strata: {strata_var}")
88
+ if psu_var:
89
+ warnings_list.append(f"PSU: {psu_var}")
90
+
91
+ fit = FitResult(
92
+ model_type="Svy: OLS",
93
+ formula=f"{dep} ~ {' + '.join(indeps)}",
94
+ dep_var=dep,
95
+ indep_vars=var_names,
96
+ n_obs=int(len(pdf)),
97
+ params=params,
98
+ std_errors=std_errors,
99
+ t_values=t_vals,
100
+ p_values=p_vals,
101
+ conf_int_low=ci_low,
102
+ conf_int_high=ci_high,
103
+ r_squared=float(result.rsquared),
104
+ warnings=warnings_list,
105
+ )
106
+
107
+ return fit, result
108
+
109
+
110
+ def fit_weighted_logit(
111
+ df: pl.DataFrame,
112
+ dep: str,
113
+ indeps: list[str],
114
+ weight_var: str,
115
+ ) -> tuple[FitResult, object]:
116
+ """Fit weighted logistic regression."""
117
+ all_cols = [dep] + indeps + [weight_var]
118
+ pdf = df.select(all_cols).to_pandas().dropna()
119
+ weights = pdf[weight_var].values
120
+ y = pdf[dep].values
121
+ X = sm.add_constant(pdf[indeps].values)
122
+ var_names = ["const"] + indeps
123
+
124
+ model = sm.Logit(y, X)
125
+ result = model.fit(disp=0, freq_weights=weights)
126
+
127
+ params = {name: float(result.params[i]) for i, name in enumerate(var_names)}
128
+ std_errors = {name: float(result.bse[i]) for i, name in enumerate(var_names)}
129
+ t_vals = {name: float(result.tvalues[i]) for i, name in enumerate(var_names)}
130
+ p_vals = {name: float(result.pvalues[i]) for i, name in enumerate(var_names)}
131
+ ci = result.conf_int()
132
+ ci_low = {name: float(ci[i, 0]) for i, name in enumerate(var_names)}
133
+ ci_high = {name: float(ci[i, 1]) for i, name in enumerate(var_names)}
134
+
135
+ fit = FitResult(
136
+ model_type="Svy: Logit",
137
+ formula=f"{dep} ~ {' + '.join(indeps)}",
138
+ dep_var=dep,
139
+ indep_vars=var_names,
140
+ n_obs=int(len(pdf)),
141
+ params=params,
142
+ std_errors=std_errors,
143
+ t_values=t_vals,
144
+ p_values=p_vals,
145
+ conf_int_low=ci_low,
146
+ conf_int_high=ci_high,
147
+ pseudo_r2=float(result.prsquared),
148
+ log_likelihood=float(result.llf),
149
+ warnings=[f"Weight variable: {weight_var}"],
150
+ )
151
+
152
+ return fit, result
153
+
154
+
155
+ def _taylor_linearization(pdf, y, X, weights, beta, psu_var, strata_var):
156
+ """Taylor series (sandwich) variance estimation for complex survey designs."""
157
+ import pandas as pd
158
+
159
+ resid = y - X @ beta
160
+ score = X * (resid * weights)[:, np.newaxis]
161
+
162
+ strata = pdf[strata_var].values
163
+ psu = pdf[psu_var].values
164
+
165
+ unique_strata = np.unique(strata)
166
+ n_h = len(unique_strata)
167
+ k = X.shape[1]
168
+ meat = np.zeros((k, k))
169
+
170
+ for h in unique_strata:
171
+ mask_h = strata == h
172
+ psus_in_h = np.unique(psu[mask_h])
173
+ m_h = len(psus_in_h)
174
+ if m_h < 2:
175
+ continue
176
+
177
+ # Sum of scores within each PSU
178
+ score_psu = np.zeros((m_h, k))
179
+ for j, p in enumerate(psus_in_h):
180
+ mask_p = mask_h & (psu == p)
181
+ score_psu[j] = score[mask_p].sum(axis=0)
182
+
183
+ score_mean = score_psu.mean(axis=0)
184
+ for j in range(m_h):
185
+ diff = score_psu[j] - score_mean
186
+ meat += np.outer(diff, diff) * m_h / (m_h - 1)
187
+
188
+ bread = np.linalg.inv(X.T @ np.diag(weights) @ X)
189
+ vcov = bread @ meat @ bread
190
+ return vcov
191
+
192
+
193
+ def compute_deff(df: pl.DataFrame, col: str, weight_var: str,
194
+ psu_var: str | None, strata_var: str | None) -> float:
195
+ """Compute design effect (DEFF) for a variable.
196
+
197
+ DEFF = var(complex design) / var(SRS of same size)
198
+ """
199
+ pdf = df.select([col, weight_var] + ([psu_var] if psu_var else []) + ([strata_var] if strata_var else [])).to_pandas().dropna()
200
+ vals = pdf[col].values
201
+ weights = pdf[weight_var].values
202
+ n = len(vals)
203
+
204
+ # SRS variance
205
+ var_srs = np.var(vals, ddof=1) / n
206
+
207
+ # Weighted variance
208
+ wt_mean = np.average(vals, weights=weights)
209
+ w_sum = weights.sum()
210
+ w_sum2 = (weights ** 2).sum()
211
+ # Kish's approximation of design effect
212
+ deff = 1 + (w_sum2 / w_sum ** 2 - 1 / n) * n
213
+ return max(deff, 1.0)
@@ -0,0 +1,196 @@
1
+ """Survival analysis: Cox PH, Kaplan-Meier, log-rank test."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import polars as pl
7
+
8
+ from openstat.stats.models import FitResult
9
+
10
+
11
+ def _try_import_lifelines():
12
+ try:
13
+ import lifelines # noqa: F401
14
+ except ImportError:
15
+ raise ImportError(
16
+ "Survival analysis requires lifelines. "
17
+ "Install it with: pip install openstat[survival]"
18
+ )
19
+
20
+
21
+ def fit_cox_ph(
22
+ df: pl.DataFrame,
23
+ time_var: str,
24
+ event_var: str,
25
+ covariates: list[str],
26
+ robust: bool = False,
27
+ ) -> tuple[FitResult, object]:
28
+ """Fit a Cox Proportional Hazards model."""
29
+ _try_import_lifelines()
30
+ from lifelines import CoxPHFitter
31
+
32
+ cols = [time_var, event_var] + covariates
33
+ pdf = df.select(cols).to_pandas().dropna()
34
+
35
+ cph = CoxPHFitter()
36
+ cph.fit(pdf, duration_col=time_var, event_col=event_var, robust=robust)
37
+
38
+ summary = cph.summary
39
+ params = {name: float(summary.loc[name, "coef"]) for name in covariates}
40
+ std_errors = {name: float(summary.loc[name, "se(coef)"]) for name in covariates}
41
+ z_values = {name: float(summary.loc[name, "z"]) for name in covariates}
42
+ p_values = {name: float(summary.loc[name, "p"]) for name in covariates}
43
+
44
+ ci_cols = [c for c in summary.columns if "lower" in c.lower()]
45
+ ci_low_col = ci_cols[0] if ci_cols else None
46
+ ci_high_cols = [c for c in summary.columns if "upper" in c.lower()]
47
+ ci_high_col = ci_high_cols[0] if ci_high_cols else None
48
+
49
+ conf_low = {}
50
+ conf_high = {}
51
+ for name in covariates:
52
+ conf_low[name] = float(summary.loc[name, ci_low_col]) if ci_low_col else 0.0
53
+ conf_high[name] = float(summary.loc[name, ci_high_col]) if ci_high_col else 0.0
54
+
55
+ warnings_list = [
56
+ f"Concordance: {cph.concordance_index_:.4f}",
57
+ f"Partial log-likelihood: {cph.log_likelihood_:.2f}",
58
+ ]
59
+
60
+ # Hazard ratios
61
+ hr_lines = ["Hazard Ratios:"]
62
+ for name in covariates:
63
+ hr = float(summary.loc[name, "exp(coef)"])
64
+ hr_lines.append(f" {name}: {hr:.4f}")
65
+ warnings_list.append("\n".join(hr_lines))
66
+
67
+ fit = FitResult(
68
+ model_type="Cox PH",
69
+ formula=f"h(t) ~ {' + '.join(covariates)}",
70
+ dep_var=f"{time_var} (event: {event_var})",
71
+ indep_vars=covariates,
72
+ n_obs=int(len(pdf)),
73
+ params=params,
74
+ std_errors=std_errors,
75
+ t_values=z_values,
76
+ p_values=p_values,
77
+ conf_int_low=conf_low,
78
+ conf_int_high=conf_high,
79
+ log_likelihood=float(cph.log_likelihood_),
80
+ warnings=warnings_list,
81
+ )
82
+
83
+ return fit, cph
84
+
85
+
86
+ def kaplan_meier(
87
+ df: pl.DataFrame,
88
+ time_var: str,
89
+ event_var: str,
90
+ group_var: str | None = None,
91
+ ) -> tuple[str, object | list]:
92
+ """Fit Kaplan-Meier survival curves.
93
+
94
+ Returns summary string and fitted KMF object(s).
95
+ """
96
+ _try_import_lifelines()
97
+ from lifelines import KaplanMeierFitter
98
+
99
+ pdf = df.select([time_var, event_var] + ([group_var] if group_var else [])).to_pandas().dropna()
100
+
101
+ if group_var is None:
102
+ kmf = KaplanMeierFitter()
103
+ kmf.fit(pdf[time_var], event_observed=pdf[event_var])
104
+ median = kmf.median_survival_time_
105
+ lines = [
106
+ f"Kaplan-Meier Estimate (N={len(pdf)})",
107
+ f" Median survival time: {median:.2f}" if np.isfinite(median) else " Median survival time: not reached",
108
+ f" Events: {int(pdf[event_var].sum())}",
109
+ f" Censored: {int(len(pdf) - pdf[event_var].sum())}",
110
+ ]
111
+ return "\n".join(lines), kmf
112
+ else:
113
+ groups = sorted(pdf[group_var].unique())
114
+ kmfs = []
115
+ lines = [f"Kaplan-Meier Estimates by {group_var}:"]
116
+ for g in groups:
117
+ mask = pdf[group_var] == g
118
+ sub = pdf[mask]
119
+ kmf = KaplanMeierFitter()
120
+ kmf.fit(sub[time_var], event_observed=sub[event_var], label=str(g))
121
+ kmfs.append(kmf)
122
+ median = kmf.median_survival_time_
123
+ lines.append(
124
+ f"\n Group {g} (N={len(sub)}):"
125
+ f"\n Median survival: {median:.2f}" if np.isfinite(median)
126
+ else f"\n Group {g} (N={len(sub)}):\n Median survival: not reached"
127
+ )
128
+ lines.append(f" Events: {int(sub[event_var].sum())}")
129
+ return "\n".join(lines), kmfs
130
+
131
+
132
+ def log_rank_test(
133
+ df: pl.DataFrame,
134
+ time_var: str,
135
+ event_var: str,
136
+ group_var: str,
137
+ ) -> str:
138
+ """Log-rank test comparing survival between groups."""
139
+ _try_import_lifelines()
140
+ from lifelines.statistics import logrank_test
141
+
142
+ pdf = df.select([time_var, event_var, group_var]).to_pandas().dropna()
143
+ groups = sorted(pdf[group_var].unique())
144
+
145
+ if len(groups) < 2:
146
+ return "Log-rank test requires at least 2 groups."
147
+
148
+ if len(groups) == 2:
149
+ g1 = pdf[pdf[group_var] == groups[0]]
150
+ g2 = pdf[pdf[group_var] == groups[1]]
151
+ result = logrank_test(
152
+ g1[time_var], g2[time_var],
153
+ event_observed_A=g1[event_var],
154
+ event_observed_B=g2[event_var],
155
+ )
156
+ lines = [
157
+ f"Log-Rank Test: {group_var}",
158
+ f" Groups: {groups[0]} vs {groups[1]}",
159
+ f" Test statistic: {result.test_statistic:.4f}",
160
+ f" p-value: {result.p_value:.4f}",
161
+ ]
162
+ if result.p_value < 0.05:
163
+ lines.append(" ⚠ Significant difference in survival between groups")
164
+ else:
165
+ lines.append(" ✓ No significant difference in survival")
166
+ return "\n".join(lines)
167
+ else:
168
+ # Pairwise for >2 groups
169
+ lines = [f"Log-Rank Tests (pairwise): {group_var}"]
170
+ for i, g1 in enumerate(groups):
171
+ for g2 in groups[i + 1:]:
172
+ d1 = pdf[pdf[group_var] == g1]
173
+ d2 = pdf[pdf[group_var] == g2]
174
+ result = logrank_test(
175
+ d1[time_var], d2[time_var],
176
+ event_observed_A=d1[event_var],
177
+ event_observed_B=d2[event_var],
178
+ )
179
+ sig = "*" if result.p_value < 0.05 else ""
180
+ lines.append(f" {g1} vs {g2}: chi2={result.test_statistic:.3f}, p={result.p_value:.4f}{sig}")
181
+ return "\n".join(lines)
182
+
183
+
184
+ def schoenfeld_test(cph_result) -> str:
185
+ """Test proportional hazards assumption via Schoenfeld residuals."""
186
+ try:
187
+ ph_test = cph_result.check_assumptions(show_plots=False, p_value_threshold=1.0)
188
+ lines = ["Proportional Hazards Test (Schoenfeld Residuals):"]
189
+ # check_assumptions returns summary or prints it
190
+ if ph_test is not None:
191
+ lines.append(str(ph_test))
192
+ else:
193
+ lines.append(" PH assumption appears satisfied for all covariates.")
194
+ return "\n".join(lines)
195
+ except Exception as e:
196
+ return f"PH test: {e}"