marginaleffects 0.3.2__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/PKG-INFO +2 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/__init__.py +3 -3
- marginaleffects-0.3.2/marginaleffects/jax_dispatch.py → marginaleffects-0.5.0/marginaleffects/autodiff/dispatch.py +6 -6
- marginaleffects-0.5.0/marginaleffects/classes/__init__.py +18 -0
- marginaleffects-0.5.0/marginaleffects/classes/model.py +136 -0
- {marginaleffects-0.3.2/marginaleffects → marginaleffects-0.5.0/marginaleffects/classes}/result.py +12 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/comparisons.py +347 -343
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/datagrid.py +186 -301
- marginaleffects-0.5.0/marginaleffects/datasets.py +137 -0
- marginaleffects-0.5.0/marginaleffects/docstrings/__init__.py +11 -0
- marginaleffects-0.5.0/marginaleffects/docstrings/params.py +562 -0
- marginaleffects-0.5.0/marginaleffects/docstrings/qmd.py +40 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/estimands.py +5 -6
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/linearmodels/model.py +51 -79
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/plot/common.py +36 -1
- marginaleffects-0.5.0/marginaleffects/plot/comparisons.py +91 -0
- marginaleffects-0.5.0/marginaleffects/plot/predictions.py +109 -0
- marginaleffects-0.5.0/marginaleffects/plot/slopes.py +87 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/predictions.py +77 -141
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/pyfixest/model.py +33 -9
- marginaleffects-0.5.0/marginaleffects/sanitize/__init__.py +27 -0
- marginaleffects-0.5.0/marginaleffects/sanitize/by.py +15 -0
- marginaleffects-0.5.0/marginaleffects/sanitize/categorical.py +175 -0
- marginaleffects-0.5.0/marginaleffects/sanitize/comparison.py +80 -0
- marginaleffects-0.5.0/marginaleffects/sanitize/deprecated.py +36 -0
- marginaleffects-0.5.0/marginaleffects/sanitize/hypothesis_null.py +35 -0
- marginaleffects-0.5.0/marginaleffects/sanitize/newdata.py +117 -0
- {marginaleffects-0.3.2/marginaleffects → marginaleffects-0.5.0/marginaleffects/sanitize}/sanitize_model.py +5 -6
- marginaleffects-0.5.0/marginaleffects/sanitize/utils.py +153 -0
- {marginaleffects-0.3.2/marginaleffects → marginaleffects-0.5.0/marginaleffects/sanitize}/validation.py +7 -7
- marginaleffects-0.5.0/marginaleffects/sanitize/variables.py +286 -0
- marginaleffects-0.5.0/marginaleffects/sanitize/vcov.py +18 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/settings.py +11 -20
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/sklearn/model.py +61 -88
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/slopes.py +86 -103
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/statsmodels/model.py +97 -81
- marginaleffects-0.5.0/marginaleffects/test/__init__.py +22 -0
- marginaleffects-0.3.2/marginaleffects/hypothesis.py → marginaleffects-0.5.0/marginaleffects/test/core.py +17 -3
- marginaleffects-0.3.2/marginaleffects/hypothesis_formula.py → marginaleffects-0.5.0/marginaleffects/test/formula.py +3 -3
- marginaleffects-0.3.2/marginaleffects/hypotheses_joint.py → marginaleffects-0.5.0/marginaleffects/test/joint.py +2 -2
- marginaleffects-0.3.2/marginaleffects/hypotheses.py → marginaleffects-0.5.0/marginaleffects/test/main.py +75 -84
- marginaleffects-0.5.0/marginaleffects/transform.py +33 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/uncertainty.py +15 -1
- marginaleffects-0.5.0/marginaleffects/utils.py +358 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects.egg-info/PKG-INFO +2 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects.egg-info/SOURCES.txt +29 -16
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects.egg-info/requires.txt +1 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects.egg-info/top_level.txt +0 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/pyproject.toml +4 -3
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/helpers.py +1 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_autodiff.py +8 -4
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_comparisons.py +38 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_comparisons_interaction.py +1 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_datagrid_02.py +13 -32
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_formulaic_utils.py +1 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_hypothesis.py +1 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_jss.py +14 -4
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_plot_comparisons.py +17 -5
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_plot_predictions.py +5 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_plot_slopes.py +5 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_pyfixest.py +32 -1
- marginaleffects-0.5.0/tests/test_statsmodels_mnlogit.py +118 -0
- marginaleffects-0.5.0/tests/test_statsmodels_ordinal.py +110 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_quantreg.py +36 -7
- marginaleffects-0.5.0/tests/test_statsmodels_vcov.py +103 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_utils.py +2 -1
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/utilities.py +28 -18
- marginaleffects-0.3.2/marginaleffects/_input_utils.py +0 -65
- marginaleffects-0.3.2/marginaleffects/classes.py +0 -64
- marginaleffects-0.3.2/marginaleffects/docs.py +0 -341
- marginaleffects-0.3.2/marginaleffects/inject_docs.py +0 -139
- marginaleffects-0.3.2/marginaleffects/model_abstract.py +0 -89
- marginaleffects-0.3.2/marginaleffects/plot/comparisons.py +0 -120
- marginaleffects-0.3.2/marginaleffects/plot/predictions.py +0 -134
- marginaleffects-0.3.2/marginaleffects/plot/slopes.py +0 -128
- marginaleffects-0.3.2/marginaleffects/sanity.py +0 -757
- marginaleffects-0.3.2/marginaleffects/transform.py +0 -16
- marginaleffects-0.3.2/marginaleffects/utils.py +0 -579
- marginaleffects-0.3.2/tests/test_statsmodels_mnlogit.py +0 -100
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/README.md +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/benchmarks/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/benchmarks/benchmark_autodiff.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/comparisons.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/glm/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/glm/comparisons.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/glm/families.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/glm/predictions.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/linear/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/linear/comparisons.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/linear/predictions.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/autodiff/utils.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/by.py +0 -0
- /marginaleffects-0.3.2/marginaleffects/formulaic_utils.py → /marginaleffects-0.5.0/marginaleffects/formula.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/linearmodels/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/plot/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/pyfixest/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/sklearn/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects/statsmodels/__init__.py +0 -0
- {marginaleffects-0.3.2/marginaleffects → marginaleffects-0.5.0/marginaleffects/test}/equivalence.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/marginaleffects.egg-info/dependency_links.txt +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/setup.cfg +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/__init__.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_analytic.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_bugfix.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_by.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_categorical.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_categorical_validation.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_datagrid_01.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_equivalence.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_formula.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_hypotheses.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_hypotheses_joint.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_linearmodels_panelols.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_missing.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_newdata.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_predictions.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_sklearn.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_slopes.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_logit.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_mixedlm.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_negativebinomial.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_ols.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_poisson.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_probit.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_statsmodels_wls.py +0 -0
- {marginaleffects-0.3.2 → marginaleffects-0.5.0}/tests/test_typical.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: marginaleffects
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
|
|
5
5
|
License-Expression: GPL-3.0-or-later
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -22,6 +22,7 @@ Requires-Dist: matplotlib>=3.7.1; extra == "test"
|
|
|
22
22
|
Requires-Dist: linearmodels>=6.1; extra == "test"
|
|
23
23
|
Requires-Dist: pandas>=2.2.2; extra == "test"
|
|
24
24
|
Requires-Dist: pre-commit>=4.2.0; extra == "test"
|
|
25
|
+
Requires-Dist: numba; extra == "test"
|
|
25
26
|
Requires-Dist: pyarrow>=17.0.0; extra == "test"
|
|
26
27
|
Requires-Dist: pyfixest>=0.28.0; extra == "test"
|
|
27
28
|
Requires-Dist: statsmodels>=0.14.0; extra == "test"
|
|
@@ -7,7 +7,7 @@ _EXPORTS = {
|
|
|
7
7
|
"avg_comparisons": ("marginaleffects.comparisons", "avg_comparisons"),
|
|
8
8
|
"comparisons": ("marginaleffects.comparisons", "comparisons"),
|
|
9
9
|
"datagrid": ("marginaleffects.datagrid", "datagrid"),
|
|
10
|
-
"hypotheses": ("marginaleffects.
|
|
10
|
+
"hypotheses": ("marginaleffects.test", "hypotheses"),
|
|
11
11
|
"plot_comparisons": ("marginaleffects.plot.comparisons", "plot_comparisons"),
|
|
12
12
|
"plot_predictions": ("marginaleffects.plot.predictions", "plot_predictions"),
|
|
13
13
|
"plot_slopes": ("marginaleffects.plot.slopes", "plot_slopes"),
|
|
@@ -18,8 +18,8 @@ _EXPORTS = {
|
|
|
18
18
|
"fit_statsmodels": ("marginaleffects.statsmodels.model", "fit_statsmodels"),
|
|
19
19
|
"fit_sklearn": ("marginaleffects.sklearn.model", "fit_sklearn"),
|
|
20
20
|
"fit_linearmodels": ("marginaleffects.linearmodels.model", "fit_linearmodels"),
|
|
21
|
-
"get_dataset": ("marginaleffects.
|
|
22
|
-
"MarginaleffectsResult": ("marginaleffects.
|
|
21
|
+
"get_dataset": ("marginaleffects.datasets", "get_dataset"),
|
|
22
|
+
"MarginaleffectsResult": ("marginaleffects.classes", "MarginaleffectsResult"),
|
|
23
23
|
"autodiff": ("marginaleffects.settings", "autodiff"),
|
|
24
24
|
"set_autodiff": ("marginaleffects.settings", "set_autodiff"),
|
|
25
25
|
"get_autodiff": ("marginaleffects.settings", "get_autodiff"),
|
|
@@ -46,7 +46,7 @@ def try_jax_predictions(
|
|
|
46
46
|
- by is a list/complex aggregation (only False and True supported initially)
|
|
47
47
|
- vcov is None (no SEs needed anyway)
|
|
48
48
|
"""
|
|
49
|
-
from
|
|
49
|
+
from ..settings import is_autodiff_enabled
|
|
50
50
|
|
|
51
51
|
# Check global setting first
|
|
52
52
|
if not is_autodiff_enabled():
|
|
@@ -83,7 +83,7 @@ def try_jax_predictions(
|
|
|
83
83
|
|
|
84
84
|
# Select appropriate autodiff module based on model type
|
|
85
85
|
if config["model_type"] == "linear":
|
|
86
|
-
from .
|
|
86
|
+
from . import linear as ad_module
|
|
87
87
|
|
|
88
88
|
result = ad_module.predictions.predictions(
|
|
89
89
|
beta=beta,
|
|
@@ -92,7 +92,7 @@ def try_jax_predictions(
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
elif config["model_type"] == "glm":
|
|
95
|
-
from .
|
|
95
|
+
from . import glm as ad_module
|
|
96
96
|
|
|
97
97
|
result = ad_module.predictions.predictions(
|
|
98
98
|
beta=beta,
|
|
@@ -147,7 +147,7 @@ def try_jax_comparisons(
|
|
|
147
147
|
- comparison is not a supported string type
|
|
148
148
|
- cross is True (cross comparisons not supported)
|
|
149
149
|
"""
|
|
150
|
-
from
|
|
150
|
+
from ..settings import is_autodiff_enabled
|
|
151
151
|
|
|
152
152
|
# Check global setting first
|
|
153
153
|
if not is_autodiff_enabled():
|
|
@@ -207,9 +207,9 @@ def try_jax_comparisons(
|
|
|
207
207
|
|
|
208
208
|
# Select appropriate autodiff module based on model type
|
|
209
209
|
if config["model_type"] == "linear":
|
|
210
|
-
from .
|
|
210
|
+
from . import linear as ad_module
|
|
211
211
|
elif config["model_type"] == "glm":
|
|
212
|
-
from .
|
|
212
|
+
from . import glm as ad_module
|
|
213
213
|
else:
|
|
214
214
|
return None
|
|
215
215
|
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .result import MarginaleffectsResult, MarginaleffectsDataFrame
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"MarginaleffectsResult",
|
|
5
|
+
"MarginaleffectsDataFrame",
|
|
6
|
+
"ModelAbstract",
|
|
7
|
+
"ModelVault",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def __getattr__(name):
|
|
12
|
+
if name in ("ModelAbstract", "ModelVault"):
|
|
13
|
+
from .model import ModelAbstract, ModelVault
|
|
14
|
+
|
|
15
|
+
globals()["ModelAbstract"] = ModelAbstract
|
|
16
|
+
globals()["ModelVault"] = ModelVault
|
|
17
|
+
return ModelAbstract if name == "ModelAbstract" else ModelVault
|
|
18
|
+
raise AttributeError(f"module 'marginaleffects.classes' has no attribute {name!r}")
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
from ..sanitize.validation import ModelValidation
|
|
9
|
+
from .. import formula as fml
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ModelVault:
|
|
14
|
+
"""Typed container for model metadata shared across all adapters."""
|
|
15
|
+
|
|
16
|
+
coef: Optional[np.ndarray] = None
|
|
17
|
+
coefnames: Optional[np.ndarray] = None
|
|
18
|
+
formula: Optional[str] = None
|
|
19
|
+
formula_engine: str = "formulaic"
|
|
20
|
+
modeldata: Optional[pl.DataFrame] = None
|
|
21
|
+
package: str = "unknown"
|
|
22
|
+
vcov: Optional[np.ndarray] = None
|
|
23
|
+
variables_type: Dict[str, str] = field(default_factory=dict)
|
|
24
|
+
variable_names: Optional[List[str]] = None
|
|
25
|
+
engine_running: Optional[Any] = None
|
|
26
|
+
# statsmodels-specific
|
|
27
|
+
design_info_patsy: Optional[Any] = None
|
|
28
|
+
pandas_categorical_orders: Dict[str, list] = field(default_factory=dict)
|
|
29
|
+
# sklearn-specific
|
|
30
|
+
model_spec: Optional[Any] = None
|
|
31
|
+
original_columns: Optional[List[str]] = None
|
|
32
|
+
# linearmodels-specific
|
|
33
|
+
multiindex: Optional[List[str]] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ModelAbstract(ModelValidation, ABC):
|
|
37
|
+
def __init__(self, model, vault: ModelVault):
|
|
38
|
+
self.model = model
|
|
39
|
+
self.vault = vault
|
|
40
|
+
self.validation()
|
|
41
|
+
|
|
42
|
+
def get_modeldata(self) -> Optional[pl.DataFrame]:
|
|
43
|
+
return self.vault.modeldata
|
|
44
|
+
|
|
45
|
+
def get_vcov(self, vcov=False) -> Optional[np.ndarray]:
|
|
46
|
+
return self.vault.vcov
|
|
47
|
+
|
|
48
|
+
def get_coef(self) -> Optional[np.ndarray]:
|
|
49
|
+
return self.vault.coef
|
|
50
|
+
|
|
51
|
+
def get_coefnames(self) -> Optional[np.ndarray]:
|
|
52
|
+
return self.vault.coefnames
|
|
53
|
+
|
|
54
|
+
def get_engine_running(self) -> Optional[Any]:
|
|
55
|
+
return self.vault.engine_running
|
|
56
|
+
|
|
57
|
+
def get_formula(self) -> Optional[str]:
|
|
58
|
+
return self.vault.formula
|
|
59
|
+
|
|
60
|
+
def get_formula_engine(self) -> str:
|
|
61
|
+
return self.vault.formula_engine
|
|
62
|
+
|
|
63
|
+
def get_package(self) -> str:
|
|
64
|
+
return self.vault.package
|
|
65
|
+
|
|
66
|
+
def get_variable_type(self, name=None) -> Dict[str, str]:
|
|
67
|
+
variables = self.vault.variables_type
|
|
68
|
+
if isinstance(name, str) and name in variables:
|
|
69
|
+
return variables[name]
|
|
70
|
+
else:
|
|
71
|
+
return variables
|
|
72
|
+
|
|
73
|
+
def set_variable_type(self, name, value):
|
|
74
|
+
self.vault.variables_type[name] = value
|
|
75
|
+
|
|
76
|
+
def find_variables(self) -> Optional[List[str]]:
|
|
77
|
+
if self.vault.variable_names is not None:
|
|
78
|
+
return self.vault.variable_names
|
|
79
|
+
|
|
80
|
+
formula = self.get_formula()
|
|
81
|
+
if isinstance(formula, str):
|
|
82
|
+
out = fml.parse_variables(self.get_formula())
|
|
83
|
+
else:
|
|
84
|
+
out = None
|
|
85
|
+
|
|
86
|
+
self.vault.variable_names = out
|
|
87
|
+
|
|
88
|
+
return out
|
|
89
|
+
|
|
90
|
+
def find_response(self) -> Optional[str]:
|
|
91
|
+
vars = self.find_variables()
|
|
92
|
+
if vars is None:
|
|
93
|
+
return None
|
|
94
|
+
else:
|
|
95
|
+
return vars[0]
|
|
96
|
+
|
|
97
|
+
def find_predictors(self) -> Optional[List[str]]:
|
|
98
|
+
vars = self.find_variables()
|
|
99
|
+
if vars is None:
|
|
100
|
+
return None
|
|
101
|
+
else:
|
|
102
|
+
return vars[1:]
|
|
103
|
+
|
|
104
|
+
def get_exog(self, newdata: pl.DataFrame):
|
|
105
|
+
"""Convert newdata into the design matrix format expected by get_predict.
|
|
106
|
+
|
|
107
|
+
Subclasses may override this to handle model-specific formula engines.
|
|
108
|
+
The default implementation uses the model's formula to build design matrices.
|
|
109
|
+
"""
|
|
110
|
+
from ..formula import model_matrices
|
|
111
|
+
|
|
112
|
+
if self.vault.design_info_patsy is not None:
|
|
113
|
+
f = self.vault.design_info_patsy
|
|
114
|
+
else:
|
|
115
|
+
f = self.get_formula()
|
|
116
|
+
|
|
117
|
+
if callable(f):
|
|
118
|
+
_, exog = f(newdata)
|
|
119
|
+
else:
|
|
120
|
+
_, exog = model_matrices(
|
|
121
|
+
f, newdata, formula_engine=self.get_formula_engine()
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return exog
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
def get_predict(self, params: np.ndarray, newdata) -> pl.DataFrame:
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
def __getattr__(self, name: str) -> Any:
|
|
131
|
+
"""Forward attribute access to the underlying fitted model."""
|
|
132
|
+
try:
|
|
133
|
+
return object.__getattribute__(self, name)
|
|
134
|
+
except AttributeError:
|
|
135
|
+
# Forward to the wrapped model
|
|
136
|
+
return getattr(self.model, name)
|
{marginaleffects-0.3.2/marginaleffects → marginaleffects-0.5.0/marginaleffects/classes}/result.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Any, Dict, Iterable, Iterator, Optional
|
|
5
6
|
|
|
@@ -218,3 +219,14 @@ class MarginaleffectsResult:
|
|
|
218
219
|
# Helpful explicit method for conversions
|
|
219
220
|
def to_polars(self) -> pl.DataFrame:
|
|
220
221
|
return self._data
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# Backwards compatibility alias
|
|
225
|
+
class MarginaleffectsDataFrame(MarginaleffectsResult):
|
|
226
|
+
def __init__(self, *args, **kwargs):
|
|
227
|
+
warnings.warn(
|
|
228
|
+
"MarginaleffectsDataFrame is deprecated; use MarginaleffectsResult instead.",
|
|
229
|
+
DeprecationWarning,
|
|
230
|
+
stacklevel=2,
|
|
231
|
+
)
|
|
232
|
+
super().__init__(*args, **kwargs)
|