marginaleffects 0.5.0__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/PKG-INFO +1 -1
  2. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/by.py +1 -3
  3. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/comparisons.py +23 -32
  4. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/estimands.py +4 -4
  5. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/common.py +31 -36
  6. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/comparison.py +2 -3
  7. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/newdata.py +10 -5
  8. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/variables.py +9 -15
  9. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/vcov.py +2 -1
  10. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/uncertainty.py +20 -27
  11. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/utils.py +1 -1
  12. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/PKG-INFO +1 -1
  13. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/pyproject.toml +1 -1
  14. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_bugfix.py +18 -0
  15. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_comparisons.py +1 -1
  16. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_pyfixest.py +1 -1
  17. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/README.md +0 -0
  18. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/benchmarks/__init__.py +0 -0
  19. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/benchmarks/benchmark_autodiff.py +0 -0
  20. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/__init__.py +0 -0
  21. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/__init__.py +0 -0
  22. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/comparisons.py +0 -0
  23. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/dispatch.py +0 -0
  24. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/__init__.py +0 -0
  25. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/comparisons.py +0 -0
  26. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/families.py +0 -0
  27. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/glm/predictions.py +0 -0
  28. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/__init__.py +0 -0
  29. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/comparisons.py +0 -0
  30. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/linear/predictions.py +0 -0
  31. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/autodiff/utils.py +0 -0
  32. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/classes/__init__.py +0 -0
  33. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/classes/model.py +0 -0
  34. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/classes/result.py +0 -0
  35. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/datagrid.py +0 -0
  36. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/datasets.py +0 -0
  37. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/docstrings/__init__.py +0 -0
  38. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/docstrings/params.py +0 -0
  39. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/docstrings/qmd.py +0 -0
  40. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/formula.py +0 -0
  41. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/linearmodels/__init__.py +0 -0
  42. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/linearmodels/model.py +0 -0
  43. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/__init__.py +0 -0
  44. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/comparisons.py +0 -0
  45. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/predictions.py +0 -0
  46. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/plot/slopes.py +0 -0
  47. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/predictions.py +0 -0
  48. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/pyfixest/__init__.py +0 -0
  49. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/pyfixest/model.py +0 -0
  50. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/__init__.py +0 -0
  51. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/by.py +0 -0
  52. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/categorical.py +0 -0
  53. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/deprecated.py +0 -0
  54. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/hypothesis_null.py +0 -0
  55. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/sanitize_model.py +0 -0
  56. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/utils.py +0 -0
  57. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sanitize/validation.py +0 -0
  58. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/settings.py +0 -0
  59. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sklearn/__init__.py +0 -0
  60. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/sklearn/model.py +0 -0
  61. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/slopes.py +0 -0
  62. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/statsmodels/__init__.py +0 -0
  63. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/statsmodels/model.py +0 -0
  64. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/__init__.py +0 -0
  65. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/core.py +0 -0
  66. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/equivalence.py +0 -0
  67. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/formula.py +0 -0
  68. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/joint.py +0 -0
  69. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/test/main.py +0 -0
  70. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects/transform.py +0 -0
  71. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/SOURCES.txt +0 -0
  72. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/dependency_links.txt +0 -0
  73. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/requires.txt +0 -0
  74. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/marginaleffects.egg-info/top_level.txt +0 -0
  75. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/setup.cfg +0 -0
  76. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/__init__.py +0 -0
  77. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/helpers.py +0 -0
  78. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_analytic.py +0 -0
  79. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_autodiff.py +0 -0
  80. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_by.py +0 -0
  81. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_categorical.py +0 -0
  82. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_categorical_validation.py +0 -0
  83. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_comparisons_interaction.py +0 -0
  84. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_datagrid_01.py +0 -0
  85. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_datagrid_02.py +0 -0
  86. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_equivalence.py +0 -0
  87. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_formula.py +0 -0
  88. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_formulaic_utils.py +0 -0
  89. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_hypotheses.py +0 -0
  90. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_hypotheses_joint.py +0 -0
  91. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_hypothesis.py +0 -0
  92. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_jss.py +0 -0
  93. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_linearmodels_panelols.py +0 -0
  94. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_missing.py +0 -0
  95. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_newdata.py +0 -0
  96. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_plot_comparisons.py +0 -0
  97. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_plot_predictions.py +0 -0
  98. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_plot_slopes.py +0 -0
  99. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_predictions.py +0 -0
  100. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_sklearn.py +0 -0
  101. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_slopes.py +0 -0
  102. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels.py +0 -0
  103. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_logit.py +0 -0
  104. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_mixedlm.py +0 -0
  105. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_mnlogit.py +0 -0
  106. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_negativebinomial.py +0 -0
  107. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_ols.py +0 -0
  108. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_ordinal.py +0 -0
  109. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_poisson.py +0 -0
  110. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_probit.py +0 -0
  111. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_quantreg.py +0 -0
  112. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_vcov.py +0 -0
  113. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_statsmodels_wls.py +0 -0
  114. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_typical.py +0 -0
  115. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/test_utils.py +0 -0
  116. {marginaleffects-0.5.0 → marginaleffects-0.5.1}/tests/utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: marginaleffects
