marginaleffects 0.5.0__tar.gz → 0.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/PKG-INFO +1 -1
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/by.py +1 -3
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/comparisons.py +23 -32
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/estimands.py +4 -4
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/common.py +31 -36
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/comparison.py +2 -3
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/newdata.py +10 -5
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/variables.py +9 -15
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/vcov.py +2 -1
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/uncertainty.py +20 -27
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/utils.py +1 -1
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/PKG-INFO +1 -1
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/pyproject.toml +1 -1
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_bugfix.py +18 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_comparisons.py +1 -1
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_pyfixest.py +1 -1
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/README.md +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/benchmarks/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/benchmarks/benchmark_autodiff.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/comparisons.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/dispatch.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/comparisons.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/families.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/predictions.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/comparisons.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/predictions.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/utils.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/classes/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/classes/model.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/classes/result.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/datagrid.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/datasets.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/docstrings/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/docstrings/params.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/docstrings/qmd.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/formula.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/linearmodels/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/linearmodels/model.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/comparisons.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/predictions.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/slopes.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/predictions.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/pyfixest/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/pyfixest/model.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/by.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/categorical.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/deprecated.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/hypothesis_null.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/sanitize_model.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/utils.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/validation.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/settings.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sklearn/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sklearn/model.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/slopes.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/statsmodels/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/statsmodels/model.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/core.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/equivalence.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/formula.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/joint.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/main.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/transform.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/SOURCES.txt +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/dependency_links.txt +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/requires.txt +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/top_level.txt +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/setup.cfg +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/__init__.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/helpers.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_analytic.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_autodiff.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_by.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_categorical.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_categorical_validation.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_comparisons_interaction.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_datagrid_01.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_datagrid_02.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_equivalence.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_formula.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_formulaic_utils.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_hypotheses.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_hypotheses_joint.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_hypothesis.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_jss.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_linearmodels_panelols.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_missing.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_newdata.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_plot_comparisons.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_plot_predictions.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_plot_slopes.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_predictions.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_sklearn.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_slopes.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_logit.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_mixedlm.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_mnlogit.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_negativebinomial.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_ols.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_ordinal.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_poisson.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_probit.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_quantreg.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_vcov.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_wls.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_typical.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_utils.py +0 -0
- {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/utilities.py +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import polars as pl
|
|
2
|
-
import numpy as np
|
|
3
2
|
from typing import List, Optional, Tuple
|
|
4
3
|
|
|
5
4
|
|
|
@@ -66,8 +65,7 @@ def _get_by_internal(
|
|
|
66
65
|
else:
|
|
67
66
|
out = pl.DataFrame({"estimate": estimand["estimate"]})
|
|
68
67
|
|
|
69
|
-
by =
|
|
70
|
-
by = np.unique(by)
|
|
68
|
+
by = list(dict.fromkeys(x for x in by if x in out.columns))
|
|
71
69
|
|
|
72
70
|
if isinstance(by, list) and len(by) == 0:
|
|
73
71
|
if return_groups and "rowid" in out.columns:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
|
+
import warnings
|
|
2
3
|
from functools import reduce
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
@@ -118,28 +119,22 @@ def _build_comparison_frames(newdata, variables, cross):
|
|
|
118
119
|
hi.append(hi_row)
|
|
119
120
|
lo.append(lo_row)
|
|
120
121
|
else:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
122
|
+
nd_row = newdata.clone()
|
|
123
|
+
hi_row = newdata.clone()
|
|
124
|
+
lo_row = newdata.clone()
|
|
124
125
|
for v in variables:
|
|
125
126
|
vcomp = "custom" if callable(v.comparison) else v.comparison
|
|
126
|
-
|
|
127
|
+
shared = [
|
|
127
128
|
pl.lit(v.variable).alias("term"),
|
|
128
129
|
pl.lit(v.lab).alias(f"contrast_{v.variable}"),
|
|
129
130
|
pl.lit(vcomp).alias("marginaleffects_comparison"),
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
lo[0] = lo[0].with_columns(
|
|
138
|
-
pl.lit(v.lo).alias(v.variable),
|
|
139
|
-
pl.lit(v.variable).alias("term"),
|
|
140
|
-
pl.lit(v.lab).alias(f"contrast_{v.variable}"),
|
|
141
|
-
pl.lit(vcomp).alias("marginaleffects_comparison"),
|
|
142
|
-
)
|
|
131
|
+
]
|
|
132
|
+
nd_row = nd_row.with_columns(*shared)
|
|
133
|
+
hi_row = hi_row.with_columns(pl.lit(v.hi).alias(v.variable), *shared)
|
|
134
|
+
lo_row = lo_row.with_columns(pl.lit(v.lo).alias(v.variable), *shared)
|
|
135
|
+
nd.append(nd_row)
|
|
136
|
+
hi.append(hi_row)
|
|
137
|
+
lo.append(lo_row)
|
|
143
138
|
return nd, hi, lo
|
|
144
139
|
|
|
145
140
|
|
|
@@ -166,9 +161,10 @@ def _finalize_counterfactual_frames(
|
|
|
166
161
|
pad_df = upcast(pad_df, hi)
|
|
167
162
|
nd = upcast(nd, hi)
|
|
168
163
|
|
|
169
|
-
|
|
164
|
+
dfs = {"nd": nd, "hi": hi, "lo": lo}
|
|
170
165
|
|
|
171
|
-
for df_name
|
|
166
|
+
for df_name in dfs:
|
|
167
|
+
df = dfs[df_name]
|
|
172
168
|
common_cols = set(pad_df.columns) & set(df.columns)
|
|
173
169
|
for col in common_cols:
|
|
174
170
|
pad_dtype = str(pad_df[col].dtype)
|
|
@@ -189,8 +185,8 @@ def _finalize_counterfactual_frames(
|
|
|
189
185
|
.alias(col)
|
|
190
186
|
)
|
|
191
187
|
except Exception as e:
|
|
192
|
-
|
|
193
|
-
f"
|
|
188
|
+
warnings.warn(
|
|
189
|
+
f"Could not convert List column {col} to strings: {e}"
|
|
194
190
|
)
|
|
195
191
|
try:
|
|
196
192
|
if col in pad_df.columns and pad_df.height > 0:
|
|
@@ -198,7 +194,7 @@ def _finalize_counterfactual_frames(
|
|
|
198
194
|
if col in df.columns and df.height > 0:
|
|
199
195
|
df = df.explode(col)
|
|
200
196
|
except Exception as e2:
|
|
201
|
-
|
|
197
|
+
warnings.warn(f"Could not explode List column {col}: {e2}")
|
|
202
198
|
if col in pad_df.columns:
|
|
203
199
|
pad_df = pad_df.with_columns(
|
|
204
200
|
pad_df[col].cast(pl.String).alias(col)
|
|
@@ -206,12 +202,9 @@ def _finalize_counterfactual_frames(
|
|
|
206
202
|
if col in df.columns:
|
|
207
203
|
df = df.with_columns(df[col].cast(pl.String).alias(col))
|
|
208
204
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
hi = df
|
|
213
|
-
elif df_name == "lo":
|
|
214
|
-
lo = df
|
|
205
|
+
dfs[df_name] = df
|
|
206
|
+
|
|
207
|
+
nd, hi, lo = dfs["nd"], dfs["hi"], dfs["lo"]
|
|
215
208
|
|
|
216
209
|
nd = pl.concat([pad_df, nd], how="diagonal")
|
|
217
210
|
hi = pl.concat([pad_df, hi], how="diagonal")
|
|
@@ -221,9 +214,7 @@ def _finalize_counterfactual_frames(
|
|
|
221
214
|
categorical_list_cols = []
|
|
222
215
|
for col in list_cols:
|
|
223
216
|
dtype_str = str(nd[col].dtype)
|
|
224
|
-
if (
|
|
225
|
-
"Enum(" in dtype_str or "String" in dtype_str or "UInt32" in dtype_str
|
|
226
|
-
) and col in ["Region"]:
|
|
217
|
+
if "Enum(" in dtype_str or "String" in dtype_str or "UInt32" in dtype_str:
|
|
227
218
|
categorical_list_cols.append(col)
|
|
228
219
|
|
|
229
220
|
if categorical_list_cols:
|
|
@@ -241,7 +232,7 @@ def _prepare_design_matrices(model, nd, hi, lo, pad_rows):
|
|
|
241
232
|
lo_X = model.get_exog(lo)
|
|
242
233
|
nd_X = model.get_exog(nd)
|
|
243
234
|
|
|
244
|
-
if pad_rows
|
|
235
|
+
if pad_rows > 0:
|
|
245
236
|
nd_X = nd_X[pad_rows:]
|
|
246
237
|
hi_X = hi_X[pad_rows:]
|
|
247
238
|
lo_X = lo_X[pad_rows:]
|
|
@@ -42,12 +42,12 @@ estimands = {
|
|
|
42
42
|
"ratio": lambda hi, lo, eps, x, y, w: prep(hi / lo),
|
|
43
43
|
"ratioavg": lambda hi, lo, eps, x, y, w: prep(hi.mean() / lo.mean()),
|
|
44
44
|
"ratioavgwts": lambda hi, lo, eps, x, y, w: prep(
|
|
45
|
-
(hi * w).sum() / w.sum() / (lo * w).sum() / w.sum()
|
|
45
|
+
((hi * w).sum() / w.sum()) / ((lo * w).sum() / w.sum())
|
|
46
46
|
),
|
|
47
47
|
"lnratio": lambda hi, lo, eps, x, y, w: prep(np.log(hi / lo)),
|
|
48
48
|
"lnratioavg": lambda hi, lo, eps, x, y, w: prep(np.log(hi.mean() / lo.mean())),
|
|
49
49
|
"lnratioavgwts": lambda hi, lo, eps, x, y, w: prep(
|
|
50
|
-
np.log((hi * w).sum() / w.sum() / (lo * w).sum() / w.sum())
|
|
50
|
+
np.log(((hi * w).sum() / w.sum()) / ((lo * w).sum() / w.sum()))
|
|
51
51
|
),
|
|
52
52
|
"lnor": lambda hi, lo, eps, x, y, w: prep(
|
|
53
53
|
np.log((hi / (1 - hi)) / (lo / (1 - lo)))
|
|
@@ -69,7 +69,7 @@ estimands = {
|
|
|
69
69
|
"expdydxavg": lambda hi, lo, eps, x, y, w: prep(
|
|
70
70
|
np.mean(((hi.exp() - lo.exp()) / np.exp(eps)) / eps)
|
|
71
71
|
),
|
|
72
|
-
"expdydxavgwts": lambda hi, lo, eps, x, y, w: (
|
|
73
|
-
|
|
72
|
+
"expdydxavgwts": lambda hi, lo, eps, x, y, w: prep(
|
|
73
|
+
((((np.exp(hi) - np.exp(lo)) / np.exp(eps)) / eps) * w).sum() / w.sum()
|
|
74
74
|
),
|
|
75
75
|
}
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
from ..datagrid import datagrid # noqa
|
|
3
5
|
from ..sanitize import sanitize_model
|
|
@@ -7,7 +9,7 @@ import polars as pl
|
|
|
7
9
|
def dt_on_condition(model, condition):
|
|
8
10
|
model = sanitize_model(model)
|
|
9
11
|
|
|
10
|
-
condition_new = condition
|
|
12
|
+
condition_new = copy.deepcopy(condition)
|
|
11
13
|
|
|
12
14
|
# not sure why newdata gets added
|
|
13
15
|
modeldata = model.get_modeldata()
|
|
@@ -19,28 +21,24 @@ def dt_on_condition(model, condition):
|
|
|
19
21
|
first_key = "" # special case when the first element is numeric
|
|
20
22
|
|
|
21
23
|
if isinstance(condition_new, list):
|
|
22
|
-
|
|
23
|
-
"All elements of condition must be columns of the model."
|
|
24
|
-
)
|
|
24
|
+
if not all(ele in modeldata.columns for ele in condition_new):
|
|
25
|
+
raise ValueError("All elements of condition must be columns of the model.")
|
|
25
26
|
first_key = condition_new[0]
|
|
26
27
|
to_datagrid = {key: None for key in condition_new}
|
|
27
28
|
|
|
28
29
|
elif isinstance(condition_new, dict):
|
|
29
|
-
|
|
30
|
-
"All keys of condition must be columns of the model."
|
|
31
|
-
)
|
|
30
|
+
if not all(key in modeldata.columns for key in condition_new.keys()):
|
|
31
|
+
raise ValueError("All keys of condition must be columns of the model.")
|
|
32
32
|
first_key = next(iter(condition_new))
|
|
33
|
-
to_datagrid =
|
|
34
|
-
condition_new # third pointer to the same object? looks like a BUG
|
|
35
|
-
)
|
|
33
|
+
to_datagrid = condition_new
|
|
36
34
|
|
|
37
|
-
# not sure why `newdata` sometimes gets added
|
|
38
35
|
if isinstance(condition_new, dict) and "newdata" in to_datagrid.keys():
|
|
39
36
|
condition_new.pop("newdata", None)
|
|
40
37
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
if not (1 <= len(condition_new) <= 4):
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"Length of condition must be inclusively between 1 and 4. Got: {len(condition_new)}."
|
|
41
|
+
)
|
|
44
42
|
|
|
45
43
|
for key, value in to_datagrid.items():
|
|
46
44
|
variable_type = model.get_variable_type(key)
|
|
@@ -51,20 +49,17 @@ def dt_on_condition(model, condition):
|
|
|
51
49
|
)
|
|
52
50
|
|
|
53
51
|
elif variable_type in ["character"]:
|
|
54
|
-
# get specified names of the condition
|
|
55
|
-
# here is the BUG, we take the values of "species" back from the model
|
|
56
52
|
to_datagrid[key] = (
|
|
57
53
|
to_datagrid[key]
|
|
58
54
|
if to_datagrid[key]
|
|
59
55
|
else modeldata[key].unique().sort().to_list()
|
|
60
56
|
)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
57
|
+
if len(to_datagrid[key]) > 10:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Character type variables of more than 10 unique values are not supported. {key} variable has {len(to_datagrid[key])} unique values."
|
|
60
|
+
)
|
|
64
61
|
|
|
65
62
|
elif variable_type in ["boolean", "binary"]:
|
|
66
|
-
# get specified names of the condition
|
|
67
|
-
# here is the BUG, we take the values of "species" back from the model
|
|
68
63
|
if to_datagrid[key] is None:
|
|
69
64
|
to_datagrid[key] = modeldata[key].unique().sort().to_list()
|
|
70
65
|
|
|
@@ -131,15 +126,14 @@ def ordered_cat(dt, k, lab):
|
|
|
131
126
|
|
|
132
127
|
|
|
133
128
|
def validate_plot_args(condition, by, newdata, wts):
|
|
134
|
-
|
|
135
|
-
"The `newdata` argument requires a `by` argument."
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
)
|
|
129
|
+
if not by and newdata is not None:
|
|
130
|
+
raise ValueError("The `newdata` argument requires a `by` argument.")
|
|
131
|
+
if wts is not None and not by:
|
|
132
|
+
raise ValueError("The `wts` argument requires a `by` argument.")
|
|
133
|
+
if not ((condition is None and by) or (condition is not None and not by)):
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"One of the `condition` and `by` arguments must be supplied, but not both."
|
|
136
|
+
)
|
|
143
137
|
|
|
144
138
|
|
|
145
139
|
def extract_var_list(condition, by):
|
|
@@ -158,9 +152,10 @@ def extract_var_list(condition, by):
|
|
|
158
152
|
|
|
159
153
|
var_list = [x for x in var_list if x not in ["newdata", "model"]]
|
|
160
154
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
155
|
+
if len(var_list) >= 5:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"The `condition` and `by` arguments can have a max length of 4."
|
|
158
|
+
)
|
|
164
159
|
|
|
165
160
|
return var_list
|
|
166
161
|
|
|
@@ -286,10 +281,10 @@ def plot_common(model, dt, y_label, var_list, gray=False, points=0):
|
|
|
286
281
|
if len(var_list) > 1:
|
|
287
282
|
if gray:
|
|
288
283
|
# get the number of unique values in the column "var_list[1]"
|
|
289
|
-
unique_values = dt[var_list[1]].unique()
|
|
290
|
-
if unique_values > 5:
|
|
284
|
+
unique_values = dt[var_list[1]].unique()
|
|
285
|
+
if unique_values.len() > 5:
|
|
291
286
|
raise ValueError(
|
|
292
|
-
f"The number of elements in the second position of the `condition` or `by` argument (variable {var_list[1]}) cannot exceed 5. It has currently {len(
|
|
287
|
+
f"The number of elements in the second position of the `condition` or `by` argument (variable {var_list[1]}) cannot exceed 5. It has currently {unique_values.len()} elements, with values {unique_values.to_list()}."
|
|
293
288
|
)
|
|
294
289
|
custom_line_types = [
|
|
295
290
|
"solid",
|
|
@@ -73,8 +73,7 @@ def sanitize_comparison(comparison, by, wts=None):
|
|
|
73
73
|
"expdydx": "exp(dY/dX)",
|
|
74
74
|
}
|
|
75
75
|
|
|
76
|
-
|
|
77
|
-
f"`comparison` must be one of: {', '.join(list(lab.keys()))}."
|
|
78
|
-
)
|
|
76
|
+
if out not in lab.keys():
|
|
77
|
+
raise ValueError(f"`comparison` must be one of: {', '.join(list(lab.keys()))}.")
|
|
79
78
|
|
|
80
79
|
return (out, lab[out])
|
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import polars as pl
|
|
3
3
|
|
|
4
|
-
from ..datagrid import datagrid
|
|
5
|
-
from ..utils import ingest, upcast
|
|
6
4
|
from ..formula import listwise_deletion
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
def sanitize_newdata(model, newdata, wts, by=[]):
|
|
8
|
+
# Lazy imports to break the `datagrid -> utils -> sanitize -> newdata -> ...`
|
|
9
|
+
# circular import that fires when `datagrid` is the first symbol pulled from
|
|
10
|
+
# marginaleffects in a fresh interpreter (see GH #1724).
|
|
11
|
+
from ..datagrid import datagrid
|
|
12
|
+
from ..utils import ingest, upcast
|
|
13
|
+
|
|
10
14
|
modeldata = model.get_modeldata()
|
|
11
15
|
|
|
12
16
|
if newdata is None:
|
|
@@ -72,9 +76,10 @@ def sanitize_newdata(model, newdata, wts, by=[]):
|
|
|
72
76
|
"contrast",
|
|
73
77
|
"statistic",
|
|
74
78
|
}
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
if set(out.columns) & reserved_names:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"Input data contain reserved column name(s): {set(out.columns).intersection(reserved_names)}"
|
|
82
|
+
)
|
|
78
83
|
|
|
79
84
|
datagrid_explicit = None
|
|
80
85
|
if isinstance(newdata, pl.DataFrame) and hasattr(newdata, "datagrid_explicit"):
|
|
@@ -11,17 +11,11 @@ HiLo = namedtuple("HiLo", ["variable", "hi", "lo", "lab", "pad", "comparison"])
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def _clean_global(k, n):
|
|
14
|
-
if (
|
|
15
|
-
|
|
16
|
-
and not isinstance(k, pl.Series)
|
|
17
|
-
and not isinstance(k, np.ndarray)
|
|
18
|
-
):
|
|
19
|
-
out = [k]
|
|
14
|
+
if isinstance(k, (pl.Series, np.ndarray)):
|
|
15
|
+
return pl.Series(k) if len(k) > 1 else pl.Series(np.repeat(k[0], n))
|
|
20
16
|
if not isinstance(k, list) or len(k) == 1:
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
out = pl.Series(k)
|
|
24
|
-
return out
|
|
17
|
+
return pl.Series(np.repeat(k, n))
|
|
18
|
+
return pl.Series(k)
|
|
25
19
|
|
|
26
20
|
|
|
27
21
|
def _get_one_variable_hi_lo(
|
|
@@ -153,9 +147,10 @@ def _get_one_variable_hi_lo(
|
|
|
153
147
|
|
|
154
148
|
elif callable(value):
|
|
155
149
|
tmp = value(newdata[variable])
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
150
|
+
if tmp.shape[1] != 2:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"The function passed to `variables` must return a DataFrame with two columns. Got {tmp.shape[1]}."
|
|
153
|
+
)
|
|
159
154
|
lo = tmp[:, 0]
|
|
160
155
|
hi = tmp[:, 1]
|
|
161
156
|
lab = "custom"
|
|
@@ -225,9 +220,8 @@ def sanitize_variables(
|
|
|
225
220
|
)
|
|
226
221
|
|
|
227
222
|
elif isinstance(variables, dict):
|
|
228
|
-
for v in variables:
|
|
223
|
+
for v in list(variables.keys()):
|
|
229
224
|
if v not in newdata.columns:
|
|
230
|
-
del variables[v]
|
|
231
225
|
warn(f"Variable {v} is not in newdata.")
|
|
232
226
|
else:
|
|
233
227
|
out.append(
|
|
@@ -14,5 +14,6 @@ def sanitize_vcov(vcov, model):
|
|
|
14
14
|
|
|
15
15
|
V = model.get_vcov(vcov)
|
|
16
16
|
if V is not None:
|
|
17
|
-
|
|
17
|
+
if not isinstance(V, np.ndarray):
|
|
18
|
+
raise TypeError("vcov must be True or a square NumPy array")
|
|
18
19
|
return V
|
|
@@ -6,36 +6,29 @@ import scipy.stats as stats
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def get_jacobian(func, coefs, eps_vcov=None):
|
|
9
|
-
|
|
9
|
+
original_shape = None
|
|
10
10
|
if coefs.ndim == 2:
|
|
11
|
+
original_shape = coefs.shape
|
|
11
12
|
if isinstance(coefs, np.ndarray):
|
|
12
13
|
coefs_flat = coefs.flatten(order="F")
|
|
13
14
|
else:
|
|
14
15
|
coefs_flat = coefs.to_numpy().flatten(order="F")
|
|
15
|
-
baseline = func(coefs)["estimate"].to_numpy()
|
|
16
|
-
jac = np.empty((baseline.shape[0], len(coefs_flat)), dtype=np.float64)
|
|
17
|
-
for i, xi in enumerate(coefs_flat):
|
|
18
|
-
if eps_vcov is not None:
|
|
19
|
-
h = eps_vcov
|
|
20
|
-
else:
|
|
21
|
-
h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10)
|
|
22
|
-
dx = np.copy(coefs_flat)
|
|
23
|
-
dx[i] = dx[i] + h
|
|
24
|
-
tmp = dx.reshape(coefs.shape, order="F")
|
|
25
|
-
jac[:, i] = (func(tmp)["estimate"].to_numpy() - baseline) / h
|
|
26
|
-
return jac
|
|
27
16
|
else:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
17
|
+
coefs_flat = np.asarray(coefs)
|
|
18
|
+
|
|
19
|
+
baseline = func(coefs)["estimate"].to_numpy()
|
|
20
|
+
jac = np.empty((baseline.shape[0], len(coefs_flat)), dtype=np.float64)
|
|
21
|
+
for i, xi in enumerate(coefs_flat):
|
|
22
|
+
if eps_vcov is not None:
|
|
23
|
+
h = eps_vcov
|
|
24
|
+
else:
|
|
25
|
+
h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10)
|
|
26
|
+
dx = np.copy(coefs_flat)
|
|
27
|
+
dx[i] = dx[i] + h
|
|
28
|
+
if original_shape is not None:
|
|
29
|
+
dx = dx.reshape(original_shape, order="F")
|
|
30
|
+
jac[:, i] = (func(dx)["estimate"].to_numpy() - baseline) / h
|
|
31
|
+
return jac
|
|
39
32
|
|
|
40
33
|
|
|
41
34
|
def get_se(J, V):
|
|
@@ -65,7 +58,7 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0):
|
|
|
65
58
|
"statistic"
|
|
66
59
|
)
|
|
67
60
|
)
|
|
68
|
-
if hasattr(model, "df_resid") and isinstance(model.df_resid, float):
|
|
61
|
+
if hasattr(model, "df_resid") and isinstance(model.df_resid, (int, float)):
|
|
69
62
|
dof = model.df_resid
|
|
70
63
|
else:
|
|
71
64
|
dof = np.inf
|
|
@@ -93,6 +86,6 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0):
|
|
|
93
86
|
.map_batches(lambda x: -np.log2(x), return_dtype=pl.Float64)
|
|
94
87
|
.alias("s_value")
|
|
95
88
|
)
|
|
96
|
-
except Exception
|
|
97
|
-
|
|
89
|
+
except Exception:
|
|
90
|
+
pass
|
|
98
91
|
return df
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import sys
|
|
3
|
+
|
|
1
4
|
import numpy as np
|
|
2
5
|
import pandas as pd
|
|
3
6
|
import polars as pl
|
|
@@ -22,3 +25,18 @@ def test_issue_226_np_context():
|
|
|
22
25
|
out = predictions(mod, newdata=df)
|
|
23
26
|
assert isinstance(out, MarginaleffectsResult)
|
|
24
27
|
assert isinstance(out.data, pl.DataFrame)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_issue_1724():
|
|
31
|
+
# Circular import when `datagrid` is the first symbol pulled from
|
|
32
|
+
# marginaleffects in a fresh interpreter. Must run in a subprocess —
|
|
33
|
+
# the in-process pytest run has already warmed the import graph.
|
|
34
|
+
result = subprocess.run(
|
|
35
|
+
[sys.executable, "-c", "from marginaleffects import datagrid"],
|
|
36
|
+
capture_output=True,
|
|
37
|
+
text=True,
|
|
38
|
+
)
|
|
39
|
+
assert result.returncode == 0, (
|
|
40
|
+
f"Fresh-process import of `datagrid` failed.\n"
|
|
41
|
+
f"stdout: {result.stdout}\nstderr: {result.stderr}"
|
|
42
|
+
)
|
|
@@ -200,7 +200,7 @@ def test_lift():
|
|
|
200
200
|
cmp2 = comparisons(mod, comparison="liftavg")
|
|
201
201
|
assert cmp1.shape[0] == 32
|
|
202
202
|
assert cmp2.shape[0] == 1
|
|
203
|
-
with pytest.raises(
|
|
203
|
+
with pytest.raises(ValueError):
|
|
204
204
|
comparisons(mod, comparison="liftr")
|
|
205
205
|
|
|
206
206
|
|
|
@@ -204,7 +204,7 @@ def test_pyfixest_standard_errors_across_models():
|
|
|
204
204
|
fit_pois_fe = fepois("Y ~ X1 * X2 * Z1 | f1", data=poisson_data)
|
|
205
205
|
with pytest.warns(
|
|
206
206
|
UserWarning,
|
|
207
|
-
match="uncertainty in fixed-effects
|
|
207
|
+
match="cannot take into account the uncertainty in fixed-effects",
|
|
208
208
|
):
|
|
209
209
|
try:
|
|
210
210
|
comp_pois_fe = comparisons(fit_pois_fe)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/comparisons.py
RENAMED
|
File without changes
|
{marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/predictions.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|