marginaleffects 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/PKG-INFO +1 -1
  2. marginaleffects-0.4.0/marginaleffects/__init__.py +58 -0
  3. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/comparisons.py +105 -111
  4. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/datagrid.py +49 -58
  5. marginaleffects-0.4.0/marginaleffects/docs.py +613 -0
  6. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/equivalence.py +4 -1
  7. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/estimands.py +3 -4
  8. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/hypotheses.py +68 -77
  9. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/hypothesis.py +6 -2
  10. marginaleffects-0.4.0/marginaleffects/linearmodels/__init__.py +30 -0
  11. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/linearmodels/model.py +40 -74
  12. marginaleffects-0.4.0/marginaleffects/plot/__init__.py +25 -0
  13. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/plot/common.py +15 -14
  14. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/plot/comparisons.py +26 -29
  15. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/plot/predictions.py +40 -43
  16. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/plot/slopes.py +24 -28
  17. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/predictions.py +67 -80
  18. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/sanitize_model.py +7 -11
  19. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/sanity.py +25 -1
  20. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/settings.py +11 -20
  21. marginaleffects-0.4.0/marginaleffects/sklearn/__init__.py +26 -0
  22. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/sklearn/model.py +51 -84
  23. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/slopes.py +79 -92
  24. marginaleffects-0.4.0/marginaleffects/statsmodels/__init__.py +26 -0
  25. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/statsmodels/model.py +28 -60
  26. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/uncertainty.py +1 -1
  27. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/utils.py +33 -41
  28. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects.egg-info/PKG-INFO +1 -1
  29. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects.egg-info/SOURCES.txt +0 -1
  30. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/pyproject.toml +3 -3
  31. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_comparisons.py +40 -1
  32. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/utilities.py +3 -3
  33. marginaleffects-0.3.0/marginaleffects/__init__.py +0 -46
  34. marginaleffects-0.3.0/marginaleffects/docs.py +0 -341
  35. marginaleffects-0.3.0/marginaleffects/inject_docs.py +0 -139
  36. marginaleffects-0.3.0/marginaleffects/linearmodels/__init__.py +0 -3
  37. marginaleffects-0.3.0/marginaleffects/plot/__init__.py +0 -5
  38. marginaleffects-0.3.0/marginaleffects/sklearn/__init__.py +0 -3
  39. marginaleffects-0.3.0/marginaleffects/statsmodels/__init__.py +0 -3
  40. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/README.md +0 -0
  41. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/benchmarks/__init__.py +0 -0
  42. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/benchmarks/benchmark_autodiff.py +0 -0
  43. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/_input_utils.py +0 -0
  44. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/__init__.py +0 -0
  45. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/comparisons.py +0 -0
  46. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/glm/__init__.py +0 -0
  47. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/glm/comparisons.py +0 -0
  48. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/glm/families.py +0 -0
  49. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/glm/predictions.py +0 -0
  50. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/linear/__init__.py +0 -0
  51. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/linear/comparisons.py +0 -0
  52. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/linear/predictions.py +0 -0
  53. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/autodiff/utils.py +0 -0
  54. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/by.py +0 -0
  55. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/classes.py +0 -0
  56. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/formulaic_utils.py +0 -0
  57. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/hypotheses_joint.py +0 -0
  58. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/hypothesis_formula.py +0 -0
  59. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/jax_dispatch.py +0 -0
  60. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/model_abstract.py +0 -0
  61. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/pyfixest/__init__.py +0 -0
  62. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/pyfixest/model.py +0 -0
  63. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/result.py +0 -0
  64. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/transform.py +0 -0
  65. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects/validation.py +0 -0
  66. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects.egg-info/dependency_links.txt +0 -0
  67. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects.egg-info/requires.txt +0 -0
  68. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/marginaleffects.egg-info/top_level.txt +0 -0
  69. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/setup.cfg +0 -0
  70. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/__init__.py +0 -0
  71. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/helpers.py +0 -0
  72. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_analytic.py +0 -0
  73. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_autodiff.py +0 -0
  74. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_bugfix.py +0 -0
  75. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_by.py +0 -0
  76. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_categorical.py +0 -0
  77. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_categorical_validation.py +0 -0
  78. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_comparisons_interaction.py +0 -0
  79. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_datagrid_01.py +0 -0
  80. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_datagrid_02.py +0 -0
  81. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_equivalence.py +0 -0
  82. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_formula.py +0 -0
  83. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_formulaic_utils.py +0 -0
  84. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_hypotheses.py +0 -0
  85. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_hypotheses_joint.py +0 -0
  86. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_hypothesis.py +0 -0
  87. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_jss.py +0 -0
  88. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_linearmodels_panelols.py +0 -0
  89. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_missing.py +0 -0
  90. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_newdata.py +0 -0
  91. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_plot_comparisons.py +0 -0
  92. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_plot_predictions.py +0 -0
  93. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_plot_slopes.py +0 -0
  94. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_predictions.py +0 -0
  95. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_pyfixest.py +0 -0
  96. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_sklearn.py +0 -0
  97. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_slopes.py +0 -0
  98. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels.py +0 -0
  99. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_logit.py +0 -0
  100. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_mixedlm.py +0 -0
  101. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_mnlogit.py +0 -0
  102. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_negativebinomial.py +0 -0
  103. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_ols.py +0 -0
  104. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_poisson.py +0 -0
  105. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_probit.py +0 -0
  106. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_quantreg.py +0 -0
  107. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_statsmodels_wls.py +0 -0
  108. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_typical.py +0 -0
  109. {marginaleffects-0.3.0 → marginaleffects-0.4.0}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: marginaleffects
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
5
5
  License-Expression: GPL-3.0-or-later