3
- Version: 0.5.0
3
+ Version: 0.5.1
4
4
  Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
5
5
  License-Expression: GPL-3.0-or-later
6
6
  Requires-Python: >=3.10
@@ -1,5 +1,4 @@
1
1
  import polars as pl
2
- import numpy as np
3
2
  from typing import List, Optional, Tuple
4
3
 
5
4
 
@@ -66,8 +65,7 @@ def _get_by_internal(
66
65
  else:
67
66
  out = pl.DataFrame({"estimate": estimand["estimate"]})
68
67
 
69
- by = [x for x in by if x in out.columns]
70
- by = np.unique(by)
68
+ by = list(dict.fromkeys(x for x in by if x in out.columns))
71
69
 
72
70
  if isinstance(by, list) and len(by) == 0:
73
71
  if return_groups and "rowid" in out.columns:
@@ -1,4 +1,5 @@
1
1
  import re
2
+ import warnings
2
3
  from functools import reduce
3
4
 
4
5
  import numpy as np
@@ -118,28 +119,22 @@ def _build_comparison_frames(newdata, variables, cross):
118
119
  hi.append(hi_row)
119
120
  lo.append(lo_row)
120
121
  else:
121
- hi.append(newdata)
122
- lo.append(newdata)
123
- nd.append(newdata)
122
+ nd_row = newdata.clone()
123
+ hi_row = newdata.clone()
124
+ lo_row = newdata.clone()
124
125
  for v in variables:
125
126
  vcomp = "custom" if callable(v.comparison) else v.comparison
126
- nd[0] = nd[0].with_columns(
127
+ shared = [
127
128
  pl.lit(v.variable).alias("term"),
128
129
  pl.lit(v.lab).alias(f"contrast_{v.variable}"),
129
130
  pl.lit(vcomp).alias("marginaleffects_comparison"),
130
- )
131
- hi[0] = hi[0].with_columns(
132
- pl.lit(v.hi).alias(v.variable),
133
- pl.lit(v.variable).alias("term"),
134
- pl.lit(v.lab).alias(f"contrast_{v.variable}"),
135
- pl.lit(vcomp).alias("marginaleffects_comparison"),
136
- )
137
- lo[0] = lo[0].with_columns(
138
- pl.lit(v.lo).alias(v.variable),
139
- pl.lit(v.variable).alias("term"),
140
- pl.lit(v.lab).alias(f"contrast_{v.variable}"),
141
- pl.lit(vcomp).alias("marginaleffects_comparison"),
142
- )
131
+ ]
132
+ nd_row = nd_row.with_columns(*shared)
133
+ hi_row = hi_row.with_columns(pl.lit(v.hi).alias(v.variable), *shared)
134
+ lo_row = lo_row.with_columns(pl.lit(v.lo).alias(v.variable), *shared)
135
+ nd.append(nd_row)
136
+ hi.append(hi_row)
137
+ lo.append(lo_row)
143
138
  return nd, hi, lo
144
139
 
145
140
 
@@ -166,9 +161,10 @@ def _finalize_counterfactual_frames(
166
161
  pad_df = upcast(pad_df, hi)
167
162
  nd = upcast(nd, hi)
168
163
 
169
- dfs_to_align = [("nd", nd), ("hi", hi), ("lo", lo)]
164
+ dfs = {"nd": nd, "hi": hi, "lo": lo}
170
165
 
171
- for df_name, df in dfs_to_align:
166
+ for df_name in dfs:
167
+ df = dfs[df_name]
172
168
  common_cols = set(pad_df.columns) & set(df.columns)
173
169
  for col in common_cols:
174
170
  pad_dtype = str(pad_df[col].dtype)
@@ -189,8 +185,8 @@ def _finalize_counterfactual_frames(
189
185
  .alias(col)
190
186
  )
191
187
  except Exception as e:
192
- print(
193
- f"Warning: Could not convert List column {col} to strings: {e}"
188
+ warnings.warn(
189
+ f"Could not convert List column {col} to strings: {e}"
194
190
  )
195
191
  try:
196
192
  if col in pad_df.columns and pad_df.height > 0:
@@ -198,7 +194,7 @@ def _finalize_counterfactual_frames(
198
194
  if col in df.columns and df.height > 0:
199
195
  df = df.explode(col)
200
196
  except Exception as e2:
201
- print(f"Warning: Could not explode List column {col}: {e2}")
197
+ warnings.warn(f"Could not explode List column {col}: {e2}")
202
198
  if col in pad_df.columns:
203
199
  pad_df = pad_df.with_columns(
204
200
  pad_df[col].cast(pl.String).alias(col)
@@ -206,12 +202,9 @@ def _finalize_counterfactual_frames(
206
202
  if col in df.columns:
207
203
  df = df.with_columns(df[col].cast(pl.String).alias(col))
208
204
 
209
- if df_name == "nd":
210
- nd = df
211
- elif df_name == "hi":
212
- hi = df
213
- elif df_name == "lo":
214
- lo = df
205
+ dfs[df_name] = df
206
+
207
+ nd, hi, lo = dfs["nd"], dfs["hi"], dfs["lo"]
215
208
 
216
209
  nd = pl.concat([pad_df, nd], how="diagonal")
217
210
  hi = pl.concat([pad_df, hi], how="diagonal")
@@ -221,9 +214,7 @@ def _finalize_counterfactual_frames(
221
214
  categorical_list_cols = []
222
215
  for col in list_cols:
223
216
  dtype_str = str(nd[col].dtype)
224
- if (
225
- "Enum(" in dtype_str or "String" in dtype_str or "UInt32" in dtype_str
226
- ) and col in ["Region"]:
217
+ if "Enum(" in dtype_str or "String" in dtype_str or "UInt32" in dtype_str:
227
218
  categorical_list_cols.append(col)
228
219
 
229
220
  if categorical_list_cols:
@@ -241,7 +232,7 @@ def _prepare_design_matrices(model, nd, hi, lo, pad_rows):
241
232
  lo_X = model.get_exog(lo)
242
233
  nd_X = model.get_exog(nd)
243
234
 
244
- if pad_rows >= 0:
235
+ if pad_rows > 0:
245
236
  nd_X = nd_X[pad_rows:]
246
237
  hi_X = hi_X[pad_rows:]
247
238
  lo_X = lo_X[pad_rows:]
@@ -42,12 +42,12 @@ estimands = {
42
42
  "ratio": lambda hi, lo, eps, x, y, w: prep(hi / lo),
43
43
  "ratioavg": lambda hi, lo, eps, x, y, w: prep(hi.mean() / lo.mean()),
44
44
  "ratioavgwts": lambda hi, lo, eps, x, y, w: prep(
45
- (hi * w).sum() / w.sum() / (lo * w).sum() / w.sum()
45
+ ((hi * w).sum() / w.sum()) / ((lo * w).sum() / w.sum())
46
46
  ),
47
47
  "lnratio": lambda hi, lo, eps, x, y, w: prep(np.log(hi / lo)),
48
48
  "lnratioavg": lambda hi, lo, eps, x, y, w: prep(np.log(hi.mean() / lo.mean())),
49
49
  "lnratioavgwts": lambda hi, lo, eps, x, y, w: prep(
50
- np.log((hi * w).sum() / w.sum() / (lo * w).sum() / w.sum())
50
+ np.log(((hi * w).sum() / w.sum()) / ((lo * w).sum() / w.sum()))
51
51
  ),
52
52
  "lnor": lambda hi, lo, eps, x, y, w: prep(
53
53
  np.log((hi / (1 - hi)) / (lo / (1 - lo)))
@@ -69,7 +69,7 @@ estimands = {
69
69
  "expdydxavg": lambda hi, lo, eps, x, y, w: prep(
70
70
  np.mean(((hi.exp() - lo.exp()) / np.exp(eps)) / eps)
71
71
  ),
72
- "expdydxavgwts": lambda hi, lo, eps, x, y, w: (
73
- prep((((np.exp(hi) - np.exp(lo)) / np.exp(eps)) / eps) * w).sum() / w.sum()
72
+ "expdydxavgwts": lambda hi, lo, eps, x, y, w: prep(
73
+ ((((np.exp(hi) - np.exp(lo)) / np.exp(eps)) / eps) * w).sum() / w.sum()
74
74
  ),
75
75
  }
@@ -1,3 +1,5 @@
1
+ import copy
2
+
1
3
  import numpy as np
2
4
  from ..datagrid import datagrid # noqa
3
5
  from ..sanitize import sanitize_model
@@ -7,7 +9,7 @@ import polars as pl
7
9
  def dt_on_condition(model, condition):
8
10
  model = sanitize_model(model)
9
11
 
10
- condition_new = condition # two pointers to the same object? this looks like a bug
12
+ condition_new = copy.deepcopy(condition)
11
13
 
12
14
  # not sure why newdata gets added
13
15
  modeldata = model.get_modeldata()
@@ -19,28 +21,24 @@ def dt_on_condition(model, condition):
19
21
  first_key = "" # special case when the first element is numeric
20
22
 
21
23
  if isinstance(condition_new, list):
22
- assert all(ele in modeldata.columns for ele in condition_new), (
23
- "All elements of condition must be columns of the model."
24
- )
24
+ if not all(ele in modeldata.columns for ele in condition_new):
25
+ raise ValueError("All elements of condition must be columns of the model.")
25
26
  first_key = condition_new[0]
26
27
  to_datagrid = {key: None for key in condition_new}
27
28
 
28
29
  elif isinstance(condition_new, dict):
29
- assert all(key in modeldata.columns for key in condition_new.keys()), (
30
- "All keys of condition must be columns of the model."
31
- )
30
+ if not all(key in modeldata.columns for key in condition_new.keys()):
31
+ raise ValueError("All keys of condition must be columns of the model.")
32
32
  first_key = next(iter(condition_new))
33
- to_datagrid = (
34
- condition_new # third pointer to the same object? looks like a BUG
35
- )
33
+ to_datagrid = condition_new
36
34
 
37
- # not sure why `newdata` sometimes gets added
38
35
  if isinstance(condition_new, dict) and "newdata" in to_datagrid.keys():
39
36
  condition_new.pop("newdata", None)
40
37
 
41
- assert 1 <= len(condition_new) <= 4, (
42
- f"Lenght of condition must be inclusively between 1 and 4. Got : {len(condition_new)}."
43
- )
38
+ if not (1 <= len(condition_new) <= 4):
39
+ raise ValueError(
40
+ f"Length of condition must be inclusively between 1 and 4. Got: {len(condition_new)}."
41
+ )
44
42
 
45
43
  for key, value in to_datagrid.items():
46
44
  variable_type = model.get_variable_type(key)
@@ -51,20 +49,17 @@ def dt_on_condition(model, condition):
51
49
  )
52
50
 
53
51
  elif variable_type in ["character"]:
54
- # get specified names of the condition
55
- # here is the BUG, we take the values of "species" back from the model
56
52
  to_datagrid[key] = (
57
53
  to_datagrid[key]
58
54
  if to_datagrid[key]
59
55
  else modeldata[key].unique().sort().to_list()
60
56
  )
61
- assert len(to_datagrid[key]) <= 10, (
62
- f"Character type variables of more than 10 unique values are not supported. {key} variable has {len(to_datagrid[key])} unique values."
63
- )
57
+ if len(to_datagrid[key]) > 10:
58
+ raise ValueError(
59
+ f"Character type variables of more than 10 unique values are not supported. {key} variable has {len(to_datagrid[key])} unique values."
60
+ )
64
61
 
65
62
  elif variable_type in ["boolean", "binary"]:
66
- # get specified names of the condition
67
- # here is the BUG, we take the values of "species" back from the model
68
63
  if to_datagrid[key] is None:
69
64
  to_datagrid[key] = modeldata[key].unique().sort().to_list()
70
65
 
@@ -131,15 +126,14 @@ def ordered_cat(dt, k, lab):
131
126
 
132
127
 
133
128
  def validate_plot_args(condition, by, newdata, wts):
134
- assert not (not by and newdata is not None), (
135
- "The `newdata` argument requires a `by` argument."
136
- )
137
- assert not (wts is not None and not by), (
138
- "The `wts` argument requires a `by` argument."
139
- )
140
- assert (condition is None and by) or (condition is not None and not by), (
141
- "One of the `condition` and `by` arguments must be supplied, but not both."
142
- )
129
+ if not by and newdata is not None:
130
+ raise ValueError("The `newdata` argument requires a `by` argument.")
131
+ if wts is not None and not by:
132
+ raise ValueError("The `wts` argument requires a `by` argument.")
133
+ if not ((condition is None and by) or (condition is not None and not by)):
134
+ raise ValueError(
135
+ "One of the `condition` and `by` arguments must be supplied, but not both."
136
+ )
143
137
 
144
138
 
145
139
  def extract_var_list(condition, by):
@@ -158,9 +152,10 @@ def extract_var_list(condition, by):
158
152
 
159
153
  var_list = [x for x in var_list if x not in ["newdata", "model"]]
160
154
 
161
- assert len(var_list) < 5, (
162
- "The `condition` and `by` arguments can have a max length of 4."
163
- )
155
+ if len(var_list) >= 5:
156
+ raise ValueError(
157
+ "The `condition` and `by` arguments can have a max length of 4."
158
+ )
164
159
 
165
160
  return var_list
166
161
 
@@ -286,10 +281,10 @@ def plot_common(model, dt, y_label, var_list, gray=False, points=0):
286
281
  if len(var_list) > 1:
287
282
  if gray:
288
283
  # get the number of unique values in the column "var_list[1]"
289
- unique_values = dt[var_list[1]].unique().len()
290
- if unique_values > 5:
284
+ unique_values = dt[var_list[1]].unique()
285
+ if unique_values.len() > 5:
291
286
  raise ValueError(
292
- f"The number of elements in the second position of the `condition` or `by` argument (variable {var_list[1]}) cannot exceed 5. It has currently {len(unique_values)} elements, with values {unique_values}."
287
+ f"The number of elements in the second position of the `condition` or `by` argument (variable {var_list[1]}) cannot exceed 5. It has currently {unique_values.len()} elements, with values {unique_values.to_list()}."
293
288
  )
294
289
  custom_line_types = [
295
290
  "solid",
@@ -73,8 +73,7 @@ def sanitize_comparison(comparison, by, wts=None):
73
73
  "expdydx": "exp(dY/dX)",
74
74
  }
75
75
 
76
- assert out in lab.keys(), (
77
- f"`comparison` must be one of: {', '.join(list(lab.keys()))}."
78
- )
76
+ if out not in lab.keys():
77
+ raise ValueError(f"`comparison` must be one of: {', '.join(list(lab.keys()))}.")
79
78
 
80
79
  return (out, lab[out])
@@ -1,12 +1,16 @@
1
1
  import numpy as np
2
2
  import polars as pl
3
3
 
4
- from ..datagrid import datagrid
5
- from ..utils import ingest, upcast
6
4
  from ..formula import listwise_deletion
7
5
 
8
6
 
9
7
  def sanitize_newdata(model, newdata, wts, by=[]):
8
+ # Lazy imports to break the `datagrid -> utils -> sanitize -> newdata -> ...`
9
+ # circular import that fires when `datagrid` is the first symbol pulled from
10
+ # marginaleffects in a fresh interpreter (see GH #1724).
11
+ from ..datagrid import datagrid
12
+ from ..utils import ingest, upcast
13
+
10
14
  modeldata = model.get_modeldata()
11
15
 
12
16
  if newdata is None:
@@ -72,9 +76,10 @@ def sanitize_newdata(model, newdata, wts, by=[]):
72
76
  "contrast",
73
77
  "statistic",
74
78
  }
75
- assert not (set(out.columns) & reserved_names), (
76
- f"Input data contain reserved column name(s) : {set(out.columns).intersection(reserved_names)}"
77
- )
79
+ if set(out.columns) & reserved_names:
80
+ raise ValueError(
81
+ f"Input data contain reserved column name(s): {set(out.columns).intersection(reserved_names)}"
82
+ )
78
83
 
79
84
  datagrid_explicit = None
80
85
  if isinstance(newdata, pl.DataFrame) and hasattr(newdata, "datagrid_explicit"):
@@ -11,17 +11,11 @@ HiLo = namedtuple("HiLo", ["variable", "hi", "lo", "lab", "pad", "comparison"])
11
11
 
12
12
 
13
13
  def _clean_global(k, n):
14
- if (
15
- not isinstance(k, list)
16
- and not isinstance(k, pl.Series)
17
- and not isinstance(k, np.ndarray)
18
- ):
19
- out = [k]
14
+ if isinstance(k, (pl.Series, np.ndarray)):
15
+ return pl.Series(k) if len(k) > 1 else pl.Series(np.repeat(k[0], n))
20
16
  if not isinstance(k, list) or len(k) == 1:
21
- out = pl.Series(np.repeat(k, n))
22
- else:
23
- out = pl.Series(k)
24
- return out
17
+ return pl.Series(np.repeat(k, n))
18
+ return pl.Series(k)
25
19
 
26
20
 
27
21
  def _get_one_variable_hi_lo(
@@ -153,9 +147,10 @@ def _get_one_variable_hi_lo(
153
147
 
154
148
  elif callable(value):
155
149
  tmp = value(newdata[variable])
156
- assert tmp.shape[1] == 2, (
157
- f"The function passed to `variables` must return a DataFrame with two columns. Got {tmp.shape[1]}."
158
- )
150
+ if tmp.shape[1] != 2:
151
+ raise ValueError(
152
+ f"The function passed to `variables` must return a DataFrame with two columns. Got {tmp.shape[1]}."
153
+ )
159
154
  lo = tmp[:, 0]
160
155
  hi = tmp[:, 1]
161
156
  lab = "custom"
@@ -225,9 +220,8 @@ def sanitize_variables(
225
220
  )
226
221
 
227
222
  elif isinstance(variables, dict):
228
- for v in variables:
223
+ for v in list(variables.keys()):
229
224
  if v not in newdata.columns:
230
- del variables[v]
231
225
  warn(f"Variable {v} is not in newdata.")
232
226
  else:
233
227
  out.append(
@@ -14,5 +14,6 @@ def sanitize_vcov(vcov, model):
14
14
 
15
15
  V = model.get_vcov(vcov)
16
16
  if V is not None:
17
- assert isinstance(V, np.ndarray), "vcov must be True or a square NumPy array"
17
+ if not isinstance(V, np.ndarray):
18
+ raise TypeError("vcov must be True or a square NumPy array")
18
19
  return V
@@ -6,36 +6,29 @@ import scipy.stats as stats
6
6
 
7
7
 
8
8
  def get_jacobian(func, coefs, eps_vcov=None):
9
- # forward finite difference (faster)
9
+ original_shape = None
10
10
  if coefs.ndim == 2:
11
+ original_shape = coefs.shape
11
12
  if isinstance(coefs, np.ndarray):
12
13
  coefs_flat = coefs.flatten(order="F")
13
14
  else:
14
15
  coefs_flat = coefs.to_numpy().flatten(order="F")
15
- baseline = func(coefs)["estimate"].to_numpy()
16
- jac = np.empty((baseline.shape[0], len(coefs_flat)), dtype=np.float64)
17
- for i, xi in enumerate(coefs_flat):
18
- if eps_vcov is not None:
19
- h = eps_vcov
20
- else:
21
- h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10)
22
- dx = np.copy(coefs_flat)
23
- dx[i] = dx[i] + h
24
- tmp = dx.reshape(coefs.shape, order="F")
25
- jac[:, i] = (func(tmp)["estimate"].to_numpy() - baseline) / h
26
- return jac
27
16
  else:
28
- baseline = func(coefs)["estimate"].to_numpy()
29
- jac = np.empty((baseline.shape[0], len(coefs)), dtype=np.float64)
30
- for i, xi in enumerate(coefs):
31
- if eps_vcov is not None:
32
- h = eps_vcov
33
- else:
34
- h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10)
35
- dx = np.copy(coefs)
36
- dx[i] = dx[i] + h
37
- jac[:, i] = (func(dx)["estimate"].to_numpy() - baseline) / h
38
- return jac
17
+ coefs_flat = np.asarray(coefs)
18
+
19
+ baseline = func(coefs)["estimate"].to_numpy()
20
+ jac = np.empty((baseline.shape[0], len(coefs_flat)), dtype=np.float64)
21
+ for i, xi in enumerate(coefs_flat):
22
+ if eps_vcov is not None:
23
+ h = eps_vcov
24
+ else:
25
+ h = max(abs(xi) * np.sqrt(np.finfo(float).eps), 1e-10)
26
+ dx = np.copy(coefs_flat)
27
+ dx[i] = dx[i] + h
28
+ if original_shape is not None:
29
+ dx = dx.reshape(original_shape, order="F")
30
+ jac[:, i] = (func(dx)["estimate"].to_numpy() - baseline) / h
31
+ return jac
39
32
 
40
33
 
41
34
  def get_se(J, V):
@@ -65,7 +58,7 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0):
65
58
  "statistic"
66
59
  )
67
60
  )
68
- if hasattr(model, "df_resid") and isinstance(model.df_resid, float):
61
+ if hasattr(model, "df_resid") and isinstance(model.df_resid, (int, float)):
69
62
  dof = model.df_resid
70
63
  else:
71
64
  dof = np.inf
@@ -93,6 +86,6 @@ def get_z_p_ci(df, model, conf_level, hypothesis_null=0):
93
86
  .map_batches(lambda x: -np.log2(x), return_dtype=pl.Float64)
94
87
  .alias("s_value")
95
88
  )
96
- except Exception as e:
97
- print(f"An exception occurred: {e}")
89
+ except Exception:
90
+ pass
98
91
  return df
@@ -116,7 +116,7 @@ def upcast(df, reference):
116
116
  pl.Float64,
117
117
  ]
118
118
  for col in df.columns:
119
- if col in df.columns and col in reference.columns:
119
+ if col in reference.columns:
120
120
  good = reference[col].dtype
121
121
  bad = df[col].dtype
122
122
  if good != bad:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: marginaleffects
3
- Version: 0.5.0
3
+ Version: 0.5.1
4
4
  Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
5
5
  License-Expression: GPL-3.0-or-later
6
6
  Requires-Python: >=3.10
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "marginaleffects"
3
- version = "0.5.0"
3
+ version = "0.5.1"
4
4
  license = "GPL-3.0-or-later"
5
5
  description = "Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models."
6
6
  readme = "README.md"
@@ -1,3 +1,6 @@
1
+ import subprocess
2
+ import sys
3
+
1
4
  import numpy as np
2
5
  import pandas as pd
3
6
  import polars as pl
@@ -22,3 +25,18 @@ def test_issue_226_np_context():
22
25
  out = predictions(mod, newdata=df)
23
26
  assert isinstance(out, MarginaleffectsResult)
24
27
  assert isinstance(out.data, pl.DataFrame)
28
+
29
+
30
+ def test_issue_1724():
31
+ # Circular import when `datagrid` is the first symbol pulled from
32
+ # marginaleffects in a fresh interpreter. Must run in a subprocess —
33
+ # the in-process pytest run has already warmed the import graph.
34
+ result = subprocess.run(
35
+ [sys.executable, "-c", "from marginaleffects import datagrid"],
36
+ capture_output=True,
37
+ text=True,
38
+ )
39
+ assert result.returncode == 0, (
40
+ f"Fresh-process import of `datagrid` failed.\n"
41
+ f"stdout: {result.stdout}\nstderr: {result.stderr}"
42
+ )
@@ -200,7 +200,7 @@ def test_lift():
200
200
  cmp2 = comparisons(mod, comparison="liftavg")
201
201
  assert cmp1.shape[0] == 32
202
202
  assert cmp2.shape[0] == 1
203
- with pytest.raises(AssertionError):
203
+ with pytest.raises(ValueError):
204
204
  comparisons(mod, comparison="liftr")
205
205
 
206
206
 
@@ -204,7 +204,7 @@ def test_pyfixest_standard_errors_across_models():
204
204
  fit_pois_fe = fepois("Y ~ X1 * X2 * Z1 | f1", data=poisson_data)
205
205
  with pytest.warns(
206
206
  UserWarning,
207
- match="uncertainty in fixed-effects parameters when computing contrasts",
207
+ match="cannot take into account the uncertainty in fixed-effects",
208
208
  ):
209
209
  try:
210
210
  comp_pois_fe = comparisons(fit_pois_fe)