google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Meridian EDA Outcome."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import enum
|
|
19
|
+
import typing
|
|
20
|
+
import pandas as pd
|
|
21
|
+
import xarray as xr
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"EDASeverity",
|
|
25
|
+
"EDAFinding",
|
|
26
|
+
"AnalysisLevel",
|
|
27
|
+
"AnalysisArtifact",
|
|
28
|
+
"PairwiseCorrArtifact",
|
|
29
|
+
"StandardDeviationArtifact",
|
|
30
|
+
"VIFArtifact",
|
|
31
|
+
"KpiInvariabilityArtifact",
|
|
32
|
+
"EDACheckType",
|
|
33
|
+
"ArtifactType",
|
|
34
|
+
"EDAOutcome",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@enum.unique
|
|
39
|
+
class EDASeverity(enum.Enum):
|
|
40
|
+
"""Enumeration for the severity of an EDA check's finding."""
|
|
41
|
+
|
|
42
|
+
# For the non-critical findings.
|
|
43
|
+
INFO = enum.auto()
|
|
44
|
+
# For the non-critical findings that require user attention.
|
|
45
|
+
ATTENTION = enum.auto()
|
|
46
|
+
# For unacceptable, model-blocking data errors.
|
|
47
|
+
ERROR = enum.auto()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclasses.dataclass(frozen=True)
|
|
51
|
+
class EDAFinding:
|
|
52
|
+
"""Encapsulates a single, specific finding from an EDA check.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
severity: The severity level of the finding.
|
|
56
|
+
explanation: A human-readable description about the EDA check and a
|
|
57
|
+
potential actionable guidance on how to address or interpret this
|
|
58
|
+
specific finding.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
severity: EDASeverity
|
|
62
|
+
explanation: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@enum.unique
|
|
66
|
+
class AnalysisLevel(enum.Enum):
|
|
67
|
+
"""Enumeration for the level of an analysis.
|
|
68
|
+
|
|
69
|
+
Attributes:
|
|
70
|
+
OVERALL: Computed across all geos and time. When the analysis is performed
|
|
71
|
+
on national data, this level is equivalent to the NATIONAL level.
|
|
72
|
+
NATIONAL: Computed across time for data aggregated to the national level.
|
|
73
|
+
When the analysis is performed on national data, this level is equivalent
|
|
74
|
+
to the OVERALL level.
|
|
75
|
+
GEO: Computed across time, for each geo.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
OVERALL = enum.auto()
|
|
79
|
+
NATIONAL = enum.auto()
|
|
80
|
+
GEO = enum.auto()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclasses.dataclass(frozen=True)
|
|
84
|
+
class AnalysisArtifact:
|
|
85
|
+
"""Base dataclass for analysis artifacts.
|
|
86
|
+
|
|
87
|
+
Specific EDA artifacts should inherit from this class to store check-specific
|
|
88
|
+
data for downstream processing (e.g., plotting).
|
|
89
|
+
|
|
90
|
+
Attributes:
|
|
91
|
+
level: The level of the analysis.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
level: AnalysisLevel
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclasses.dataclass(frozen=True)
|
|
98
|
+
class PairwiseCorrArtifact(AnalysisArtifact):
|
|
99
|
+
"""Encapsulates artifacts from a single pairwise correlation analysis.
|
|
100
|
+
|
|
101
|
+
Attributes:
|
|
102
|
+
corr_matrix: Pairwise correlation matrix.
|
|
103
|
+
extreme_corr_var_pairs: DataFrame of variable pairs exceeding the
|
|
104
|
+
correlation threshold.
|
|
105
|
+
extreme_corr_threshold: The threshold used to identify extreme correlation
|
|
106
|
+
pairs.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
corr_matrix: xr.DataArray
|
|
110
|
+
extreme_corr_var_pairs: pd.DataFrame
|
|
111
|
+
extreme_corr_threshold: float
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclasses.dataclass(frozen=True)
|
|
115
|
+
class StandardDeviationArtifact(AnalysisArtifact):
|
|
116
|
+
"""Encapsulates artifacts from a standard deviation analysis.
|
|
117
|
+
|
|
118
|
+
Attributes:
|
|
119
|
+
variable: The variable for which standard deviation is calculated.
|
|
120
|
+
std_ds: Dataset with stdev_with_outliers and stdev_without_outliers.
|
|
121
|
+
outlier_df: DataFrame with outliers.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
variable: str
|
|
125
|
+
std_ds: xr.Dataset
|
|
126
|
+
outlier_df: pd.DataFrame
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclasses.dataclass(frozen=True)
|
|
130
|
+
class VIFArtifact(AnalysisArtifact):
|
|
131
|
+
"""Encapsulates artifacts from a single VIF analysis.
|
|
132
|
+
|
|
133
|
+
Attributes:
|
|
134
|
+
vif_da: DataArray with VIF values.
|
|
135
|
+
outlier_df: DataFrame with extreme VIF values.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
vif_da: xr.DataArray
|
|
139
|
+
outlier_df: pd.DataFrame
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclasses.dataclass(frozen=True)
|
|
143
|
+
class KpiInvariabilityArtifact(AnalysisArtifact):
|
|
144
|
+
"""Encapsulates artifacts from a KPI invariability analysis.
|
|
145
|
+
|
|
146
|
+
Attributes:
|
|
147
|
+
kpi_da: DataArray of the KPI that is examined for variability.
|
|
148
|
+
kpi_stdev: The standard deviation of the KPI, which is used to test the KPI
|
|
149
|
+
invariability.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
kpi_da: xr.DataArray
|
|
153
|
+
kpi_stdev: xr.DataArray
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@enum.unique
|
|
157
|
+
class EDACheckType(enum.Enum):
|
|
158
|
+
"""Enumeration for the type of an EDA check."""
|
|
159
|
+
|
|
160
|
+
PAIRWISE_CORRELATION = enum.auto()
|
|
161
|
+
STANDARD_DEVIATION = enum.auto()
|
|
162
|
+
MULTICOLLINEARITY = enum.auto()
|
|
163
|
+
KPI_INVARIABILITY = enum.auto()
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
ArtifactType = typing.TypeVar("ArtifactType", bound="AnalysisArtifact")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclasses.dataclass(frozen=True)
|
|
170
|
+
class EDAOutcome(typing.Generic[ArtifactType]):
|
|
171
|
+
"""A dataclass for the outcomes of a single EDA check function.
|
|
172
|
+
|
|
173
|
+
An EDA check function can discover multiple issues. This object groups all of
|
|
174
|
+
those individual issues, reported as a list of `EDAFinding` objects.
|
|
175
|
+
|
|
176
|
+
Attributes:
|
|
177
|
+
check_type: The type of the EDA check that is being performed.
|
|
178
|
+
findings: A list of all individual issues discovered by the check.
|
|
179
|
+
analysis_artifacts: A list of analysis artifacts from the EDA check.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
check_type: EDACheckType
|
|
183
|
+
findings: list[EDAFinding]
|
|
184
|
+
analysis_artifacts: list[ArtifactType]
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def get_geo_artifact(self) -> ArtifactType | None:
|
|
188
|
+
"""Returns the geo-level analysis artifact."""
|
|
189
|
+
for artifact in self.analysis_artifacts:
|
|
190
|
+
if artifact.level == AnalysisLevel.GEO:
|
|
191
|
+
return artifact
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def get_national_artifact(self) -> ArtifactType | None:
|
|
196
|
+
"""Returns the national-level analysis artifact."""
|
|
197
|
+
for artifact in self.analysis_artifacts:
|
|
198
|
+
if artifact.level == AnalysisLevel.NATIONAL:
|
|
199
|
+
return artifact
|
|
200
|
+
return None
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Meridian EDA Spec."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
from typing import Any, Callable, Dict, TypeAlias
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"AggregationConfig",
|
|
22
|
+
"VIFSpec",
|
|
23
|
+
"EDASpec",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
AggregationFn: TypeAlias = Callable[..., Any]
|
|
27
|
+
AggregationMap: TypeAlias = Dict[str, AggregationFn]
|
|
28
|
+
_DEFAULT_VIF_THRESHOLD = 1000
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
32
|
+
class AggregationConfig:
|
|
33
|
+
"""A configuration for customizing variable aggregation functions.
|
|
34
|
+
|
|
35
|
+
The aggregation function can be called in the form `f(x, axis=axis, **kwargs)`
|
|
36
|
+
to return the result of reducing an `np.ndarray` over an integer valued axis.
|
|
37
|
+
It's recommended to explicitly define the aggregation functions instead of
|
|
38
|
+
using lambdas.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
control_variables: A dictionary mapping control variable names to
|
|
42
|
+
aggregation functions. Defaults to `np.sum` if a variable is not
|
|
43
|
+
specified.
|
|
44
|
+
non_media_treatments: A dictionary mapping non-media variable names to
|
|
45
|
+
aggregation functions. Defaults to `np.sum` if a variable is not
|
|
46
|
+
specified.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
control_variables: AggregationMap = dataclasses.field(default_factory=dict)
|
|
50
|
+
non_media_treatments: AggregationMap = dataclasses.field(default_factory=dict)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
54
|
+
class VIFSpec:
|
|
55
|
+
"""A spec for the EDA VIF check.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
geo_threshold: The threshold for geo-level VIF.
|
|
59
|
+
overall_threshold: The threshold for overall VIF.
|
|
60
|
+
national_threshold: The threshold for national VIF.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
geo_threshold: float = _DEFAULT_VIF_THRESHOLD
|
|
64
|
+
overall_threshold: float = _DEFAULT_VIF_THRESHOLD
|
|
65
|
+
national_threshold: float = _DEFAULT_VIF_THRESHOLD
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
69
|
+
class EDASpec:
|
|
70
|
+
"""A container for all user-configurable EDA check specs.
|
|
71
|
+
|
|
72
|
+
This object allows users to customize the behavior of the EDA checks
|
|
73
|
+
by passing a single configuration object into the EDAEngine constructor,
|
|
74
|
+
avoiding a large number of arguments.
|
|
75
|
+
|
|
76
|
+
Attributes:
|
|
77
|
+
aggregation_config: A configuration object for custom aggregation functions.
|
|
78
|
+
vif_spec: A configuration object for the EDA VIF check.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
aggregation_config: AggregationConfig = dataclasses.field(
|
|
82
|
+
default_factory=AggregationConfig
|
|
83
|
+
)
|
|
84
|
+
vif_spec: VIFSpec = dataclasses.field(default_factory=VIFSpec)
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module containing Meridian related exploratory data analysis (EDA) functionalities."""
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Literal, TYPE_CHECKING, Union
|
|
19
|
+
|
|
20
|
+
import altair as alt
|
|
21
|
+
from meridian import constants
|
|
22
|
+
from meridian.model.eda import constants as eda_constants
|
|
23
|
+
import pandas as pd
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
'MeridianEDA',
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MeridianEDA:
|
|
34
|
+
"""Class for running pre-modeling exploratory data analysis for Meridian InputData."""
|
|
35
|
+
|
|
36
|
+
_PAIRWISE_CORR_COLOR_SCALE = alt.Scale(
|
|
37
|
+
domain=[-1.0, 0.0, 1.0],
|
|
38
|
+
range=['#1f78b4', '#f7f7f7', '#e34a33'], # Blue-light grey-Orange
|
|
39
|
+
type='linear',
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
meridian: model.Meridian,
|
|
45
|
+
):
|
|
46
|
+
self._meridian = meridian
|
|
47
|
+
|
|
48
|
+
def generate_and_save_report(self, filename: str, filepath: str):
|
|
49
|
+
"""Generates and saves the 2 page HTML report containing findings in EDA about given InputData.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
filename: The filename for the generated HTML output.
|
|
53
|
+
filepath: The path to the directory where the file will be saved.
|
|
54
|
+
"""
|
|
55
|
+
# TODO: Implement.
|
|
56
|
+
raise NotImplementedError()
|
|
57
|
+
|
|
58
|
+
def plot_pairwise_correlation(
|
|
59
|
+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
|
|
60
|
+
) -> alt.Chart:
|
|
61
|
+
"""Plots the Pairwise Correlation data.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
geos: Defines which geos to plot. - int: The number of top geos to plot,
|
|
65
|
+
ranked by population. - list[str]: A specific list of geo names to plot.
|
|
66
|
+
- 'nationalize': Aggregates all geos into a single national view.
|
|
67
|
+
Defaults to 1 (plotting the top geo). If the data is already at a
|
|
68
|
+
national level, this parameter is ignored and a national plot is
|
|
69
|
+
generated.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Altair chart(s) of the Pairwise Correlation data.
|
|
73
|
+
"""
|
|
74
|
+
geos_to_plot = self._validate_and_get_geos_to_plot(geos)
|
|
75
|
+
is_national = self._meridian.is_national
|
|
76
|
+
nationalize_geos = geos == 'nationalize'
|
|
77
|
+
|
|
78
|
+
if is_national or nationalize_geos:
|
|
79
|
+
pairwise_corr_artifact = (
|
|
80
|
+
self._meridian.eda_engine.check_national_pairwise_corr().get_national_artifact
|
|
81
|
+
)
|
|
82
|
+
if pairwise_corr_artifact is None:
|
|
83
|
+
raise ValueError('EDAOutcome does not have national artifact.')
|
|
84
|
+
else:
|
|
85
|
+
pairwise_corr_artifact = (
|
|
86
|
+
self._meridian.eda_engine.check_geo_pairwise_corr().get_geo_artifact
|
|
87
|
+
)
|
|
88
|
+
if pairwise_corr_artifact is None:
|
|
89
|
+
raise ValueError('EDAOutcome does not have geo artifact.')
|
|
90
|
+
pairwise_corr_data = pairwise_corr_artifact.corr_matrix.to_dataframe()
|
|
91
|
+
|
|
92
|
+
charts = []
|
|
93
|
+
for geo_to_plot in geos_to_plot:
|
|
94
|
+
title = (
|
|
95
|
+
'Pairwise correlations among all treatments and controls for'
|
|
96
|
+
f' {geo_to_plot}'
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if not (is_national or nationalize_geos):
|
|
100
|
+
plot_data = (
|
|
101
|
+
pairwise_corr_data.xs(geo_to_plot, level=constants.GEO)
|
|
102
|
+
.rename_axis(
|
|
103
|
+
index=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2]
|
|
104
|
+
)
|
|
105
|
+
.reset_index()
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
plot_data = pairwise_corr_data.rename_axis(
|
|
109
|
+
index=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2]
|
|
110
|
+
).reset_index()
|
|
111
|
+
plot_data.columns = [
|
|
112
|
+
eda_constants.VARIABLE_1,
|
|
113
|
+
eda_constants.VARIABLE_2,
|
|
114
|
+
eda_constants.CORRELATION,
|
|
115
|
+
]
|
|
116
|
+
unique_variables = plot_data[eda_constants.VARIABLE_1].unique()
|
|
117
|
+
variable_to_index = {name: i for i, name in enumerate(unique_variables)}
|
|
118
|
+
|
|
119
|
+
plot_data['idx1'] = plot_data[eda_constants.VARIABLE_1].map(
|
|
120
|
+
variable_to_index
|
|
121
|
+
)
|
|
122
|
+
plot_data['idx2'] = plot_data[eda_constants.VARIABLE_2].map(
|
|
123
|
+
variable_to_index
|
|
124
|
+
)
|
|
125
|
+
lower_triangle_data = plot_data[plot_data['idx2'] > plot_data['idx1']]
|
|
126
|
+
|
|
127
|
+
charts.append(
|
|
128
|
+
self._plot_2d_heatmap(lower_triangle_data, title, unique_variables)
|
|
129
|
+
)
|
|
130
|
+
final_chart = (
|
|
131
|
+
alt.vconcat(*charts)
|
|
132
|
+
.resolve_legend(color='independent')
|
|
133
|
+
.configure_axis(labelAngle=315)
|
|
134
|
+
.configure_title(anchor='start')
|
|
135
|
+
.configure_view(stroke=None)
|
|
136
|
+
)
|
|
137
|
+
return final_chart
|
|
138
|
+
|
|
139
|
+
def _plot_2d_heatmap(
|
|
140
|
+
self, data: pd.DataFrame, title: str, unique_variables: list[str]
|
|
141
|
+
) -> alt.Chart:
|
|
142
|
+
"""Plots a 2D heatmap."""
|
|
143
|
+
# Base chart with position encodings
|
|
144
|
+
base = (
|
|
145
|
+
alt.Chart(data)
|
|
146
|
+
.encode(
|
|
147
|
+
x=alt.X(
|
|
148
|
+
f'{eda_constants.VARIABLE_1}:N',
|
|
149
|
+
title=None,
|
|
150
|
+
sort=unique_variables,
|
|
151
|
+
scale=alt.Scale(domain=unique_variables),
|
|
152
|
+
),
|
|
153
|
+
y=alt.Y(
|
|
154
|
+
f'{eda_constants.VARIABLE_2}:N',
|
|
155
|
+
title=None,
|
|
156
|
+
sort=unique_variables,
|
|
157
|
+
scale=alt.Scale(domain=unique_variables),
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
.properties(title=title)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Heatmap layer (rectangles)
|
|
164
|
+
heatmap = base.mark_rect().encode(
|
|
165
|
+
color=alt.Color(
|
|
166
|
+
f'{eda_constants.CORRELATION}:Q',
|
|
167
|
+
scale=self._PAIRWISE_CORR_COLOR_SCALE,
|
|
168
|
+
legend=alt.Legend(title=eda_constants.CORRELATION),
|
|
169
|
+
),
|
|
170
|
+
tooltip=[
|
|
171
|
+
eda_constants.VARIABLE_1,
|
|
172
|
+
eda_constants.VARIABLE_2,
|
|
173
|
+
alt.Tooltip(f'{eda_constants.CORRELATION}:Q', format='.3f'),
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Text annotation layer (values)
|
|
178
|
+
text = base.mark_text().encode(
|
|
179
|
+
text=alt.Text(f'{eda_constants.CORRELATION}:Q', format='.3f'),
|
|
180
|
+
color=alt.value('black'),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Combine layers and apply final configurations
|
|
184
|
+
chart = (heatmap + text).properties(width=350, height=350)
|
|
185
|
+
|
|
186
|
+
return chart
|
|
187
|
+
|
|
188
|
+
def _generate_pairwise_correlation_report(self) -> str:
|
|
189
|
+
"""Creates the HTML snippet for Pairwise Correlation report section."""
|
|
190
|
+
# TODO: Implement.
|
|
191
|
+
raise NotImplementedError()
|
|
192
|
+
|
|
193
|
+
def _validate_and_get_geos_to_plot(
|
|
194
|
+
self, geos: Union[int, list[str], Literal['nationalize']]
|
|
195
|
+
) -> list[str]:
|
|
196
|
+
"""Validates and returns the geos to plot."""
|
|
197
|
+
## Validate
|
|
198
|
+
is_national = self._meridian.is_national
|
|
199
|
+
if is_national or geos == 'nationalize':
|
|
200
|
+
geos_to_plot = [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]
|
|
201
|
+
elif isinstance(geos, int):
|
|
202
|
+
if geos > len(self._meridian.input_data.geo) or geos <= 0:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
'geos must be a positive integer less than or equal to the number'
|
|
205
|
+
' of geos in the data.'
|
|
206
|
+
)
|
|
207
|
+
geos_to_plot = self._meridian.input_data.get_n_top_largest_geos(geos)
|
|
208
|
+
else:
|
|
209
|
+
geos_to_plot = geos
|
|
210
|
+
|
|
211
|
+
if (
|
|
212
|
+
not is_national and geos != 'nationalize'
|
|
213
|
+
): # if national then geos_to_plot will be ignored
|
|
214
|
+
for geo in geos_to_plot:
|
|
215
|
+
if geo not in self._meridian.input_data.geo:
|
|
216
|
+
raise ValueError(f'Geo {geo} does not exist in the data.')
|
|
217
|
+
if len(geos_to_plot) != len(set(geos_to_plot)):
|
|
218
|
+
raise ValueError('geos must not contain duplicate values.')
|
|
219
|
+
|
|
220
|
+
return geos_to_plot
|
meridian/model/knots.py
CHANGED
|
@@ -250,12 +250,28 @@ class AKS:
|
|
|
250
250
|
def __init__(self, data: input_data.InputData):
|
|
251
251
|
self._data = data
|
|
252
252
|
|
|
253
|
-
def automatic_knot_selection(
|
|
253
|
+
def automatic_knot_selection(
|
|
254
|
+
self,
|
|
255
|
+
base_penalty: np.ndarray | None = None,
|
|
256
|
+
min_internal_knots: int = 1,
|
|
257
|
+
max_internal_knots: int | None = None,
|
|
258
|
+
) -> AKSResult:
|
|
254
259
|
"""Calculates the optimal number of knots for Meridian model using Automatic knot selection with A-spline.
|
|
255
260
|
|
|
261
|
+
Args:
|
|
262
|
+
base_penalty: A vector of positive penalty values. The adaptive spline
|
|
263
|
+
regression is performed for every value of penalty.
|
|
264
|
+
min_internal_knots: The minimum number of internal knots. Defaults to 1.
|
|
265
|
+
max_internal_knots: The maximum number of internal knots. If None, this
|
|
266
|
+
value is calculated as the number of initial knots minus the total count
|
|
267
|
+
of all treatment and control variables. Otherwise, the user-provided
|
|
268
|
+
value will be used.
|
|
269
|
+
|
|
256
270
|
Returns:
|
|
257
271
|
Selected knots and the corresponding B-spline model.
|
|
258
272
|
"""
|
|
273
|
+
if base_penalty is None:
|
|
274
|
+
base_penalty = self._BASE_PENALTY
|
|
259
275
|
n_times = len(self._data.time)
|
|
260
276
|
n_geos = len(self._data.geo)
|
|
261
277
|
|
|
@@ -265,11 +281,12 @@ class AKS:
|
|
|
265
281
|
np.repeat([range(n_times)], n_geos, axis=0), (n_geos * n_times,)
|
|
266
282
|
)
|
|
267
283
|
|
|
268
|
-
knots
|
|
269
|
-
|
|
284
|
+
knots = self._calculate_initial_knots(x)
|
|
285
|
+
max_internal_knots = self._calculate_and_validate_max_internal_knots(
|
|
286
|
+
knots, min_internal_knots, max_internal_knots
|
|
270
287
|
)
|
|
271
288
|
geo_scaling_factor = 1 / np.sqrt(len(self._data.geo))
|
|
272
|
-
penalty = geo_scaling_factor *
|
|
289
|
+
penalty = geo_scaling_factor * base_penalty
|
|
273
290
|
|
|
274
291
|
aspline = self.aspline(x=x, y=y, knots=knots, penalty=penalty)
|
|
275
292
|
n_knots = np.array([len(x) for x in aspline[constants.KNOTS_SELECTED]])
|
|
@@ -288,7 +305,7 @@ class AKS:
|
|
|
288
305
|
def _calculate_initial_knots(
|
|
289
306
|
self,
|
|
290
307
|
x: np.ndarray,
|
|
291
|
-
) ->
|
|
308
|
+
) -> np.ndarray:
|
|
292
309
|
"""Calculates initial knots based on unique x values.
|
|
293
310
|
|
|
294
311
|
Args:
|
|
@@ -298,38 +315,7 @@ class AKS:
|
|
|
298
315
|
Returns:
|
|
299
316
|
A tuple containing:
|
|
300
317
|
- The calculated knots.
|
|
301
|
-
- The minimum number of internal knots.
|
|
302
|
-
- The maximum number of internal knots.
|
|
303
318
|
"""
|
|
304
|
-
n_media = (
|
|
305
|
-
len(self._data.media_channel)
|
|
306
|
-
if self._data.media_channel is not None
|
|
307
|
-
else 0
|
|
308
|
-
)
|
|
309
|
-
n_rf = (
|
|
310
|
-
len(self._data.rf_channel) if self._data.rf_channel is not None else 0
|
|
311
|
-
)
|
|
312
|
-
n_organic_media = (
|
|
313
|
-
len(self._data.organic_media_channel)
|
|
314
|
-
if self._data.organic_media_channel is not None
|
|
315
|
-
else 0
|
|
316
|
-
)
|
|
317
|
-
n_organic_rf = (
|
|
318
|
-
len(self._data.organic_rf_channel)
|
|
319
|
-
if self._data.organic_rf_channel is not None
|
|
320
|
-
else 0
|
|
321
|
-
)
|
|
322
|
-
n_non_media = (
|
|
323
|
-
len(self._data.non_media_channel)
|
|
324
|
-
if self._data.non_media_channel is not None
|
|
325
|
-
else 0
|
|
326
|
-
)
|
|
327
|
-
n_controls = (
|
|
328
|
-
len(self._data.control_variable)
|
|
329
|
-
if self._data.control_variable is not None
|
|
330
|
-
else 0
|
|
331
|
-
)
|
|
332
|
-
|
|
333
319
|
x_vals_unique = np.unique(x)
|
|
334
320
|
min_x_data, max_x_data = x_vals_unique.min(), x_vals_unique.max()
|
|
335
321
|
knots = x_vals_unique[
|
|
@@ -340,18 +326,39 @@ class AKS:
|
|
|
340
326
|
# fewer degree of freedom than the total number of knots to function.
|
|
341
327
|
# Dropping the final knot is a natural and practical choice because it
|
|
342
328
|
# often has minimal impact on the overall model fit.
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
329
|
+
return knots[:-1]
|
|
330
|
+
|
|
331
|
+
def _calculate_and_validate_max_internal_knots(
|
|
332
|
+
self,
|
|
333
|
+
knots: np.ndarray,
|
|
334
|
+
min_internal_knots: int = 1,
|
|
335
|
+
max_internal_knots: int | None = None,
|
|
336
|
+
) -> int:
|
|
337
|
+
"""Calculates the max internal knots, and validates the range.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
knots: Initial knots to calculate max internal knots from.
|
|
341
|
+
min_internal_knots: The minimum number of internal knots.
|
|
342
|
+
max_internal_knots: The maximum number of internal knots. If None, this
|
|
343
|
+
value will be calculated, otherwise will use user provided value.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
The maximum number of internal knots.
|
|
347
|
+
"""
|
|
348
|
+
n_features = sum(
|
|
349
|
+
len(feature) if feature is not None else 0
|
|
350
|
+
for feature in [
|
|
351
|
+
self._data.media_channel,
|
|
352
|
+
self._data.rf_channel,
|
|
353
|
+
self._data.organic_media_channel,
|
|
354
|
+
self._data.organic_rf_channel,
|
|
355
|
+
self._data.non_media_channel,
|
|
356
|
+
self._data.control_variable,
|
|
357
|
+
]
|
|
354
358
|
)
|
|
359
|
+
|
|
360
|
+
if max_internal_knots is None:
|
|
361
|
+
max_internal_knots = len(knots) - n_features
|
|
355
362
|
if min_internal_knots > len(knots):
|
|
356
363
|
raise ValueError(
|
|
357
364
|
'The minimum number of internal knots cannot be greater than the'
|
|
@@ -362,8 +369,7 @@ class AKS:
|
|
|
362
369
|
'The maximum number of internal knots cannot be less than the minimum'
|
|
363
370
|
' number of internal knots.'
|
|
364
371
|
)
|
|
365
|
-
|
|
366
|
-
return knots, min_internal_knots, max_internal_knots
|
|
372
|
+
return max_internal_knots
|
|
367
373
|
|
|
368
374
|
def aspline(
|
|
369
375
|
self,
|
meridian/model/media.py
CHANGED
|
@@ -63,8 +63,8 @@ class MediaTensors:
|
|
|
63
63
|
media_spend: A tensor constructed from `InputData.media_spend`.
|
|
64
64
|
media_transformer: A `MediaTransformer` to scale media tensors using the
|
|
65
65
|
model's media data.
|
|
66
|
-
media_scaled: The media tensor
|
|
67
|
-
value.
|
|
66
|
+
media_scaled: The media tensor after pre-modeling transformations including
|
|
67
|
+
population scaling and scaling by the median non-zero value.
|
|
68
68
|
prior_media_scaled_counterfactual: A tensor containing `media_scaled` values
|
|
69
69
|
corresponding to the counterfactual scenario required for the prior
|
|
70
70
|
calculation. For ROI priors, the counterfactual scenario is where media is
|
|
@@ -169,8 +169,9 @@ class OrganicMediaTensors:
|
|
|
169
169
|
organic_media: A tensor constructed from `InputData.organic_media`.
|
|
170
170
|
organic_media_transformer: A `MediaTransformer` to scale media tensors using
|
|
171
171
|
the model's organic media data.
|
|
172
|
-
organic_media_scaled: The organic media tensor
|
|
173
|
-
by the
|
|
172
|
+
organic_media_scaled: The organic media tensor after pre-modeling
|
|
173
|
+
transformations including population scaling and scaling by the media
|
|
174
|
+
non-zero value.
|
|
174
175
|
"""
|
|
175
176
|
|
|
176
177
|
organic_media: backend.Tensor | None = None
|
|
@@ -214,8 +215,8 @@ class RfTensors:
|
|
|
214
215
|
rf_spend: A tensor constructed from `InputData.rf_spend`.
|
|
215
216
|
reach_transformer: A `MediaTransformer` to scale RF tensors using the
|
|
216
217
|
model's RF data.
|
|
217
|
-
reach_scaled: A reach tensor
|
|
218
|
-
value.
|
|
218
|
+
reach_scaled: A reach tensor after pre-modeling transformations including
|
|
219
|
+
population scaling and scaling by the median non-zero value.
|
|
219
220
|
prior_reach_scaled_counterfactual: A tensor containing `reach_scaled` values
|
|
220
221
|
corresponding to the counterfactual scenario required for the prior
|
|
221
222
|
calculation. For ROI priors, the counterfactual scenario is where reach is
|
|
@@ -324,8 +325,9 @@ class OrganicRfTensors:
|
|
|
324
325
|
organic_frequency: A tensor constructed from `InputData.organic_frequency`.
|
|
325
326
|
organic_reach_transformer: A `MediaTransformer` to scale organic RF tensors
|
|
326
327
|
using the model's organic RF data.
|
|
327
|
-
organic_reach_scaled: An organic reach tensor
|
|
328
|
-
by the median
|
|
328
|
+
organic_reach_scaled: An organic reach tensor after pre-modeling
|
|
329
|
+
transformations including population scaling and scaling by the median
|
|
330
|
+
non-zero value.
|
|
329
331
|
"""
|
|
330
332
|
|
|
331
333
|
organic_reach: backend.Tensor | None = None
|