6
6
  Requires-Python: >=3.10
@@ -0,0 +1,58 @@
1
+ from importlib import import_module
2
+ import importlib.util
3
+ import sys
4
+
5
+
6
+ _EXPORTS = {
7
+ "avg_comparisons": ("marginaleffects.comparisons", "avg_comparisons"),
8
+ "comparisons": ("marginaleffects.comparisons", "comparisons"),
9
+ "datagrid": ("marginaleffects.datagrid", "datagrid"),
10
+ "hypotheses": ("marginaleffects.hypotheses", "hypotheses"),
11
+ "plot_comparisons": ("marginaleffects.plot.comparisons", "plot_comparisons"),
12
+ "plot_predictions": ("marginaleffects.plot.predictions", "plot_predictions"),
13
+ "plot_slopes": ("marginaleffects.plot.slopes", "plot_slopes"),
14
+ "avg_predictions": ("marginaleffects.predictions", "avg_predictions"),
15
+ "predictions": ("marginaleffects.predictions", "predictions"),
16
+ "avg_slopes": ("marginaleffects.slopes", "avg_slopes"),
17
+ "slopes": ("marginaleffects.slopes", "slopes"),
18
+ "fit_statsmodels": ("marginaleffects.statsmodels.model", "fit_statsmodels"),
19
+ "fit_sklearn": ("marginaleffects.sklearn.model", "fit_sklearn"),
20
+ "fit_linearmodels": ("marginaleffects.linearmodels.model", "fit_linearmodels"),
21
+ "get_dataset": ("marginaleffects.utils", "get_dataset"),
22
+ "MarginaleffectsResult": ("marginaleffects.result", "MarginaleffectsResult"),
23
+ "autodiff": ("marginaleffects.settings", "autodiff"),
24
+ "set_autodiff": ("marginaleffects.settings", "set_autodiff"),
25
+ "get_autodiff": ("marginaleffects.settings", "get_autodiff"),
26
+ }
27
+
28
+ if importlib.util.find_spec("jax") is not None:
29
+ _EXPORTS["autodiff_module"] = ("marginaleffects.autodiff", None)
30
+
31
+ __all__ = list(_EXPORTS.keys())
32
+
33
+
34
+ def __getattr__(name):
35
+ if name not in _EXPORTS:
36
+ raise AttributeError(f"module 'marginaleffects' has no attribute '{name}'")
37
+
38
+ module_name, attr_name = _EXPORTS[name]
39
+ module = import_module(module_name)
40
+
41
+ # Rebind exported callables from all loaded modules so importlib-created
42
+ # submodule attributes (e.g. `marginaleffects.comparisons`) do not shadow
43
+ # the intended public API during `from marginaleffects import *`.
44
+ for export_name, (export_module, export_attr) in _EXPORTS.items():
45
+ if export_attr is None:
46
+ continue
47
+
48
+ export_mod = sys.modules.get(export_module)
49
+ if export_mod is not None:
50
+ globals()[export_name] = getattr(export_mod, export_attr)
51
+
52
+ value = module if attr_name is None else globals()[name]
53
+ globals()[name] = value
54
+ return value
55
+
56
+
57
+ def __dir__():
58
+ return sorted(set(globals().keys()) | set(__all__))
@@ -2,11 +2,9 @@ import re
2
2
  from functools import reduce
3
3
 
4
4
  import numpy as np
5
- import patsy
6
5
  import polars as pl
7
6
 
8
7
  from .estimands import estimands
9
- from .hypothesis import get_hypothesis
10
8
  from .sanitize_model import sanitize_model
11
9
  from .sanity import (
12
10
  sanitize_variables,
@@ -21,15 +19,8 @@ from .utils import (
21
19
  finalize_result,
22
20
  call_avg,
23
21
  )
24
- from .pyfixest import ModelPyfixest
25
- from .sklearn import ModelSklearn
26
- from .linearmodels import ModelLinearmodels
27
22
  from ._input_utils import prepare_base_inputs
28
- from .docs import (
29
- DocsDetails,
30
- DocsParameters,
31
- docstring_returns,
32
- )
23
+ from .docs import doc
33
24
 
34
25
 
35
26
  def _cross_postprocess(cross):
@@ -286,11 +277,19 @@ def _finalize_counterfactual_frames(
286
277
 
287
278
 
288
279
  def _prepare_design_matrices(model, nd, hi, lo, pad_rows):
289
- if isinstance(model, (ModelPyfixest, ModelLinearmodels, ModelSklearn)):
280
+ package = model.get_package() if hasattr(model, "get_package") else None
281
+ typename = type(model).__name__.lower()
282
+ uses_native_design = package in {"pyfixest", "linearmodels", "sklearn"} or any(
283
+ x in typename for x in ("modelpyfixest", "modellinearmodels", "modelsklearn")
284
+ )
285
+
286
+ if uses_native_design:
290
287
  hi_X = hi
291
288
  lo_X = lo
292
289
  nd_X = nd
293
290
  else:
291
+ import patsy
292
+
294
293
  fml = re.sub(r".*~", "", model.get_formula())
295
294
  hi_X = patsy.dmatrix(fml, hi.to_pandas())
296
295
  lo_X = patsy.dmatrix(fml, lo.to_pandas())
@@ -316,6 +315,99 @@ def _collect_comparison_functions(variables):
316
315
  return comparison_functions
317
316
 
318
317
 
318
+ @doc("""
319
+
320
+ # `comparisons()`
321
+
322
+ `comparisons()` and `avg_comparisons()` are functions for predicting the outcome variable at different regressor values and comparing those predictions by computing a difference, ratio, or some other function. These functions can return many quantities of interest, such as contrasts, differences, risk ratios, changes in log odds, lift, slopes, elasticities, average treatment effect (on the treated or untreated), etc.
323
+
324
+ - `comparisons()`: unit-level (conditional) estimates.
325
+ - `avg_comparisons()`: average (marginal) estimates.
326
+
327
+ See the package website and vignette for examples:
328
+
329
+ - https://marginaleffects.com/chapters/comparisons.html
330
+ - https://marginaleffects.com
331
+
332
+ ## Parameters
333
+
334
+ {param_model}
335
+
336
+ {param_variables_comparison}
337
+
338
+ {param_newdata_comparison}
339
+
340
+ - `comparison`: (str or callable) String specifying how pairs of predictions should be compared, or a callable function to compute custom estimates. See the Comparisons section below for definitions of each transformation.
341
+ - Acceptable strings: difference, differenceavg, differenceavgwts, dydx, eyex, eydx, dyex, dydxavg, eyexavg, eydxavg, dyexavg, dydxavgwts, eyexavgwts, eydxavgwts, dyexavgwts, ratio, ratioavg, ratioavgwts, lnratio, lnratioavg, lnratioavgwts, lnor, lnoravg, lnoravgwts, lift, liftavg, liftavg, expdydx, expdydxavg, expdydxavgwts
342
+ - Callable: A function that accepts any subset of the named arguments `hi`, `lo`, `eps`, `x`, `y`, and `w`, and returns a numeric value or array. For example: `lambda hi, lo: hi / lo` for ratios, `lambda hi, lo: (hi - lo) / lo * 100` for percent changes, or a named function like `def lnor(hi, lo): return np.log((hi.mean() / (1 - hi.mean())) / (lo.mean() / (1 - lo.mean())))`.
343
+
344
+ {param_by}
345
+
346
+ {param_transform}
347
+
348
+ {param_hypothesis}
349
+
350
+ {param_wts}
351
+
352
+ {param_vcov}
353
+
354
+ {param_equivalence}
355
+
356
+ {param_cross}
357
+
358
+ {param_conf_level}
359
+
360
+ {param_eps}
361
+
362
+ {param_eps_vcov}
363
+
364
+ {returns}
365
+
366
+ ## Examples
367
+ ```py
368
+ from marginaleffects import *
369
+ import numpy as np
370
+
371
+ import statsmodels.api as sm
372
+ import statsmodels.formula.api as smf
373
+ data = get_dataset("thornton")
374
+ model = smf.ols("outcome ~ distance + incentive", data=data).fit()
375
+
376
+ # Basic comparisons
377
+ comparisons(model)
378
+
379
+ avg_comparisons(model)
380
+
381
+ comparisons(model, hypothesis=0)
382
+
383
+ avg_comparisons(model, hypothesis=0)
384
+
385
+ comparisons(model, by="agecat")
386
+
387
+ avg_comparisons(model, by="agecat")
388
+
389
+ # Custom comparisons with functions
390
+ # Ratio comparison using lambda
391
+ comparisons(model, variables="distance",
392
+ comparison=lambda hi, lo: hi / lo)
393
+
394
+ # Percent change using lambda
395
+ comparisons(model, variables="distance",
396
+ comparison=lambda hi, lo: (hi - lo) / lo * 100)
397
+
398
+ # Custom function with flexible signature
399
+ def lnor(hi, lo):
400
+ hi = np.asarray(hi)
401
+ lo = np.asarray(lo)
402
+ return np.log((hi.mean() / (1 - hi.mean())) / (lo.mean() / (1 - lo.mean())))
403
+
404
+ comparisons(model, variables="distance", comparison=lnor)
405
+ ```
406
+
407
+ ## Details
408
+ {details_tost}
409
+
410
+ {details_order_of_operations}""")
319
411
  def comparisons(
320
412
  model,
321
413
  variables=None,
@@ -333,13 +425,6 @@ def comparisons(
333
425
  eps_vcov=None,
334
426
  **kwargs,
335
427
  ):
336
- """
337
- `comparisons()` and `avg_comparisons()` are functions for predicting the outcome variable at different regressor values and comparing those predictions by computing a difference, ratio, or some other function. These functions can return many quantities of interest, such as contrasts, differences, risk ratios, changes in log odds, lift, slopes, elasticities, average treatment effect (on the treated or untreated), etc.
338
-
339
- For more information, visit the website: https://marginaleffects.com/
340
-
341
- Or type: `help(comparisons)`
342
- """
343
428
  hypothesis = handle_deprecated_hypotheses_argument(hypothesis, kwargs, stacklevel=2)
344
429
  if kwargs:
345
430
  unexpected = ", ".join(sorted(kwargs.keys()))
@@ -464,6 +549,8 @@ def comparisons(
464
549
 
465
550
  # === END JAX EARLY EXIT ===
466
551
 
552
+ from .hypothesis import get_hypothesis
553
+
467
554
  # inner() takes the `hi` and `lo` matrices, computes predictions, compares
468
555
  # them, and aggregates the results based on the `by` argument. This gives us
469
556
  # the final quantity of interest. We wrap this in a function because it will
@@ -615,13 +702,6 @@ def avg_comparisons(
615
702
  eps=1e-4,
616
703
  **kwargs,
617
704
  ):
618
- """
619
- `comparisons()` and `avg_comparisons()` are functions for predicting the outcome variable at different regressor values and comparing those predictions by computing a difference, ratio, or some other function. These functions can return many quantities of interest, such as contrasts, differences, risk ratios, changes in log odds, lift, slopes, elasticities, average treatment effect (on the treated or untreated), etc.
620
-
621
- For more information, visit the website: https://marginaleffects.com/
622
-
623
- Or type: `help(avg_comparisons)`
624
- """
625
705
  return call_avg(
626
706
  comparisons,
627
707
  model=model,
@@ -641,90 +721,4 @@ def avg_comparisons(
641
721
  )
642
722
 
643
723
 
644
- docs_comparisons = (
645
- """
646
-
647
- # `comparisons()`
648
-
649
- `comparisons()` and `avg_comparisons()` are functions for predicting the outcome variable at different regressor values and comparing those predictions by computing a difference, ratio, or some other function. These functions can return many quantities of interest, such as contrasts, differences, risk ratios, changes in log odds, lift, slopes, elasticities, average treatment effect (on the treated or untreated), etc.
650
-
651
- * `comparisons()`: unit-level (conditional) estimates.
652
- * `avg_comparisons()`: average (marginal) estimates.
653
-
654
- See the package website and vignette for examples:
655
-
656
- * https://marginaleffects.com/chapters/comparisons.html
657
- * https://marginaleffects.com
658
-
659
- ## Parameters
660
- """
661
- + DocsParameters.docstring_model
662
- + DocsParameters.docstring_variables("comparison")
663
- + DocsParameters.docstring_newdata("comparison")
664
- + """
665
- * `comparison`: (str or callable) String specifying how pairs of predictions should be compared, or a callable function to compute custom estimates. See the Comparisons section below for definitions of each transformation.
666
-
667
- * Acceptable strings: difference, differenceavg, differenceavgwts, dydx, eyex, eydx, dyex, dydxavg, eyexavg, eydxavg, dyexavg, dydxavgwts, eyexavgwts, eydxavgwts, dyexavgwts, ratio, ratioavg, ratioavgwts, lnratio, lnratioavg, lnratioavgwts, lnor, lnoravg, lnoravgwts, lift, liftavg, liftavg, expdydx, expdydxavg, expdydxavgwts
668
-
669
- * Callable: A function that takes `hi`, `lo`, `eps`, `x`, `y`, and `w` as arguments and returns a numeric array. This allows computing custom comparisons like `lambda hi, lo, eps, x, y, w: hi / lo` for ratios or `lambda hi, lo, eps, x, y, w: (hi - lo) / lo * 100` for percent changes.
670
- """
671
- + DocsParameters.docstring_by
672
- + DocsParameters.docstring_transform
673
- + DocsParameters.docstring_hypothesis
674
- + DocsParameters.docstring_wts
675
- + DocsParameters.docstring_vcov
676
- + DocsParameters.docstring_equivalence
677
- + DocsParameters.docstring_cross
678
- + DocsParameters.docstring_conf_level
679
- + DocsParameters.docstring_eps
680
- + DocsParameters.docstring_eps_vcov
681
- + docstring_returns
682
- + """
683
- ## Examples
684
- ```py
685
- from marginaleffects import *
686
- import numpy as np
687
-
688
- import statsmodels.api as sm
689
- import statsmodels.formula.api as smf
690
- data = get_dataset("thornton")
691
- model = smf.ols("outcome ~ distance + incentive", data=data).fit()
692
-
693
- # Basic comparisons
694
- comparisons(model)
695
-
696
- avg_comparisons(model)
697
-
698
- comparisons(model, hypothesis=0)
699
-
700
- avg_comparisons(model, hypothesis=0)
701
-
702
- comparisons(model, by="agecat")
703
-
704
- avg_comparisons(model, by="agecat")
705
-
706
- # Custom comparisons with lambda functions
707
- # Ratio comparison using lambda
708
- comparisons(model, variables="distance",
709
- comparison=lambda hi, lo, eps, x, y, w: hi / lo)
710
-
711
- # Percent change using lambda
712
- comparisons(model, variables="distance",
713
- comparison=lambda hi, lo, eps, x, y, w: (hi - lo) / lo * 100)
714
-
715
- # Log ratio using lambda
716
- comparisons(model, variables="distance",
717
- comparison=lambda hi, lo, eps, x, y, w: np.log(hi / lo))
718
- ```
719
-
720
- ## Details
721
- """
722
- + DocsDetails.docstring_tost
723
- + DocsDetails.docstring_order_of_operations
724
- + "" # add comparisons argument functions section as in R at https://marginaleffects.com/man/r/comparisons.html
725
- )
726
-
727
-
728
- comparisons.__doc__ = docs_comparisons
729
-
730
724
  avg_comparisons.__doc__ = comparisons.__doc__
@@ -25,11 +25,58 @@ def datagrid(
25
25
  **kwargs,
26
26
  ):
27
27
  """
28
+ # `datagrid()`
29
+
28
30
  Generate a data grid of user-specified values for use in the 'newdata' argument of the 'predictions()', 'comparisons()', and 'slopes()' functions.
29
31
 
30
- For more information, visit the website: https://marginaleffects.com/
32
+ This is useful to define where in the predictor space we want to evaluate the quantities of interest. Ex: the predicted outcome or slope for a 37 year old college graduate.
33
+
34
+ ## Parameters
35
+ * model: (object, optional)
36
+ Model object.
37
+ * (one and only one of the `model` and `newdata` arguments can be used.)
38
+ * newdata: (DataFrame, optional)
39
+ Data frame used to define the predictor space.
40
+ * (one and only one of the `model` and `newdata` arguments can be used.)
41
+ * grid_type: (str, optional)
42
+ Determines the functions to apply to each variable. The defaults can be overridden by defining individual variables explicitly in the `**kwargs`, or by supplying a function to one of the `FUN_*` arguments.
43
+ * "mean_or_mode": Character, factor, logical, and binary variables are set to their modes. Numeric, integer, and other variables are set to their means.
44
+ * "balanced": Each unique level of character, factor, logical, and binary variables are preserved. Numeric, integer, and other variables are set to their means. Warning: When there are many variables and many levels per variable, a balanced grid can be very large. In those cases, it is better to use `grid_type="mean_or_mode"` and to specify the unique levels of a subset of named variables explicitly.
45
+ * "counterfactual": the entire dataset is duplicated for each combination of the variable values specified in `**kwargs`. Variables not explicitly supplied to `datagrid()` are set to their observed values in the original dataset.
46
+ * FUN_numeric: (Callable, optional)
47
+ The function to be applied to numeric variables.
48
+ * FUN_other: (Callable, optional)
49
+ The function to be applied to other variable types.
50
+ * **kwargs
51
+ * Named arguments where the name is the variable name and the value is a list of values to use in the grid. If a variable is not specified, it is set to its mean or mode depending on the `grid_type` argument.
52
+
53
+ ## Returns
54
+ (polars.DataFrame)
55
+ * DataFrame where each row corresponds to one combination of the named predictors supplied by the user. Variables which are not explicitly defined are held at their mean or mode.
31
56
 
32
- Or type: `help(datagrid)`
57
+ ## Examples
58
+ ```py
59
+ import polars as pl
60
+ import statsmodels.formula.api as smf
61
+ from marginaleffects import *
62
+ data = get_dataset("thornton")
63
+
64
+ # The output only has 2 rows, and all the variables except `hp` are at their mean or mode.
65
+ datagrid(newdata = data, village = [43, 11])
66
+
67
+ # We get the same result by feeding a model instead of a DataFrame
68
+ mod = smf.ols("outcome ~ incentive + distance", data).fit()
69
+ datagrid(model = mod, village = [43, 11])
70
+
71
+ # Use in `marginaleffects` to compute "Typical Marginal Effects". When used in `slopes()` or `predictions()` we do not need to specify the `model` or `newdata` arguments.
72
+ nd = datagrid(mod, village = [43, 11])
73
+ slopes(mod, newdata = nd)
74
+
75
+ # The full dataset is duplicated with each observation given counterfactual values of 43 and 11 for the `village` variable.
76
+ # The original `thornton` includes 2884 rows, so the resulting dataset includes 5768 rows.
77
+ dg = datagrid(newdata = data, village = [43, 11], grid_type = "counterfactual")
78
+ dg.shape
79
+ ```
33
80
  """
34
81
 
35
82
  # allow predictions() to pass `model` argument automatically
@@ -495,59 +542,3 @@ def _datagridcf(model=None, newdata=None, by=None, **kwargs):
495
542
  result.datagrid_explicit = list(kwargs.keys())
496
543
 
497
544
  return result
498
-
499
-
500
- datagrid.__doc__ = """
501
- # `datagrid()`
502
-
503
- Generate a data grid of user-specified values for use in the 'newdata' argument of the 'predictions()', 'comparisons()', and 'slopes()' functions.
504
-
505
- This is useful to define where in the predictor space we want to evaluate the quantities of interest. Ex: the predicted outcome or slope for a 37 year old college graduate.
506
-
507
- ## Parameters
508
- * model: (object, optional)
509
- Model object.
510
- * (one and only one of the `model` and `newdata` arguments can be used.)
511
- * newdata: (DataFrame, optional)
512
- Data frame used to define the predictor space.
513
- * (one and only one of the `model` and `newdata` arguments can be used.)
514
- * grid_type: (str, optional)
515
- Determines the functions to apply to each variable. The defaults can be overridden by defining individual variables explicitly in the `**kwargs`, or by supplying a function to one of the `FUN_*` arguments.
516
- * "mean_or_mode": Character, factor, logical, and binary variables are set to their modes. Numeric, integer, and other variables are set to their means.
517
- * "balanced": Each unique level of character, factor, logical, and binary variables are preserved. Numeric, integer, and other variables are set to their means. Warning: When there are many variables and many levels per variable, a balanced grid can be very large. In those cases, it is better to use `grid_type="mean_or_mode"` and to specify the unique levels of a subset of named variables explicitly.
518
- * "counterfactual": the entire dataset is duplicated for each combination of the variable values specified in `**kwargs`. Variables not explicitly supplied to `datagrid()` are set to their observed values in the original dataset.
519
- * FUN_numeric: (Callable, optional)
520
- The function to be applied to numeric variables.
521
- * FUN_other: (Callable, optional)
522
- The function to be applied to other variable types.
523
- * **kwargs
524
- * Named arguments where the name is the variable name and the value is a list of values to use in the grid. If a variable is not specified, it is set to its mean or mode depending on the `grid_type` argument.
525
-
526
- ## Returns
527
- (polars.DataFrame)
528
- * DataFrame where each row corresponds to one combination of the named predictors supplied by the user. Variables which are not explicitly defined are held at their mean or mode.
529
-
530
- ## Examples
531
- ```py
532
- import polars as pl
533
- import statsmodels.formula.api as smf
534
- from marginaleffects import *
535
- data = get_dataset("thornton")
536
-
537
- # The output only has 2 rows, and all the variables except `hp` are at their mean or mode.
538
- datagrid(newdata = data, village = [43, 11])
539
-
540
- # We get the same result by feeding a model instead of a DataFrame
541
- mod = smf.ols("outcome ~ incentive + distance", data).fit()
542
- datagrid(model = mod, village = [43, 11])
543
-
544
- # Use in `marginaleffects` to compute "Typical Marginal Effects". When used in `slopes()` or `predictions()` we do not need to specify the `model` or `newdata` arguments.
545
- nd = datagrid(mod, village = [43, 11])
546
- slopes(mod, newdata = nd)
547
-
548
- # The full dataset is duplicated with each observation given counterfactual values of 43 and 11 for the `village` variable.
549
- # The original `thornton` includes 2884 rows, so the resulting dataset includes 5768 rows.
550
- dg = datagrid(newdata = data, village = [43, 11], grid_type = "counterfactual")
551
- dg.shape
552
- ```
553
- """