google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
meridian/model/eda/eda_engine.py
CHANGED
|
@@ -14,19 +14,111 @@
|
|
|
14
14
|
|
|
15
15
|
"""Meridian EDA Engine."""
|
|
16
16
|
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
17
19
|
import dataclasses
|
|
18
20
|
import functools
|
|
19
|
-
|
|
21
|
+
import typing
|
|
22
|
+
from typing import Optional, Protocol, Sequence
|
|
23
|
+
|
|
24
|
+
from meridian import backend
|
|
20
25
|
from meridian import constants
|
|
21
|
-
from meridian.model import model
|
|
22
26
|
from meridian.model import transformers
|
|
27
|
+
from meridian.model.eda import constants as eda_constants
|
|
28
|
+
from meridian.model.eda import eda_outcome
|
|
29
|
+
from meridian.model.eda import eda_spec
|
|
23
30
|
import numpy as np
|
|
24
|
-
import
|
|
31
|
+
import pandas as pd
|
|
32
|
+
import statsmodels.api as sm
|
|
33
|
+
from statsmodels.stats import outliers_influence
|
|
25
34
|
import xarray as xr
|
|
26
35
|
|
|
27
36
|
|
|
37
|
+
if typing.TYPE_CHECKING:
|
|
38
|
+
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
|
|
39
|
+
|
|
40
|
+
__all__ = ['EDAEngine', 'GeoLevelCheckOnNationalModelError']
|
|
41
|
+
|
|
28
42
|
_DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
|
|
29
|
-
|
|
43
|
+
_CORRELATION_COL_NAME = eda_constants.CORRELATION
|
|
44
|
+
_STACK_VAR_COORD_NAME = eda_constants.VARIABLE
|
|
45
|
+
_CORR_VAR1 = eda_constants.VARIABLE_1
|
|
46
|
+
_CORR_VAR2 = eda_constants.VARIABLE_2
|
|
47
|
+
_CORRELATION_MATRIX_NAME = 'correlation_matrix'
|
|
48
|
+
_OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
49
|
+
_GEO_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
50
|
+
_NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
51
|
+
_EMPTY_DF_FOR_EXTREME_CORR_PAIRS = pd.DataFrame(
|
|
52
|
+
columns=[_CORR_VAR1, _CORR_VAR2, _CORRELATION_COL_NAME]
|
|
53
|
+
)
|
|
54
|
+
_Q1_THRESHOLD = 0.25
|
|
55
|
+
_Q3_THRESHOLD = 0.75
|
|
56
|
+
_IQR_MULTIPLIER = 1.5
|
|
57
|
+
_STD_WITH_OUTLIERS_VAR_NAME = 'std_with_outliers'
|
|
58
|
+
_STD_WITHOUT_OUTLIERS_VAR_NAME = 'std_without_outliers'
|
|
59
|
+
_STD_THRESHOLD = 1e-4
|
|
60
|
+
_OUTLIERS_COL_NAME = 'outliers'
|
|
61
|
+
_ABS_OUTLIERS_COL_NAME = 'abs_outliers'
|
|
62
|
+
_VIF_COL_NAME = 'VIF'
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class _NamedEDACheckCallable(Protocol):
|
|
66
|
+
"""A callable that returns an EDAOutcome and has a __name__ attribute."""
|
|
67
|
+
|
|
68
|
+
__name__: str
|
|
69
|
+
|
|
70
|
+
def __call__(self) -> eda_outcome.EDAOutcome:
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class GeoLevelCheckOnNationalModelError(Exception):
|
|
75
|
+
"""Raised when a geo-level check is called on a national model."""
|
|
76
|
+
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclasses.dataclass(frozen=True)
|
|
81
|
+
class _RFNames:
|
|
82
|
+
"""Holds constant names for reach and frequency data arrays."""
|
|
83
|
+
|
|
84
|
+
reach: str
|
|
85
|
+
reach_scaled: str
|
|
86
|
+
frequency: str
|
|
87
|
+
impressions: str
|
|
88
|
+
impressions_scaled: str
|
|
89
|
+
national_reach: str
|
|
90
|
+
national_reach_scaled: str
|
|
91
|
+
national_frequency: str
|
|
92
|
+
national_impressions: str
|
|
93
|
+
national_impressions_scaled: str
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
_ORGANIC_RF_NAMES = _RFNames(
|
|
97
|
+
reach=constants.ORGANIC_REACH,
|
|
98
|
+
reach_scaled=constants.ORGANIC_REACH_SCALED,
|
|
99
|
+
frequency=constants.ORGANIC_FREQUENCY,
|
|
100
|
+
impressions=constants.ORGANIC_RF_IMPRESSIONS,
|
|
101
|
+
impressions_scaled=constants.ORGANIC_RF_IMPRESSIONS_SCALED,
|
|
102
|
+
national_reach=constants.NATIONAL_ORGANIC_REACH,
|
|
103
|
+
national_reach_scaled=constants.NATIONAL_ORGANIC_REACH_SCALED,
|
|
104
|
+
national_frequency=constants.NATIONAL_ORGANIC_FREQUENCY,
|
|
105
|
+
national_impressions=constants.NATIONAL_ORGANIC_RF_IMPRESSIONS,
|
|
106
|
+
national_impressions_scaled=constants.NATIONAL_ORGANIC_RF_IMPRESSIONS_SCALED,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
_RF_NAMES = _RFNames(
|
|
111
|
+
reach=constants.REACH,
|
|
112
|
+
reach_scaled=constants.REACH_SCALED,
|
|
113
|
+
frequency=constants.FREQUENCY,
|
|
114
|
+
impressions=constants.RF_IMPRESSIONS,
|
|
115
|
+
impressions_scaled=constants.RF_IMPRESSIONS_SCALED,
|
|
116
|
+
national_reach=constants.NATIONAL_REACH,
|
|
117
|
+
national_reach_scaled=constants.NATIONAL_REACH_SCALED,
|
|
118
|
+
national_frequency=constants.NATIONAL_FREQUENCY,
|
|
119
|
+
national_impressions=constants.NATIONAL_RF_IMPRESSIONS,
|
|
120
|
+
national_impressions_scaled=constants.NATIONAL_RF_IMPRESSIONS_SCALED,
|
|
121
|
+
)
|
|
30
122
|
|
|
31
123
|
|
|
32
124
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
@@ -36,45 +128,220 @@ class ReachFrequencyData:
|
|
|
36
128
|
Attributes:
|
|
37
129
|
reach_raw_da: Raw reach data.
|
|
38
130
|
reach_scaled_da: Scaled reach data.
|
|
39
|
-
|
|
40
|
-
|
|
131
|
+
national_reach_raw_da: National raw reach data.
|
|
132
|
+
national_reach_scaled_da: National scaled reach data.
|
|
41
133
|
frequency_da: Frequency data.
|
|
42
|
-
|
|
134
|
+
national_frequency_da: National frequency data.
|
|
43
135
|
rf_impressions_scaled_da: Scaled reach * frequency impressions data.
|
|
44
|
-
|
|
136
|
+
national_rf_impressions_scaled_da: National scaled reach * frequency
|
|
45
137
|
impressions data.
|
|
46
138
|
rf_impressions_raw_da: Raw reach * frequency impressions data.
|
|
47
|
-
|
|
139
|
+
national_rf_impressions_raw_da: National raw reach * frequency impressions
|
|
48
140
|
data.
|
|
49
141
|
"""
|
|
50
142
|
|
|
51
143
|
reach_raw_da: xr.DataArray
|
|
52
144
|
reach_scaled_da: xr.DataArray
|
|
53
|
-
|
|
54
|
-
|
|
145
|
+
national_reach_raw_da: xr.DataArray
|
|
146
|
+
national_reach_scaled_da: xr.DataArray
|
|
55
147
|
frequency_da: xr.DataArray
|
|
56
|
-
|
|
148
|
+
national_frequency_da: xr.DataArray
|
|
57
149
|
rf_impressions_scaled_da: xr.DataArray
|
|
58
|
-
|
|
150
|
+
national_rf_impressions_scaled_da: xr.DataArray
|
|
59
151
|
rf_impressions_raw_da: xr.DataArray
|
|
60
|
-
|
|
152
|
+
national_rf_impressions_raw_da: xr.DataArray
|
|
61
153
|
|
|
62
154
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
155
|
+
def _data_array_like(
|
|
156
|
+
*, da: xr.DataArray, values: np.ndarray | backend.Tensor
|
|
157
|
+
) -> xr.DataArray:
|
|
158
|
+
"""Returns a DataArray from `values` with the same structure as `da`.
|
|
66
159
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
160
|
+
Args:
|
|
161
|
+
da: The DataArray whose structure (dimensions, coordinates, name, and attrs)
|
|
162
|
+
will be used for the new DataArray.
|
|
163
|
+
values: The numpy array or backend tensor to use as the values for the new
|
|
164
|
+
DataArray.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
A new DataArray with the provided `values` and the same structure as `da`.
|
|
74
168
|
"""
|
|
169
|
+
return xr.DataArray(
|
|
170
|
+
values,
|
|
171
|
+
coords=da.coords,
|
|
172
|
+
dims=da.dims,
|
|
173
|
+
name=da.name,
|
|
174
|
+
attrs=da.attrs,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def stack_variables(
|
|
179
|
+
ds: xr.Dataset, coord_name: str = _STACK_VAR_COORD_NAME
|
|
180
|
+
) -> xr.DataArray:
|
|
181
|
+
"""Stacks data variables of a Dataset into a single DataArray.
|
|
75
182
|
|
|
76
|
-
|
|
77
|
-
|
|
183
|
+
This function is designed to work with Datasets that have 'time' or 'geo'
|
|
184
|
+
dimensions, which are preserved. Other dimensions are stacked into a new
|
|
185
|
+
dimension.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
ds: The input xarray.Dataset to stack.
|
|
189
|
+
coord_name: The name of the new coordinate for the stacked dimension.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
An xarray.DataArray with the specified dimensions stacked.
|
|
193
|
+
"""
|
|
194
|
+
dims = []
|
|
195
|
+
coords = []
|
|
196
|
+
sample_dims = []
|
|
197
|
+
# Dimensions have the same names as the coordinates.
|
|
198
|
+
for dim in ds.dims:
|
|
199
|
+
if dim in [constants.TIME, constants.GEO]:
|
|
200
|
+
sample_dims.append(dim)
|
|
201
|
+
continue
|
|
202
|
+
dims.append(dim)
|
|
203
|
+
coords.extend(ds.coords[dim].values.tolist())
|
|
204
|
+
|
|
205
|
+
da = ds.to_stacked_array(coord_name, sample_dims=sample_dims)
|
|
206
|
+
da = da.reset_index(dims, drop=True).assign_coords({coord_name: coords})
|
|
207
|
+
return da
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _compute_correlation_matrix(
|
|
211
|
+
input_da: xr.DataArray, dims: str | Sequence[str]
|
|
212
|
+
) -> xr.DataArray:
|
|
213
|
+
"""Computes the correlation matrix for variables in a DataArray.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
input_da: An xr.DataArray containing variables for which to compute
|
|
217
|
+
correlations.
|
|
218
|
+
dims: Dimensions along which to compute correlations. Can only be TIME or
|
|
219
|
+
GEO.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
An xr.DataArray containing the correlation matrix.
|
|
223
|
+
"""
|
|
224
|
+
# Create two versions for correlation
|
|
225
|
+
da1 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR1})
|
|
226
|
+
da2 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR2})
|
|
227
|
+
|
|
228
|
+
# Compute pairwise correlation across dims. Other dims are broadcasted.
|
|
229
|
+
corr_mat_da = xr.corr(da1, da2, dim=dims)
|
|
230
|
+
corr_mat_da.name = _CORRELATION_MATRIX_NAME
|
|
231
|
+
return corr_mat_da
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _get_upper_triangle_corr_mat(corr_mat_da: xr.DataArray) -> xr.DataArray:
|
|
235
|
+
"""Gets the upper triangle of a correlation matrix.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
corr_mat_da: An xr.DataArray containing the correlation matrix.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
An xr.DataArray containing only the elements in the upper triangle of the
|
|
242
|
+
correlation matrix, with other elements masked as NaN.
|
|
243
|
+
"""
|
|
244
|
+
n_vars = corr_mat_da.sizes[_CORR_VAR1]
|
|
245
|
+
mask_np = np.triu(np.ones((n_vars, n_vars), dtype=bool), k=1)
|
|
246
|
+
mask = xr.DataArray(
|
|
247
|
+
mask_np,
|
|
248
|
+
dims=[_CORR_VAR1, _CORR_VAR2],
|
|
249
|
+
coords={
|
|
250
|
+
_CORR_VAR1: corr_mat_da[_CORR_VAR1],
|
|
251
|
+
_CORR_VAR2: corr_mat_da[_CORR_VAR2],
|
|
252
|
+
},
|
|
253
|
+
)
|
|
254
|
+
return corr_mat_da.where(mask)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _find_extreme_corr_pairs(
|
|
258
|
+
extreme_corr_da: xr.DataArray, extreme_corr_threshold: float
|
|
259
|
+
) -> pd.DataFrame:
|
|
260
|
+
"""Finds extreme correlation pairs in a correlation matrix."""
|
|
261
|
+
corr_tri = _get_upper_triangle_corr_mat(extreme_corr_da)
|
|
262
|
+
extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
|
|
263
|
+
|
|
264
|
+
df = extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME).dropna()
|
|
265
|
+
if df.empty:
|
|
266
|
+
return _EMPTY_DF_FOR_EXTREME_CORR_PAIRS.copy()
|
|
267
|
+
return df.sort_values(
|
|
268
|
+
by=_CORRELATION_COL_NAME, ascending=False, inplace=False
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _calculate_std(
|
|
273
|
+
input_da: xr.DataArray,
|
|
274
|
+
) -> tuple[xr.Dataset, pd.DataFrame]:
|
|
275
|
+
"""Helper function to compute std with and without outliers.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
input_da: A DataArray for which to calculate the std.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
A tuple where the first element is a Dataset with two data variables:
|
|
282
|
+
'std_incl_outliers' and 'std_excl_outliers'. The second element is a
|
|
283
|
+
DataFrame with columns for variables, geo (if applicable), time, and
|
|
284
|
+
outlier values.
|
|
285
|
+
"""
|
|
286
|
+
std_with_outliers = input_da.std(dim=constants.TIME, ddof=1)
|
|
287
|
+
|
|
288
|
+
# TODO: Allow users to specify custom outlier definitions.
|
|
289
|
+
q1 = input_da.quantile(_Q1_THRESHOLD, dim=constants.TIME)
|
|
290
|
+
q3 = input_da.quantile(_Q3_THRESHOLD, dim=constants.TIME)
|
|
291
|
+
iqr = q3 - q1
|
|
292
|
+
lower_bound = q1 - _IQR_MULTIPLIER * iqr
|
|
293
|
+
upper_bound = q3 + _IQR_MULTIPLIER * iqr
|
|
294
|
+
|
|
295
|
+
da_no_outlier = input_da.where(
|
|
296
|
+
(input_da >= lower_bound) & (input_da <= upper_bound)
|
|
297
|
+
)
|
|
298
|
+
std_without_outliers = da_no_outlier.std(dim=constants.TIME, ddof=1)
|
|
299
|
+
|
|
300
|
+
std_ds = xr.Dataset({
|
|
301
|
+
_STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
|
|
302
|
+
_STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
|
|
303
|
+
})
|
|
304
|
+
|
|
305
|
+
outlier_da = input_da.where(
|
|
306
|
+
(input_da < lower_bound) | (input_da > upper_bound)
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
outlier_df = outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME).dropna()
|
|
310
|
+
outlier_df = outlier_df.assign(
|
|
311
|
+
**{_ABS_OUTLIERS_COL_NAME: np.abs(outlier_df[_OUTLIERS_COL_NAME])}
|
|
312
|
+
).sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
|
|
313
|
+
|
|
314
|
+
return std_ds, outlier_df
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
|
|
318
|
+
"""Helper function to compute variance inflation factor.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
input_da: A DataArray for which to calculate the VIF over sample dimensions
|
|
322
|
+
(e.g. time and geo if applicable).
|
|
323
|
+
var_dim: The dimension name of the variable to compute VIF for.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
A DataArray containing the VIF for each variable in the variable dimension.
|
|
327
|
+
"""
|
|
328
|
+
num_vars = input_da.sizes[var_dim]
|
|
329
|
+
np_data = input_da.values.reshape(-1, num_vars)
|
|
330
|
+
np_data_with_const = sm.add_constant(np_data, prepend=True)
|
|
331
|
+
|
|
332
|
+
# Compute VIF for each variable excluding const which is the first one in the
|
|
333
|
+
# 'variable' dimension.
|
|
334
|
+
vifs = [
|
|
335
|
+
outliers_influence.variance_inflation_factor(np_data_with_const, i)
|
|
336
|
+
for i in range(1, num_vars + 1)
|
|
337
|
+
]
|
|
338
|
+
|
|
339
|
+
vif_da = xr.DataArray(
|
|
340
|
+
vifs,
|
|
341
|
+
coords={var_dim: input_da[var_dim].values},
|
|
342
|
+
dims=[var_dim],
|
|
343
|
+
)
|
|
344
|
+
return vif_da
|
|
78
345
|
|
|
79
346
|
|
|
80
347
|
class EDAEngine:
|
|
@@ -83,10 +350,19 @@ class EDAEngine:
|
|
|
83
350
|
def __init__(
|
|
84
351
|
self,
|
|
85
352
|
meridian: model.Meridian,
|
|
86
|
-
|
|
353
|
+
spec: eda_spec.EDASpec = eda_spec.EDASpec(),
|
|
87
354
|
):
|
|
88
355
|
self._meridian = meridian
|
|
89
|
-
self.
|
|
356
|
+
self._spec = spec
|
|
357
|
+
self._agg_config = self._spec.aggregation_config
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
def spec(self) -> eda_spec.EDASpec:
|
|
361
|
+
return self._spec
|
|
362
|
+
|
|
363
|
+
@property
|
|
364
|
+
def _is_national_data(self) -> bool:
|
|
365
|
+
return self._meridian.is_national
|
|
90
366
|
|
|
91
367
|
@functools.cached_property
|
|
92
368
|
def controls_scaled_da(self) -> xr.DataArray | None:
|
|
@@ -96,33 +372,39 @@ class EDAEngine:
|
|
|
96
372
|
da=self._meridian.input_data.controls,
|
|
97
373
|
values=self._meridian.controls_scaled,
|
|
98
374
|
)
|
|
375
|
+
controls_scaled_da.name = constants.CONTROLS_SCALED
|
|
99
376
|
return controls_scaled_da
|
|
100
377
|
|
|
101
378
|
@functools.cached_property
|
|
102
|
-
def
|
|
103
|
-
"""Returns the national controls data array."""
|
|
379
|
+
def national_controls_scaled_da(self) -> xr.DataArray | None:
|
|
380
|
+
"""Returns the national scaled controls data array."""
|
|
104
381
|
if self._meridian.input_data.controls is None:
|
|
105
382
|
return None
|
|
106
|
-
if self.
|
|
383
|
+
if self._is_national_data:
|
|
107
384
|
if self.controls_scaled_da is None:
|
|
108
385
|
# This case should be impossible given the check above.
|
|
109
386
|
raise RuntimeError(
|
|
110
387
|
'controls_scaled_da is None when controls is not None.'
|
|
111
388
|
)
|
|
112
|
-
|
|
389
|
+
national_da = self.controls_scaled_da.squeeze(constants.GEO, drop=True)
|
|
390
|
+
national_da.name = constants.NATIONAL_CONTROLS_SCALED
|
|
113
391
|
else:
|
|
114
|
-
|
|
392
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
115
393
|
self._meridian.input_data.controls,
|
|
394
|
+
constants.NATIONAL_CONTROLS_SCALED,
|
|
116
395
|
transformers.CenteringAndScalingTransformer,
|
|
117
396
|
constants.CONTROL_VARIABLE,
|
|
118
397
|
self._agg_config.control_variables,
|
|
119
398
|
)
|
|
399
|
+
return national_da
|
|
120
400
|
|
|
121
401
|
@functools.cached_property
|
|
122
402
|
def media_raw_da(self) -> xr.DataArray | None:
|
|
123
403
|
if self._meridian.input_data.media is None:
|
|
124
404
|
return None
|
|
125
|
-
|
|
405
|
+
raw_media_da = self._truncate_media_time(self._meridian.input_data.media)
|
|
406
|
+
raw_media_da.name = constants.MEDIA
|
|
407
|
+
return raw_media_da
|
|
126
408
|
|
|
127
409
|
@functools.cached_property
|
|
128
410
|
def media_scaled_da(self) -> xr.DataArray | None:
|
|
@@ -132,68 +414,84 @@ class EDAEngine:
|
|
|
132
414
|
da=self._meridian.input_data.media,
|
|
133
415
|
values=self._meridian.media_tensors.media_scaled,
|
|
134
416
|
)
|
|
417
|
+
media_scaled_da.name = constants.MEDIA_SCALED
|
|
135
418
|
return self._truncate_media_time(media_scaled_da)
|
|
136
419
|
|
|
137
420
|
@functools.cached_property
|
|
138
421
|
def media_spend_da(self) -> xr.DataArray | None:
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
)
|
|
422
|
+
"""Returns media spend.
|
|
423
|
+
|
|
424
|
+
If the input spend is aggregated, it is allocated across geo and time
|
|
425
|
+
proportionally to media units.
|
|
426
|
+
"""
|
|
145
427
|
# No need to truncate the media time for media spend.
|
|
146
|
-
|
|
428
|
+
da = self._meridian.input_data.allocated_media_spend
|
|
429
|
+
if da is None:
|
|
430
|
+
return None
|
|
431
|
+
da = da.copy()
|
|
432
|
+
da.name = constants.MEDIA_SPEND
|
|
433
|
+
return da
|
|
147
434
|
|
|
148
435
|
@functools.cached_property
|
|
149
|
-
def
|
|
436
|
+
def national_media_spend_da(self) -> xr.DataArray | None:
|
|
150
437
|
"""Returns the national media spend data array."""
|
|
151
|
-
|
|
438
|
+
media_spend = self.media_spend_da
|
|
439
|
+
if media_spend is None:
|
|
152
440
|
return None
|
|
153
|
-
if self.
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
raise RuntimeError(
|
|
157
|
-
'media_spend_da is None when media_spend is not None.'
|
|
158
|
-
)
|
|
159
|
-
return self.media_spend_da.squeeze(constants.GEO)
|
|
441
|
+
if self._is_national_data:
|
|
442
|
+
national_da = media_spend.squeeze(constants.GEO, drop=True)
|
|
443
|
+
national_da.name = constants.NATIONAL_MEDIA_SPEND
|
|
160
444
|
else:
|
|
161
|
-
|
|
162
|
-
self._meridian.input_data.
|
|
445
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
446
|
+
self._meridian.input_data.allocated_media_spend,
|
|
447
|
+
constants.NATIONAL_MEDIA_SPEND,
|
|
163
448
|
None,
|
|
164
449
|
)
|
|
450
|
+
return national_da
|
|
165
451
|
|
|
166
452
|
@functools.cached_property
|
|
167
|
-
def
|
|
453
|
+
def national_media_raw_da(self) -> xr.DataArray | None:
|
|
454
|
+
"""Returns the national raw media data array."""
|
|
168
455
|
if self.media_raw_da is None:
|
|
169
456
|
return None
|
|
170
|
-
if self.
|
|
171
|
-
|
|
457
|
+
if self._is_national_data:
|
|
458
|
+
national_da = self.media_raw_da.squeeze(constants.GEO, drop=True)
|
|
459
|
+
national_da.name = constants.NATIONAL_MEDIA
|
|
172
460
|
else:
|
|
173
461
|
# Note that media is summable by assumption.
|
|
174
|
-
|
|
462
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
175
463
|
self.media_raw_da,
|
|
464
|
+
constants.NATIONAL_MEDIA,
|
|
176
465
|
None,
|
|
177
466
|
)
|
|
467
|
+
return national_da
|
|
178
468
|
|
|
179
469
|
@functools.cached_property
|
|
180
|
-
def
|
|
470
|
+
def national_media_scaled_da(self) -> xr.DataArray | None:
|
|
471
|
+
"""Returns the national scaled media data array."""
|
|
181
472
|
if self.media_scaled_da is None:
|
|
182
473
|
return None
|
|
183
|
-
if self.
|
|
184
|
-
|
|
474
|
+
if self._is_national_data:
|
|
475
|
+
national_da = self.media_scaled_da.squeeze(constants.GEO, drop=True)
|
|
476
|
+
national_da.name = constants.NATIONAL_MEDIA_SCALED
|
|
185
477
|
else:
|
|
186
478
|
# Note that media is summable by assumption.
|
|
187
|
-
|
|
479
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
188
480
|
self.media_raw_da,
|
|
481
|
+
constants.NATIONAL_MEDIA_SCALED,
|
|
189
482
|
transformers.MediaTransformer,
|
|
190
483
|
)
|
|
484
|
+
return national_da
|
|
191
485
|
|
|
192
486
|
@functools.cached_property
|
|
193
487
|
def organic_media_raw_da(self) -> xr.DataArray | None:
|
|
194
488
|
if self._meridian.input_data.organic_media is None:
|
|
195
489
|
return None
|
|
196
|
-
|
|
490
|
+
raw_organic_media_da = self._truncate_media_time(
|
|
491
|
+
self._meridian.input_data.organic_media
|
|
492
|
+
)
|
|
493
|
+
raw_organic_media_da.name = constants.ORGANIC_MEDIA
|
|
494
|
+
return raw_organic_media_da
|
|
197
495
|
|
|
198
496
|
@functools.cached_property
|
|
199
497
|
def organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
@@ -203,30 +501,42 @@ class EDAEngine:
|
|
|
203
501
|
da=self._meridian.input_data.organic_media,
|
|
204
502
|
values=self._meridian.organic_media_tensors.organic_media_scaled,
|
|
205
503
|
)
|
|
504
|
+
organic_media_scaled_da.name = constants.ORGANIC_MEDIA_SCALED
|
|
206
505
|
return self._truncate_media_time(organic_media_scaled_da)
|
|
207
506
|
|
|
208
507
|
@functools.cached_property
|
|
209
|
-
def
|
|
508
|
+
def national_organic_media_raw_da(self) -> xr.DataArray | None:
|
|
509
|
+
"""Returns the national raw organic media data array."""
|
|
210
510
|
if self.organic_media_raw_da is None:
|
|
211
511
|
return None
|
|
212
|
-
if self.
|
|
213
|
-
|
|
512
|
+
if self._is_national_data:
|
|
513
|
+
national_da = self.organic_media_raw_da.squeeze(constants.GEO, drop=True)
|
|
514
|
+
national_da.name = constants.NATIONAL_ORGANIC_MEDIA
|
|
214
515
|
else:
|
|
215
516
|
# Note that organic media is summable by assumption.
|
|
216
|
-
|
|
517
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
518
|
+
self.organic_media_raw_da, constants.NATIONAL_ORGANIC_MEDIA, None
|
|
519
|
+
)
|
|
520
|
+
return national_da
|
|
217
521
|
|
|
218
522
|
@functools.cached_property
|
|
219
|
-
def
|
|
523
|
+
def national_organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
524
|
+
"""Returns the national scaled organic media data array."""
|
|
220
525
|
if self.organic_media_scaled_da is None:
|
|
221
526
|
return None
|
|
222
|
-
if self.
|
|
223
|
-
|
|
527
|
+
if self._is_national_data:
|
|
528
|
+
national_da = self.organic_media_scaled_da.squeeze(
|
|
529
|
+
constants.GEO, drop=True
|
|
530
|
+
)
|
|
531
|
+
national_da.name = constants.NATIONAL_ORGANIC_MEDIA_SCALED
|
|
224
532
|
else:
|
|
225
533
|
# Note that organic media is summable by assumption.
|
|
226
|
-
|
|
534
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
227
535
|
self.organic_media_raw_da,
|
|
536
|
+
constants.NATIONAL_ORGANIC_MEDIA_SCALED,
|
|
228
537
|
transformers.MediaTransformer,
|
|
229
538
|
)
|
|
539
|
+
return national_da
|
|
230
540
|
|
|
231
541
|
@functools.cached_property
|
|
232
542
|
def non_media_scaled_da(self) -> xr.DataArray | None:
|
|
@@ -236,51 +546,62 @@ class EDAEngine:
|
|
|
236
546
|
da=self._meridian.input_data.non_media_treatments,
|
|
237
547
|
values=self._meridian.non_media_treatments_normalized,
|
|
238
548
|
)
|
|
549
|
+
non_media_scaled_da.name = constants.NON_MEDIA_TREATMENTS_SCALED
|
|
239
550
|
return non_media_scaled_da
|
|
240
551
|
|
|
241
552
|
@functools.cached_property
|
|
242
|
-
def
|
|
243
|
-
"""Returns the national non-media treatment data array."""
|
|
553
|
+
def national_non_media_scaled_da(self) -> xr.DataArray | None:
|
|
554
|
+
"""Returns the national scaled non-media treatment data array."""
|
|
244
555
|
if self._meridian.input_data.non_media_treatments is None:
|
|
245
556
|
return None
|
|
246
|
-
if self.
|
|
557
|
+
if self._is_national_data:
|
|
247
558
|
if self.non_media_scaled_da is None:
|
|
248
559
|
# This case should be impossible given the check above.
|
|
249
560
|
raise RuntimeError(
|
|
250
561
|
'non_media_scaled_da is None when non_media_treatments is not None.'
|
|
251
562
|
)
|
|
252
|
-
|
|
563
|
+
national_da = self.non_media_scaled_da.squeeze(constants.GEO, drop=True)
|
|
564
|
+
national_da.name = constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED
|
|
253
565
|
else:
|
|
254
|
-
|
|
566
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
255
567
|
self._meridian.input_data.non_media_treatments,
|
|
568
|
+
constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
|
|
256
569
|
transformers.CenteringAndScalingTransformer,
|
|
257
570
|
constants.NON_MEDIA_CHANNEL,
|
|
258
571
|
self._agg_config.non_media_treatments,
|
|
259
572
|
)
|
|
573
|
+
return national_da
|
|
260
574
|
|
|
261
575
|
@functools.cached_property
|
|
262
576
|
def rf_spend_da(self) -> xr.DataArray | None:
|
|
263
|
-
|
|
577
|
+
"""Returns RF spend.
|
|
578
|
+
|
|
579
|
+
If the input spend is aggregated, it is allocated across geo and time
|
|
580
|
+
proportionally to RF impressions (reach * frequency).
|
|
581
|
+
"""
|
|
582
|
+
da = self._meridian.input_data.allocated_rf_spend
|
|
583
|
+
if da is None:
|
|
264
584
|
return None
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
)
|
|
269
|
-
return rf_spend_da
|
|
585
|
+
da = da.copy()
|
|
586
|
+
da.name = constants.RF_SPEND
|
|
587
|
+
return da
|
|
270
588
|
|
|
271
589
|
@functools.cached_property
|
|
272
|
-
def
|
|
273
|
-
|
|
590
|
+
def national_rf_spend_da(self) -> xr.DataArray | None:
|
|
591
|
+
"""Returns the national RF spend data array."""
|
|
592
|
+
rf_spend = self.rf_spend_da
|
|
593
|
+
if rf_spend is None:
|
|
274
594
|
return None
|
|
275
|
-
if self.
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
raise RuntimeError('rf_spend_da is None when rf_spend is not None.')
|
|
279
|
-
return self.rf_spend_da.squeeze(constants.GEO)
|
|
595
|
+
if self._is_national_data:
|
|
596
|
+
national_da = rf_spend.squeeze(constants.GEO, drop=True)
|
|
597
|
+
national_da.name = constants.NATIONAL_RF_SPEND
|
|
280
598
|
else:
|
|
281
|
-
|
|
282
|
-
self._meridian.input_data.
|
|
599
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
600
|
+
self._meridian.input_data.allocated_rf_spend,
|
|
601
|
+
constants.NATIONAL_RF_SPEND,
|
|
602
|
+
None,
|
|
283
603
|
)
|
|
604
|
+
return national_da
|
|
284
605
|
|
|
285
606
|
@functools.cached_property
|
|
286
607
|
def _rf_data(self) -> ReachFrequencyData | None:
|
|
@@ -302,31 +623,34 @@ class EDAEngine:
|
|
|
302
623
|
def reach_scaled_da(self) -> xr.DataArray | None:
|
|
303
624
|
if self._rf_data is None:
|
|
304
625
|
return None
|
|
305
|
-
return self._rf_data.reach_scaled_da
|
|
626
|
+
return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
|
|
306
627
|
|
|
307
628
|
@property
|
|
308
|
-
def
|
|
629
|
+
def national_reach_raw_da(self) -> xr.DataArray | None:
|
|
630
|
+
"""Returns the national raw reach data array."""
|
|
309
631
|
if self._rf_data is None:
|
|
310
632
|
return None
|
|
311
|
-
return self._rf_data.
|
|
633
|
+
return self._rf_data.national_reach_raw_da
|
|
312
634
|
|
|
313
635
|
@property
|
|
314
|
-
def
|
|
636
|
+
def national_reach_scaled_da(self) -> xr.DataArray | None:
|
|
637
|
+
"""Returns the national scaled reach data array."""
|
|
315
638
|
if self._rf_data is None:
|
|
316
639
|
return None
|
|
317
|
-
return self._rf_data.
|
|
640
|
+
return self._rf_data.national_reach_scaled_da # pytype: disable=attribute-error
|
|
318
641
|
|
|
319
642
|
@property
|
|
320
643
|
def frequency_da(self) -> xr.DataArray | None:
|
|
321
644
|
if self._rf_data is None:
|
|
322
645
|
return None
|
|
323
|
-
return self._rf_data.frequency_da
|
|
646
|
+
return self._rf_data.frequency_da # pytype: disable=attribute-error
|
|
324
647
|
|
|
325
648
|
@property
|
|
326
|
-
def
|
|
649
|
+
def national_frequency_da(self) -> xr.DataArray | None:
|
|
650
|
+
"""Returns the national frequency data array."""
|
|
327
651
|
if self._rf_data is None:
|
|
328
652
|
return None
|
|
329
|
-
return self._rf_data.
|
|
653
|
+
return self._rf_data.national_frequency_da # pytype: disable=attribute-error
|
|
330
654
|
|
|
331
655
|
@property
|
|
332
656
|
def rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
@@ -335,10 +659,11 @@ class EDAEngine:
|
|
|
335
659
|
return self._rf_data.rf_impressions_raw_da
|
|
336
660
|
|
|
337
661
|
@property
|
|
338
|
-
def
|
|
662
|
+
def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
663
|
+
"""Returns the national raw RF impressions data array."""
|
|
339
664
|
if self._rf_data is None:
|
|
340
665
|
return None
|
|
341
|
-
return self._rf_data.
|
|
666
|
+
return self._rf_data.national_rf_impressions_raw_da
|
|
342
667
|
|
|
343
668
|
@property
|
|
344
669
|
def rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
@@ -347,10 +672,11 @@ class EDAEngine:
|
|
|
347
672
|
return self._rf_data.rf_impressions_scaled_da
|
|
348
673
|
|
|
349
674
|
@property
|
|
350
|
-
def
|
|
675
|
+
def national_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
676
|
+
"""Returns the national scaled RF impressions data array."""
|
|
351
677
|
if self._rf_data is None:
|
|
352
678
|
return None
|
|
353
|
-
return self._rf_data.
|
|
679
|
+
return self._rf_data.national_rf_impressions_scaled_da
|
|
354
680
|
|
|
355
681
|
@functools.cached_property
|
|
356
682
|
def _organic_rf_data(self) -> ReachFrequencyData | None:
|
|
@@ -372,19 +698,21 @@ class EDAEngine:
|
|
|
372
698
|
def organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
373
699
|
if self._organic_rf_data is None:
|
|
374
700
|
return None
|
|
375
|
-
return self._organic_rf_data.reach_scaled_da
|
|
701
|
+
return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
|
|
376
702
|
|
|
377
703
|
@property
|
|
378
|
-
def
|
|
704
|
+
def national_organic_reach_raw_da(self) -> xr.DataArray | None:
|
|
705
|
+
"""Returns the national raw organic reach data array."""
|
|
379
706
|
if self._organic_rf_data is None:
|
|
380
707
|
return None
|
|
381
|
-
return self._organic_rf_data.
|
|
708
|
+
return self._organic_rf_data.national_reach_raw_da
|
|
382
709
|
|
|
383
710
|
@property
|
|
384
|
-
def
|
|
711
|
+
def national_organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
712
|
+
"""Returns the national scaled organic reach data array."""
|
|
385
713
|
if self._organic_rf_data is None:
|
|
386
714
|
return None
|
|
387
|
-
return self._organic_rf_data.
|
|
715
|
+
return self._organic_rf_data.national_reach_scaled_da # pytype: disable=attribute-error
|
|
388
716
|
|
|
389
717
|
@property
|
|
390
718
|
def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
@@ -393,22 +721,24 @@ class EDAEngine:
|
|
|
393
721
|
return self._organic_rf_data.rf_impressions_scaled_da
|
|
394
722
|
|
|
395
723
|
@property
|
|
396
|
-
def
|
|
724
|
+
def national_organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
725
|
+
"""Returns the national scaled organic RF impressions data array."""
|
|
397
726
|
if self._organic_rf_data is None:
|
|
398
727
|
return None
|
|
399
|
-
return self._organic_rf_data.
|
|
728
|
+
return self._organic_rf_data.national_rf_impressions_scaled_da
|
|
400
729
|
|
|
401
730
|
@property
|
|
402
731
|
def organic_frequency_da(self) -> xr.DataArray | None:
|
|
403
732
|
if self._organic_rf_data is None:
|
|
404
733
|
return None
|
|
405
|
-
return self._organic_rf_data.frequency_da
|
|
734
|
+
return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
|
|
406
735
|
|
|
407
736
|
@property
|
|
408
|
-
def
|
|
737
|
+
def national_organic_frequency_da(self) -> xr.DataArray | None:
|
|
738
|
+
"""Returns the national organic frequency data array."""
|
|
409
739
|
if self._organic_rf_data is None:
|
|
410
740
|
return None
|
|
411
|
-
return self._organic_rf_data.
|
|
741
|
+
return self._organic_rf_data.national_frequency_da # pytype: disable=attribute-error
|
|
412
742
|
|
|
413
743
|
@property
|
|
414
744
|
def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
@@ -417,14 +747,15 @@ class EDAEngine:
|
|
|
417
747
|
return self._organic_rf_data.rf_impressions_raw_da
|
|
418
748
|
|
|
419
749
|
@property
|
|
420
|
-
def
|
|
750
|
+
def national_organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
751
|
+
"""Returns the national raw organic RF impressions data array."""
|
|
421
752
|
if self._organic_rf_data is None:
|
|
422
753
|
return None
|
|
423
|
-
return self._organic_rf_data.
|
|
754
|
+
return self._organic_rf_data.national_rf_impressions_raw_da
|
|
424
755
|
|
|
425
756
|
@functools.cached_property
|
|
426
757
|
def geo_population_da(self) -> xr.DataArray | None:
|
|
427
|
-
if self.
|
|
758
|
+
if self._is_national_data:
|
|
428
759
|
return None
|
|
429
760
|
return xr.DataArray(
|
|
430
761
|
self._meridian.population,
|
|
@@ -435,21 +766,38 @@ class EDAEngine:
|
|
|
435
766
|
|
|
436
767
|
@functools.cached_property
|
|
437
768
|
def kpi_scaled_da(self) -> xr.DataArray:
|
|
438
|
-
|
|
769
|
+
scaled_kpi_da = _data_array_like(
|
|
439
770
|
da=self._meridian.input_data.kpi,
|
|
440
771
|
values=self._meridian.kpi_scaled,
|
|
441
772
|
)
|
|
773
|
+
scaled_kpi_da.name = constants.KPI_SCALED
|
|
774
|
+
return scaled_kpi_da
|
|
442
775
|
|
|
443
776
|
@functools.cached_property
|
|
444
|
-
def
|
|
445
|
-
|
|
446
|
-
|
|
777
|
+
def _overall_scaled_kpi_invariability_artifact(
|
|
778
|
+
self,
|
|
779
|
+
) -> eda_outcome.KpiInvariabilityArtifact:
|
|
780
|
+
"""Returns an artifact of overall scaled KPI invariability."""
|
|
781
|
+
return eda_outcome.KpiInvariabilityArtifact(
|
|
782
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
783
|
+
kpi_da=self.kpi_scaled_da,
|
|
784
|
+
kpi_stdev=self.kpi_scaled_da.std(ddof=1),
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
@functools.cached_property
|
|
788
|
+
def national_kpi_scaled_da(self) -> xr.DataArray:
|
|
789
|
+
"""Returns the national scaled KPI data array."""
|
|
790
|
+
if self._is_national_data:
|
|
791
|
+
national_da = self.kpi_scaled_da.squeeze(constants.GEO, drop=True)
|
|
792
|
+
national_da.name = constants.NATIONAL_KPI_SCALED
|
|
447
793
|
else:
|
|
448
794
|
# Note that kpi is summable by assumption.
|
|
449
|
-
|
|
795
|
+
national_da = self._aggregate_and_scale_geo_da(
|
|
450
796
|
self._meridian.input_data.kpi,
|
|
797
|
+
constants.NATIONAL_KPI_SCALED,
|
|
451
798
|
transformers.CenteringAndScalingTransformer,
|
|
452
799
|
)
|
|
800
|
+
return national_da
|
|
453
801
|
|
|
454
802
|
@functools.cached_property
|
|
455
803
|
def treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
@@ -473,7 +821,46 @@ class EDAEngine:
|
|
|
473
821
|
return xr.merge(to_merge, join='inner')
|
|
474
822
|
|
|
475
823
|
@functools.cached_property
|
|
476
|
-
def
|
|
824
|
+
def all_spend_ds(self) -> xr.Dataset:
|
|
825
|
+
"""Returns a Dataset containing all spend data.
|
|
826
|
+
|
|
827
|
+
This includes media spend and rf spend.
|
|
828
|
+
"""
|
|
829
|
+
to_merge = [
|
|
830
|
+
da
|
|
831
|
+
for da in [
|
|
832
|
+
self.media_spend_da,
|
|
833
|
+
self.rf_spend_da,
|
|
834
|
+
]
|
|
835
|
+
if da is not None
|
|
836
|
+
]
|
|
837
|
+
return xr.merge(to_merge, join='inner')
|
|
838
|
+
|
|
839
|
+
@functools.cached_property
|
|
840
|
+
def national_all_spend_ds(self) -> xr.Dataset:
|
|
841
|
+
"""Returns a Dataset containing all national spend data.
|
|
842
|
+
|
|
843
|
+
This includes media spend and rf spend.
|
|
844
|
+
"""
|
|
845
|
+
to_merge = [
|
|
846
|
+
da
|
|
847
|
+
for da in [
|
|
848
|
+
self.national_media_spend_da,
|
|
849
|
+
self.national_rf_spend_da,
|
|
850
|
+
]
|
|
851
|
+
if da is not None
|
|
852
|
+
]
|
|
853
|
+
return xr.merge(to_merge, join='inner')
|
|
854
|
+
|
|
855
|
+
@functools.cached_property
|
|
856
|
+
def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
857
|
+
"""Returns a stacked DataArray of treatment_control_scaled_ds."""
|
|
858
|
+
da = stack_variables(self.treatment_control_scaled_ds)
|
|
859
|
+
da.name = constants.TREATMENT_CONTROL_SCALED
|
|
860
|
+
return da
|
|
861
|
+
|
|
862
|
+
@functools.cached_property
|
|
863
|
+
def national_treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
477
864
|
"""Returns a Dataset containing all scaled treatments and controls.
|
|
478
865
|
|
|
479
866
|
This includes media, RF impressions, organic media, organic RF impressions,
|
|
@@ -482,17 +869,148 @@ class EDAEngine:
|
|
|
482
869
|
to_merge_national = [
|
|
483
870
|
da
|
|
484
871
|
for da in [
|
|
485
|
-
self.
|
|
486
|
-
self.
|
|
487
|
-
self.
|
|
488
|
-
self.
|
|
489
|
-
self.
|
|
490
|
-
self.
|
|
872
|
+
self.national_media_scaled_da,
|
|
873
|
+
self.national_rf_impressions_scaled_da,
|
|
874
|
+
self.national_organic_media_scaled_da,
|
|
875
|
+
self.national_organic_rf_impressions_scaled_da,
|
|
876
|
+
self.national_controls_scaled_da,
|
|
877
|
+
self.national_non_media_scaled_da,
|
|
491
878
|
]
|
|
492
879
|
if da is not None
|
|
493
880
|
]
|
|
494
881
|
return xr.merge(to_merge_national, join='inner')
|
|
495
882
|
|
|
883
|
+
@functools.cached_property
|
|
884
|
+
def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
885
|
+
"""Returns a stacked DataArray of national_treatment_control_scaled_ds."""
|
|
886
|
+
da = stack_variables(self.national_treatment_control_scaled_ds)
|
|
887
|
+
da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
|
|
888
|
+
return da
|
|
889
|
+
|
|
890
|
+
@functools.cached_property
|
|
891
|
+
def all_reach_scaled_da(self) -> xr.DataArray | None:
|
|
892
|
+
"""Returns a DataArray containing all scaled reach data.
|
|
893
|
+
|
|
894
|
+
This includes both paid and organic reach, concatenated along the RF_CHANNEL
|
|
895
|
+
dimension.
|
|
896
|
+
|
|
897
|
+
Returns:
|
|
898
|
+
A DataArray containing all scaled reach data, or None if no RF or organic
|
|
899
|
+
RF channels are present.
|
|
900
|
+
"""
|
|
901
|
+
reach_das = []
|
|
902
|
+
if self.reach_scaled_da is not None:
|
|
903
|
+
reach_das.append(self.reach_scaled_da)
|
|
904
|
+
if self.organic_reach_scaled_da is not None:
|
|
905
|
+
reach_das.append(
|
|
906
|
+
self.organic_reach_scaled_da.rename(
|
|
907
|
+
{constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
|
|
908
|
+
)
|
|
909
|
+
)
|
|
910
|
+
if not reach_das:
|
|
911
|
+
return None
|
|
912
|
+
da = xr.concat(reach_das, dim=constants.RF_CHANNEL)
|
|
913
|
+
da.name = constants.ALL_REACH_SCALED
|
|
914
|
+
return da
|
|
915
|
+
|
|
916
|
+
@functools.cached_property
|
|
917
|
+
def all_freq_da(self) -> xr.DataArray | None:
|
|
918
|
+
"""Returns a DataArray containing all frequency data.
|
|
919
|
+
|
|
920
|
+
This includes both paid and organic frequency, concatenated along the
|
|
921
|
+
RF_CHANNEL dimension.
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
A DataArray containing all frequency data, or None if no RF or organic
|
|
925
|
+
RF channels are present.
|
|
926
|
+
"""
|
|
927
|
+
freq_das = []
|
|
928
|
+
if self.frequency_da is not None:
|
|
929
|
+
freq_das.append(self.frequency_da)
|
|
930
|
+
if self.organic_frequency_da is not None:
|
|
931
|
+
freq_das.append(
|
|
932
|
+
self.organic_frequency_da.rename(
|
|
933
|
+
{constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
|
|
934
|
+
)
|
|
935
|
+
)
|
|
936
|
+
if not freq_das:
|
|
937
|
+
return None
|
|
938
|
+
da = xr.concat(freq_das, dim=constants.RF_CHANNEL)
|
|
939
|
+
da.name = constants.ALL_FREQUENCY
|
|
940
|
+
return da
|
|
941
|
+
|
|
942
|
+
@functools.cached_property
|
|
943
|
+
def national_all_reach_scaled_da(self) -> xr.DataArray | None:
|
|
944
|
+
"""Returns a DataArray containing all national-level scaled reach data.
|
|
945
|
+
|
|
946
|
+
This includes both paid and organic reach, concatenated along the
|
|
947
|
+
RF_CHANNEL dimension.
|
|
948
|
+
|
|
949
|
+
Returns:
|
|
950
|
+
A DataArray containing all national-level scaled reach data, or None if
|
|
951
|
+
no RF or organic RF channels are present.
|
|
952
|
+
"""
|
|
953
|
+
national_reach_das = []
|
|
954
|
+
if self.national_reach_scaled_da is not None:
|
|
955
|
+
national_reach_das.append(self.national_reach_scaled_da)
|
|
956
|
+
national_organic_reach_scaled_da = self.national_organic_reach_scaled_da
|
|
957
|
+
if national_organic_reach_scaled_da is not None:
|
|
958
|
+
national_reach_das.append(
|
|
959
|
+
national_organic_reach_scaled_da.rename(
|
|
960
|
+
{constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
|
|
961
|
+
)
|
|
962
|
+
)
|
|
963
|
+
if not national_reach_das:
|
|
964
|
+
return None
|
|
965
|
+
da = xr.concat(national_reach_das, dim=constants.RF_CHANNEL)
|
|
966
|
+
da.name = constants.NATIONAL_ALL_REACH_SCALED
|
|
967
|
+
return da
|
|
968
|
+
|
|
969
|
+
@functools.cached_property
|
|
970
|
+
def national_all_freq_da(self) -> xr.DataArray | None:
|
|
971
|
+
"""Returns a DataArray containing all national-level frequency data.
|
|
972
|
+
|
|
973
|
+
This includes both paid and organic frequency, concatenated along the
|
|
974
|
+
RF_CHANNEL dimension.
|
|
975
|
+
|
|
976
|
+
Returns:
|
|
977
|
+
A DataArray containing all national-level frequency data, or None if no
|
|
978
|
+
RF or organic RF channels are present.
|
|
979
|
+
"""
|
|
980
|
+
national_freq_das = []
|
|
981
|
+
if self.national_frequency_da is not None:
|
|
982
|
+
national_freq_das.append(self.national_frequency_da)
|
|
983
|
+
national_organic_frequency_da = self.national_organic_frequency_da
|
|
984
|
+
if national_organic_frequency_da is not None:
|
|
985
|
+
national_freq_das.append(
|
|
986
|
+
national_organic_frequency_da.rename(
|
|
987
|
+
{constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
|
|
988
|
+
)
|
|
989
|
+
)
|
|
990
|
+
if not national_freq_das:
|
|
991
|
+
return None
|
|
992
|
+
da = xr.concat(national_freq_das, dim=constants.RF_CHANNEL)
|
|
993
|
+
da.name = constants.NATIONAL_ALL_FREQUENCY
|
|
994
|
+
return da
|
|
995
|
+
|
|
996
|
+
@property
|
|
997
|
+
def _critical_checks(
|
|
998
|
+
self,
|
|
999
|
+
) -> list[tuple[_NamedEDACheckCallable, eda_outcome.EDACheckType]]:
|
|
1000
|
+
"""Returns a list of critical checks to be performed."""
|
|
1001
|
+
checks = [
|
|
1002
|
+
(
|
|
1003
|
+
self.check_overall_kpi_invariability,
|
|
1004
|
+
eda_outcome.EDACheckType.KPI_INVARIABILITY,
|
|
1005
|
+
),
|
|
1006
|
+
(self.check_vif, eda_outcome.EDACheckType.MULTICOLLINEARITY),
|
|
1007
|
+
(
|
|
1008
|
+
self.check_pairwise_corr,
|
|
1009
|
+
eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
|
|
1010
|
+
),
|
|
1011
|
+
]
|
|
1012
|
+
return checks
|
|
1013
|
+
|
|
496
1014
|
def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
|
|
497
1015
|
"""Truncates the first `start` elements of the media time of a variable."""
|
|
498
1016
|
# This should not happen. If it does, it means this function is mis-used.
|
|
@@ -510,14 +1028,16 @@ class EDAEngine:
|
|
|
510
1028
|
self,
|
|
511
1029
|
xarray: xr.DataArray,
|
|
512
1030
|
transformer_class: Optional[type[transformers.TensorTransformer]],
|
|
513
|
-
population:
|
|
514
|
-
):
|
|
1031
|
+
population: Optional[backend.Tensor] = None,
|
|
1032
|
+
) -> xr.DataArray:
|
|
515
1033
|
"""Scales xarray values with a TensorTransformer."""
|
|
516
1034
|
da = xarray.copy()
|
|
517
1035
|
|
|
518
1036
|
if transformer_class is None:
|
|
519
1037
|
return da
|
|
520
|
-
|
|
1038
|
+
if population is None:
|
|
1039
|
+
population = backend.ones([1], dtype=backend.float32)
|
|
1040
|
+
if transformer_class is transformers.CenteringAndScalingTransformer:
|
|
521
1041
|
xarray_transformer = transformers.CenteringAndScalingTransformer(
|
|
522
1042
|
tensor=da.values, population=population
|
|
523
1043
|
)
|
|
@@ -537,15 +1057,15 @@ class EDAEngine:
|
|
|
537
1057
|
|
|
538
1058
|
def _aggregate_variables(
|
|
539
1059
|
self,
|
|
540
|
-
|
|
1060
|
+
geo_da: xr.DataArray,
|
|
541
1061
|
channel_dim: str,
|
|
542
|
-
da_var_agg_map: AggregationMap,
|
|
1062
|
+
da_var_agg_map: eda_spec.AggregationMap,
|
|
543
1063
|
keepdims: bool = True,
|
|
544
1064
|
) -> xr.DataArray:
|
|
545
1065
|
"""Aggregates variables within a DataArray based on user-defined functions.
|
|
546
1066
|
|
|
547
1067
|
Args:
|
|
548
|
-
|
|
1068
|
+
geo_da: The geo-level DataArray containing multiple variables along
|
|
549
1069
|
channel_dim.
|
|
550
1070
|
channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
|
|
551
1071
|
constants.CONTROL_VARIABLE).
|
|
@@ -558,8 +1078,8 @@ class EDAEngine:
|
|
|
558
1078
|
aggregated according to the da_var_agg_map.
|
|
559
1079
|
"""
|
|
560
1080
|
agg_results = []
|
|
561
|
-
for var_name in
|
|
562
|
-
var_data =
|
|
1081
|
+
for var_name in geo_da[channel_dim].values:
|
|
1082
|
+
var_data = geo_da.sel({channel_dim: var_name})
|
|
563
1083
|
agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
|
|
564
1084
|
# Apply the aggregation function over the GEO dimension
|
|
565
1085
|
aggregated_data = var_data.reduce(
|
|
@@ -572,15 +1092,17 @@ class EDAEngine:
|
|
|
572
1092
|
|
|
573
1093
|
def _aggregate_and_scale_geo_da(
|
|
574
1094
|
self,
|
|
575
|
-
|
|
1095
|
+
geo_da: xr.DataArray,
|
|
1096
|
+
national_da_name: str,
|
|
576
1097
|
transformer_class: Optional[type[transformers.TensorTransformer]],
|
|
577
1098
|
channel_dim: Optional[str] = None,
|
|
578
|
-
da_var_agg_map: Optional[AggregationMap] = None,
|
|
1099
|
+
da_var_agg_map: Optional[eda_spec.AggregationMap] = None,
|
|
579
1100
|
) -> xr.DataArray:
|
|
580
1101
|
"""Aggregate geo-level xr.DataArray to national level and then scale values.
|
|
581
1102
|
|
|
582
1103
|
Args:
|
|
583
|
-
|
|
1104
|
+
geo_da: The geo-level DataArray to convert.
|
|
1105
|
+
national_da_name: The name for the returned national DataArray.
|
|
584
1106
|
transformer_class: The TensorTransformer class to apply after summing to
|
|
585
1107
|
national level. Must be None, CenteringAndScalingTransformer, or
|
|
586
1108
|
MediaTransformer.
|
|
@@ -599,20 +1121,21 @@ class EDAEngine:
|
|
|
599
1121
|
da_var_agg_map = {}
|
|
600
1122
|
|
|
601
1123
|
if channel_dim is not None:
|
|
602
|
-
|
|
603
|
-
|
|
1124
|
+
national_da = self._aggregate_variables(
|
|
1125
|
+
geo_da, channel_dim, da_var_agg_map
|
|
604
1126
|
)
|
|
605
1127
|
else:
|
|
606
|
-
|
|
607
|
-
da_national = da_geo.sum(
|
|
1128
|
+
national_da = geo_da.sum(
|
|
608
1129
|
dim=constants.GEO, keepdims=True, skipna=False, keep_attrs=True
|
|
609
1130
|
)
|
|
610
1131
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
1132
|
+
national_da = national_da.assign_coords({constants.GEO: [temp_geo_dim]})
|
|
1133
|
+
national_da.values = backend.cast(national_da.values, dtype=backend.float32)
|
|
1134
|
+
national_da = self._scale_xarray(national_da, transformer_class)
|
|
614
1135
|
|
|
615
|
-
|
|
1136
|
+
national_da = national_da.sel({constants.GEO: temp_geo_dim}, drop=True)
|
|
1137
|
+
national_da.name = national_da_name
|
|
1138
|
+
return national_da
|
|
616
1139
|
|
|
617
1140
|
def _get_rf_data(
|
|
618
1141
|
self,
|
|
@@ -625,64 +1148,80 @@ class EDAEngine:
|
|
|
625
1148
|
scaled_reach_values = (
|
|
626
1149
|
self._meridian.organic_rf_tensors.organic_reach_scaled
|
|
627
1150
|
)
|
|
1151
|
+
names = _ORGANIC_RF_NAMES
|
|
628
1152
|
else:
|
|
629
1153
|
scaled_reach_values = self._meridian.rf_tensors.reach_scaled
|
|
1154
|
+
names = _RF_NAMES
|
|
1155
|
+
|
|
630
1156
|
reach_scaled_da = _data_array_like(
|
|
631
1157
|
da=reach_raw_da, values=scaled_reach_values
|
|
632
1158
|
)
|
|
1159
|
+
reach_scaled_da.name = names.reach_scaled
|
|
633
1160
|
# Truncate the media time for reach and scaled reach.
|
|
634
1161
|
reach_raw_da = self._truncate_media_time(reach_raw_da)
|
|
1162
|
+
reach_raw_da.name = names.reach
|
|
635
1163
|
reach_scaled_da = self._truncate_media_time(reach_scaled_da)
|
|
636
1164
|
|
|
637
1165
|
# The geo level frequency
|
|
638
1166
|
frequency_da = self._truncate_media_time(freq_raw_da)
|
|
1167
|
+
frequency_da.name = names.frequency
|
|
639
1168
|
|
|
640
1169
|
# The raw geo level impression
|
|
641
1170
|
# It's equal to reach * frequency.
|
|
642
1171
|
impressions_raw_da = reach_raw_da * frequency_da
|
|
643
|
-
impressions_raw_da.name =
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
else constants.RF_IMPRESSIONS
|
|
1172
|
+
impressions_raw_da.name = names.impressions
|
|
1173
|
+
impressions_raw_da.values = backend.cast(
|
|
1174
|
+
impressions_raw_da.values, dtype=backend.float32
|
|
647
1175
|
)
|
|
648
|
-
impressions_raw_da.values = tf.cast(impressions_raw_da.values, tf.float32)
|
|
649
1176
|
|
|
650
|
-
if self.
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
1177
|
+
if self._is_national_data:
|
|
1178
|
+
national_reach_raw_da = reach_raw_da.squeeze(constants.GEO, drop=True)
|
|
1179
|
+
national_reach_raw_da.name = names.national_reach
|
|
1180
|
+
national_reach_scaled_da = reach_scaled_da.squeeze(
|
|
1181
|
+
constants.GEO, drop=True
|
|
1182
|
+
)
|
|
1183
|
+
national_reach_scaled_da.name = names.national_reach_scaled
|
|
1184
|
+
national_impressions_raw_da = impressions_raw_da.squeeze(
|
|
1185
|
+
constants.GEO, drop=True
|
|
1186
|
+
)
|
|
1187
|
+
national_impressions_raw_da.name = names.national_impressions
|
|
1188
|
+
national_frequency_da = frequency_da.squeeze(constants.GEO, drop=True)
|
|
1189
|
+
national_frequency_da.name = names.national_frequency
|
|
655
1190
|
|
|
656
1191
|
# Scaled impressions
|
|
657
1192
|
impressions_scaled_da = self._scale_xarray(
|
|
658
1193
|
impressions_raw_da, transformers.MediaTransformer
|
|
659
1194
|
)
|
|
660
|
-
|
|
661
|
-
|
|
1195
|
+
impressions_scaled_da.name = names.impressions_scaled
|
|
1196
|
+
national_impressions_scaled_da = impressions_scaled_da.squeeze(
|
|
1197
|
+
constants.GEO, drop=True
|
|
662
1198
|
)
|
|
1199
|
+
national_impressions_scaled_da.name = names.national_impressions_scaled
|
|
663
1200
|
else:
|
|
664
|
-
|
|
665
|
-
reach_raw_da, None
|
|
1201
|
+
national_reach_raw_da = self._aggregate_and_scale_geo_da(
|
|
1202
|
+
reach_raw_da, names.national_reach, None
|
|
666
1203
|
)
|
|
667
|
-
|
|
668
|
-
reach_raw_da,
|
|
1204
|
+
national_reach_scaled_da = self._aggregate_and_scale_geo_da(
|
|
1205
|
+
reach_raw_da,
|
|
1206
|
+
names.national_reach_scaled,
|
|
1207
|
+
transformers.MediaTransformer,
|
|
669
1208
|
)
|
|
670
|
-
|
|
671
|
-
impressions_raw_da,
|
|
1209
|
+
national_impressions_raw_da = self._aggregate_and_scale_geo_da(
|
|
1210
|
+
impressions_raw_da,
|
|
1211
|
+
names.national_impressions,
|
|
1212
|
+
None,
|
|
672
1213
|
)
|
|
673
1214
|
|
|
674
1215
|
# National frequency is a weighted average of geo frequencies,
|
|
675
1216
|
# weighted by reach.
|
|
676
|
-
|
|
677
|
-
|
|
1217
|
+
national_frequency_da = xr.where(
|
|
1218
|
+
national_reach_raw_da == 0.0,
|
|
678
1219
|
0.0,
|
|
679
|
-
|
|
1220
|
+
national_impressions_raw_da / national_reach_raw_da,
|
|
680
1221
|
)
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
frequency_da_national.values = tf.cast(
|
|
685
|
-
frequency_da_national.values, tf.float32
|
|
1222
|
+
national_frequency_da.name = names.national_frequency
|
|
1223
|
+
national_frequency_da.values = backend.cast(
|
|
1224
|
+
national_frequency_da.values, dtype=backend.float32
|
|
686
1225
|
)
|
|
687
1226
|
|
|
688
1227
|
# Scale the impressions by population
|
|
@@ -691,45 +1230,619 @@ class EDAEngine:
|
|
|
691
1230
|
transformers.MediaTransformer,
|
|
692
1231
|
population=self._meridian.population,
|
|
693
1232
|
)
|
|
1233
|
+
impressions_scaled_da.name = names.impressions_scaled
|
|
694
1234
|
|
|
695
1235
|
# Scale the national impressions
|
|
696
|
-
|
|
1236
|
+
national_impressions_scaled_da = self._aggregate_and_scale_geo_da(
|
|
697
1237
|
impressions_raw_da,
|
|
1238
|
+
names.national_impressions_scaled,
|
|
698
1239
|
transformers.MediaTransformer,
|
|
699
1240
|
)
|
|
700
1241
|
|
|
701
1242
|
return ReachFrequencyData(
|
|
702
1243
|
reach_raw_da=reach_raw_da,
|
|
703
1244
|
reach_scaled_da=reach_scaled_da,
|
|
704
|
-
|
|
705
|
-
|
|
1245
|
+
national_reach_raw_da=national_reach_raw_da,
|
|
1246
|
+
national_reach_scaled_da=national_reach_scaled_da,
|
|
706
1247
|
frequency_da=frequency_da,
|
|
707
|
-
|
|
1248
|
+
national_frequency_da=national_frequency_da,
|
|
708
1249
|
rf_impressions_scaled_da=impressions_scaled_da,
|
|
709
|
-
|
|
1250
|
+
national_rf_impressions_scaled_da=national_impressions_scaled_da,
|
|
710
1251
|
rf_impressions_raw_da=impressions_raw_da,
|
|
711
|
-
|
|
1252
|
+
national_rf_impressions_raw_da=national_impressions_raw_da,
|
|
712
1253
|
)
|
|
713
1254
|
|
|
1255
|
+
def _pairwise_corr_for_geo_data(
|
|
1256
|
+
self, dims: str | Sequence[str], extreme_corr_threshold: float
|
|
1257
|
+
) -> tuple[xr.DataArray, pd.DataFrame]:
|
|
1258
|
+
"""Get pairwise correlation among treatments and controls for geo data."""
|
|
1259
|
+
corr_mat = _compute_correlation_matrix(
|
|
1260
|
+
self._stacked_treatment_control_scaled_da, dims=dims
|
|
1261
|
+
)
|
|
1262
|
+
extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
|
|
1263
|
+
corr_mat, extreme_corr_threshold
|
|
1264
|
+
)
|
|
1265
|
+
return corr_mat, extreme_corr_var_pairs_df
|
|
714
1266
|
|
|
715
|
-
def
|
|
716
|
-
|
|
717
|
-
) ->
|
|
718
|
-
|
|
1267
|
+
def check_geo_pairwise_corr(
|
|
1268
|
+
self,
|
|
1269
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PairwiseCorrArtifact]:
|
|
1270
|
+
"""Checks pairwise correlation among treatments and controls for geo data.
|
|
719
1271
|
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
will be used for the new DataArray.
|
|
723
|
-
values: The numpy array or tensorflow tensor to use as the values for the
|
|
724
|
-
new DataArray.
|
|
1272
|
+
Returns:
|
|
1273
|
+
An EDAOutcome object with findings and result values.
|
|
725
1274
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
1275
|
+
Raises:
|
|
1276
|
+
GeoLevelCheckOnNationalModelError: If the model is national.
|
|
1277
|
+
"""
|
|
1278
|
+
# If the model is national, raise an error.
|
|
1279
|
+
if self._is_national_data:
|
|
1280
|
+
raise GeoLevelCheckOnNationalModelError(
|
|
1281
|
+
'check_geo_pairwise_corr is not supported for national models.'
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
findings = []
|
|
1285
|
+
|
|
1286
|
+
overall_corr_mat, overall_extreme_corr_var_pairs_df = (
|
|
1287
|
+
self._pairwise_corr_for_geo_data(
|
|
1288
|
+
dims=[constants.GEO, constants.TIME],
|
|
1289
|
+
extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1290
|
+
)
|
|
1291
|
+
)
|
|
1292
|
+
if not overall_extreme_corr_var_pairs_df.empty:
|
|
1293
|
+
var_pairs = overall_extreme_corr_var_pairs_df.index.to_list()
|
|
1294
|
+
findings.append(
|
|
1295
|
+
eda_outcome.EDAFinding(
|
|
1296
|
+
severity=eda_outcome.EDASeverity.ERROR,
|
|
1297
|
+
explanation=(
|
|
1298
|
+
'Some variables have perfect pairwise correlation across all'
|
|
1299
|
+
' times and geos. For each pair of perfectly-correlated'
|
|
1300
|
+
' variables, please remove one of the variables from the'
|
|
1301
|
+
f' model.\nPairs with perfect correlation: {var_pairs}'
|
|
1302
|
+
),
|
|
1303
|
+
)
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
geo_corr_mat, geo_extreme_corr_var_pairs_df = (
|
|
1307
|
+
self._pairwise_corr_for_geo_data(
|
|
1308
|
+
dims=constants.TIME,
|
|
1309
|
+
extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1310
|
+
)
|
|
1311
|
+
)
|
|
1312
|
+
# Overall correlation and per-geo correlation findings are mutually
|
|
1313
|
+
# exclusive, and overall correlation finding takes precedence.
|
|
1314
|
+
if (
|
|
1315
|
+
overall_extreme_corr_var_pairs_df.empty
|
|
1316
|
+
and not geo_extreme_corr_var_pairs_df.empty
|
|
1317
|
+
):
|
|
1318
|
+
findings.append(
|
|
1319
|
+
eda_outcome.EDAFinding(
|
|
1320
|
+
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1321
|
+
explanation=(
|
|
1322
|
+
'Some variables have perfect pairwise correlation in certain'
|
|
1323
|
+
' geo(s). Consider checking your data, and/or combining these'
|
|
1324
|
+
' variables if they also have high pairwise correlations in'
|
|
1325
|
+
' other geos.'
|
|
1326
|
+
),
|
|
1327
|
+
)
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
# If there are no findings, add a INFO level finding indicating that no
|
|
1331
|
+
# severe correlations were found and what it means for user's data.
|
|
1332
|
+
if not findings:
|
|
1333
|
+
findings.append(
|
|
1334
|
+
eda_outcome.EDAFinding(
|
|
1335
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
1336
|
+
explanation=(
|
|
1337
|
+
'Please review the computed pairwise correlations. Note that'
|
|
1338
|
+
' high pairwise correlation may cause model identifiability'
|
|
1339
|
+
' and convergence issues. Consider combining the variables if'
|
|
1340
|
+
' high correlation exists.'
|
|
1341
|
+
),
|
|
1342
|
+
)
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
pairwise_corr_artifacts = [
|
|
1346
|
+
eda_outcome.PairwiseCorrArtifact(
|
|
1347
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
1348
|
+
corr_matrix=overall_corr_mat,
|
|
1349
|
+
extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
|
|
1350
|
+
extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1351
|
+
),
|
|
1352
|
+
eda_outcome.PairwiseCorrArtifact(
|
|
1353
|
+
level=eda_outcome.AnalysisLevel.GEO,
|
|
1354
|
+
corr_matrix=geo_corr_mat,
|
|
1355
|
+
extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
|
|
1356
|
+
extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1357
|
+
),
|
|
1358
|
+
]
|
|
1359
|
+
|
|
1360
|
+
return eda_outcome.EDAOutcome(
|
|
1361
|
+
check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
|
|
1362
|
+
findings=findings,
|
|
1363
|
+
analysis_artifacts=pairwise_corr_artifacts,
|
|
1364
|
+
)
|
|
1365
|
+
|
|
1366
|
+
def check_national_pairwise_corr(
|
|
1367
|
+
self,
|
|
1368
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PairwiseCorrArtifact]:
|
|
1369
|
+
"""Checks pairwise correlation among treatments and controls for national data.
|
|
1370
|
+
|
|
1371
|
+
Returns:
|
|
1372
|
+
An EDAOutcome object with findings and result values.
|
|
1373
|
+
"""
|
|
1374
|
+
findings = []
|
|
1375
|
+
|
|
1376
|
+
corr_mat = _compute_correlation_matrix(
|
|
1377
|
+
self._stacked_national_treatment_control_scaled_da, dims=constants.TIME
|
|
1378
|
+
)
|
|
1379
|
+
extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
|
|
1380
|
+
corr_mat, _NATIONAL_PAIRWISE_CORR_THRESHOLD
|
|
1381
|
+
)
|
|
1382
|
+
|
|
1383
|
+
if not extreme_corr_var_pairs_df.empty:
|
|
1384
|
+
var_pairs = extreme_corr_var_pairs_df.index.to_list()
|
|
1385
|
+
findings.append(
|
|
1386
|
+
eda_outcome.EDAFinding(
|
|
1387
|
+
severity=eda_outcome.EDASeverity.ERROR,
|
|
1388
|
+
explanation=(
|
|
1389
|
+
'Some variables have perfect pairwise correlation across all'
|
|
1390
|
+
' times. For each pair of perfectly-correlated'
|
|
1391
|
+
' variables, please remove one of the variables from the'
|
|
1392
|
+
f' model.\nPairs with perfect correlation: {var_pairs}'
|
|
1393
|
+
),
|
|
1394
|
+
)
|
|
1395
|
+
)
|
|
1396
|
+
else:
|
|
1397
|
+
findings.append(
|
|
1398
|
+
eda_outcome.EDAFinding(
|
|
1399
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
1400
|
+
explanation=(
|
|
1401
|
+
'Please review the computed pairwise correlations. Note that'
|
|
1402
|
+
' high pairwise correlation may cause model identifiability'
|
|
1403
|
+
' and convergence issues. Consider combining the variables if'
|
|
1404
|
+
' high correlation exists.'
|
|
1405
|
+
),
|
|
1406
|
+
)
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1409
|
+
pairwise_corr_artifacts = [
|
|
1410
|
+
eda_outcome.PairwiseCorrArtifact(
|
|
1411
|
+
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1412
|
+
corr_matrix=corr_mat,
|
|
1413
|
+
extreme_corr_var_pairs=extreme_corr_var_pairs_df,
|
|
1414
|
+
extreme_corr_threshold=_NATIONAL_PAIRWISE_CORR_THRESHOLD,
|
|
1415
|
+
)
|
|
1416
|
+
]
|
|
1417
|
+
return eda_outcome.EDAOutcome(
|
|
1418
|
+
check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
|
|
1419
|
+
findings=findings,
|
|
1420
|
+
analysis_artifacts=pairwise_corr_artifacts,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
def check_pairwise_corr(
|
|
1424
|
+
self,
|
|
1425
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PairwiseCorrArtifact]:
|
|
1426
|
+
"""Checks pairwise correlation among treatments and controls.
|
|
1427
|
+
|
|
1428
|
+
Returns:
|
|
1429
|
+
An EDAOutcome object with findings and result values.
|
|
1430
|
+
"""
|
|
1431
|
+
if self._is_national_data:
|
|
1432
|
+
return self.check_national_pairwise_corr()
|
|
1433
|
+
else:
|
|
1434
|
+
return self.check_geo_pairwise_corr()
|
|
1435
|
+
|
|
1436
|
+
def _check_std(
|
|
1437
|
+
self,
|
|
1438
|
+
data: xr.DataArray,
|
|
1439
|
+
level: eda_outcome.AnalysisLevel,
|
|
1440
|
+
zero_std_message: str,
|
|
1441
|
+
) -> tuple[
|
|
1442
|
+
Optional[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
|
|
1443
|
+
]:
|
|
1444
|
+
"""Helper to check standard deviation."""
|
|
1445
|
+
std_ds, outlier_df = _calculate_std(data)
|
|
1446
|
+
|
|
1447
|
+
finding = None
|
|
1448
|
+
if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
|
|
1449
|
+
finding = eda_outcome.EDAFinding(
|
|
1450
|
+
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1451
|
+
explanation=zero_std_message,
|
|
1452
|
+
)
|
|
1453
|
+
|
|
1454
|
+
artifact = eda_outcome.StandardDeviationArtifact(
|
|
1455
|
+
variable=str(data.name),
|
|
1456
|
+
level=level,
|
|
1457
|
+
std_ds=std_ds,
|
|
1458
|
+
outlier_df=outlier_df,
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
return finding, artifact
|
|
1462
|
+
|
|
1463
|
+
def check_geo_std(
|
|
1464
|
+
self,
|
|
1465
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]:
|
|
1466
|
+
"""Checks std for geo-level KPI, treatments, R&F, and controls."""
|
|
1467
|
+
if self._is_national_data:
|
|
1468
|
+
raise ValueError('check_geo_std is not applicable for national models.')
|
|
1469
|
+
|
|
1470
|
+
findings = []
|
|
1471
|
+
artifacts = []
|
|
1472
|
+
|
|
1473
|
+
checks = [
|
|
1474
|
+
(
|
|
1475
|
+
self.kpi_scaled_da,
|
|
1476
|
+
(
|
|
1477
|
+
'KPI has zero standard deviation after removing outliers'
|
|
1478
|
+
' in certain geos, indicating weak or no signal in the response'
|
|
1479
|
+
' variable for these geos. Please review the input data,'
|
|
1480
|
+
' and/or consider grouping these geos together.'
|
|
1481
|
+
),
|
|
1482
|
+
),
|
|
1483
|
+
(
|
|
1484
|
+
self._stacked_treatment_control_scaled_da,
|
|
1485
|
+
(
|
|
1486
|
+
'Some treatment or control variables have zero standard'
|
|
1487
|
+
' deviation after removing outliers in certain geo(s). Please'
|
|
1488
|
+
' review the input data. If these variables are sparse,'
|
|
1489
|
+
' consider combining them to mitigate potential model'
|
|
1490
|
+
' identifiability and convergence issues.'
|
|
1491
|
+
),
|
|
1492
|
+
),
|
|
1493
|
+
(
|
|
1494
|
+
self.all_reach_scaled_da,
|
|
1495
|
+
(
|
|
1496
|
+
'There are RF or Organic RF channels with zero variation of'
|
|
1497
|
+
' reach across time at a geo after outliers are removed. If'
|
|
1498
|
+
' these channels also have low variation of reach in other'
|
|
1499
|
+
' geos, consider modeling them as impression-based channels'
|
|
1500
|
+
' instead by taking reach * frequency.'
|
|
1501
|
+
),
|
|
1502
|
+
),
|
|
1503
|
+
(
|
|
1504
|
+
self.all_freq_da,
|
|
1505
|
+
(
|
|
1506
|
+
'There are RF or Organic RF channels with zero variation of'
|
|
1507
|
+
' frequency across time at a geo after outliers are removed. If'
|
|
1508
|
+
' these channels also have low variation of frequency in other'
|
|
1509
|
+
' geos, consider modeling them as impression-based channels'
|
|
1510
|
+
' instead by taking reach * frequency.'
|
|
1511
|
+
),
|
|
1512
|
+
),
|
|
1513
|
+
]
|
|
1514
|
+
|
|
1515
|
+
for data_da, message in checks:
|
|
1516
|
+
if data_da is None:
|
|
1517
|
+
continue
|
|
1518
|
+
finding, artifact = self._check_std(
|
|
1519
|
+
level=eda_outcome.AnalysisLevel.GEO,
|
|
1520
|
+
data=data_da,
|
|
1521
|
+
zero_std_message=message,
|
|
1522
|
+
)
|
|
1523
|
+
artifacts.append(artifact)
|
|
1524
|
+
if finding:
|
|
1525
|
+
findings.append(finding)
|
|
1526
|
+
|
|
1527
|
+
# Add an INFO finding if no findings were added.
|
|
1528
|
+
if not findings:
|
|
1529
|
+
findings.append(
|
|
1530
|
+
eda_outcome.EDAFinding(
|
|
1531
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
1532
|
+
explanation=(
|
|
1533
|
+
'Please review any identified outliers and the standard'
|
|
1534
|
+
' deviation.'
|
|
1535
|
+
),
|
|
1536
|
+
)
|
|
1537
|
+
)
|
|
1538
|
+
|
|
1539
|
+
return eda_outcome.EDAOutcome(
|
|
1540
|
+
check_type=eda_outcome.EDACheckType.STANDARD_DEVIATION,
|
|
1541
|
+
findings=findings,
|
|
1542
|
+
analysis_artifacts=artifacts,
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
def check_national_std(
|
|
1546
|
+
self,
|
|
1547
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]:
|
|
1548
|
+
"""Checks std for national-level KPI, treatments, R&F, and controls."""
|
|
1549
|
+
findings = []
|
|
1550
|
+
artifacts = []
|
|
1551
|
+
|
|
1552
|
+
checks = [
|
|
1553
|
+
(
|
|
1554
|
+
self.national_kpi_scaled_da,
|
|
1555
|
+
(
|
|
1556
|
+
'The standard deviation of the scaled KPI drops from positive'
|
|
1557
|
+
' to zero after removing outliers, indicating sparsity of KPI'
|
|
1558
|
+
' i.e. lack of signal in the response variable. Please review'
|
|
1559
|
+
' the input data, and/or reconsider the feasibility of model'
|
|
1560
|
+
' fitting with this dataset.'
|
|
1561
|
+
),
|
|
1562
|
+
),
|
|
1563
|
+
(
|
|
1564
|
+
self._stacked_national_treatment_control_scaled_da,
|
|
1565
|
+
(
|
|
1566
|
+
'The standard deviation of these scaled treatment or control'
|
|
1567
|
+
' variables drops from positive to zero after removing'
|
|
1568
|
+
' outliers. This indicates sparsity of these variables, which'
|
|
1569
|
+
' may cause model identifiability and convergence issues.'
|
|
1570
|
+
' Please review the input data, and/or consider combining these'
|
|
1571
|
+
' variables to mitigate sparsity.'
|
|
1572
|
+
),
|
|
1573
|
+
),
|
|
1574
|
+
(
|
|
1575
|
+
self.national_all_reach_scaled_da,
|
|
1576
|
+
(
|
|
1577
|
+
'There are RF channels with totally zero variation of reach'
|
|
1578
|
+
' across time at the national level after outliers are removed.'
|
|
1579
|
+
' Consider modeling these RF channels as impression-based'
|
|
1580
|
+
' channels instead.'
|
|
1581
|
+
),
|
|
1582
|
+
),
|
|
1583
|
+
(
|
|
1584
|
+
self.national_all_freq_da,
|
|
1585
|
+
(
|
|
1586
|
+
'There are RF channels with totally zero variation of frequency'
|
|
1587
|
+
' across time at the national level after outliers are removed.'
|
|
1588
|
+
' Consider modeling these RF channels as impression-based'
|
|
1589
|
+
' channels instead.'
|
|
1590
|
+
),
|
|
1591
|
+
),
|
|
1592
|
+
]
|
|
1593
|
+
|
|
1594
|
+
for data_da, message in checks:
|
|
1595
|
+
if data_da is None:
|
|
1596
|
+
continue
|
|
1597
|
+
finding, artifact = self._check_std(
|
|
1598
|
+
data=data_da,
|
|
1599
|
+
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1600
|
+
zero_std_message=message,
|
|
1601
|
+
)
|
|
1602
|
+
artifacts.append(artifact)
|
|
1603
|
+
if finding:
|
|
1604
|
+
findings.append(finding)
|
|
1605
|
+
|
|
1606
|
+
# Add an INFO finding if no findings were added.
|
|
1607
|
+
if not findings:
|
|
1608
|
+
findings.append(
|
|
1609
|
+
eda_outcome.EDAFinding(
|
|
1610
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
1611
|
+
explanation=(
|
|
1612
|
+
'Please review any identified outliers and the standard'
|
|
1613
|
+
' deviation.'
|
|
1614
|
+
),
|
|
1615
|
+
)
|
|
1616
|
+
)
|
|
1617
|
+
|
|
1618
|
+
return eda_outcome.EDAOutcome(
|
|
1619
|
+
check_type=eda_outcome.EDACheckType.STANDARD_DEVIATION,
|
|
1620
|
+
findings=findings,
|
|
1621
|
+
analysis_artifacts=artifacts,
|
|
1622
|
+
)
|
|
1623
|
+
|
|
1624
|
+
def check_std(
|
|
1625
|
+
self,
|
|
1626
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]:
|
|
1627
|
+
"""Checks standard deviation for treatments and controls.
|
|
1628
|
+
|
|
1629
|
+
Returns:
|
|
1630
|
+
An EDAOutcome object with findings and result values.
|
|
1631
|
+
"""
|
|
1632
|
+
if self._is_national_data:
|
|
1633
|
+
return self.check_national_std()
|
|
1634
|
+
else:
|
|
1635
|
+
return self.check_geo_std()
|
|
1636
|
+
|
|
1637
|
+
def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1638
|
+
"""Computes geo-level variance inflation factor among treatments and controls."""
|
|
1639
|
+
if self._is_national_data:
|
|
1640
|
+
raise ValueError(
|
|
1641
|
+
'Geo-level VIF checks are not applicable for national models.'
|
|
1642
|
+
)
|
|
1643
|
+
|
|
1644
|
+
# Overall level VIF check for geo data.
|
|
1645
|
+
tc_da = self._stacked_treatment_control_scaled_da
|
|
1646
|
+
overall_threshold = self._spec.vif_spec.overall_threshold
|
|
1647
|
+
|
|
1648
|
+
overall_vif_da = _calculate_vif(tc_da, _STACK_VAR_COORD_NAME)
|
|
1649
|
+
extreme_overall_vif_da = overall_vif_da.where(
|
|
1650
|
+
overall_vif_da > overall_threshold
|
|
1651
|
+
)
|
|
1652
|
+
extreme_overall_vif_df = extreme_overall_vif_da.to_dataframe(
|
|
1653
|
+
name=_VIF_COL_NAME
|
|
1654
|
+
).dropna()
|
|
1655
|
+
|
|
1656
|
+
overall_vif_artifact = eda_outcome.VIFArtifact(
|
|
1657
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
1658
|
+
vif_da=overall_vif_da,
|
|
1659
|
+
outlier_df=extreme_overall_vif_df,
|
|
1660
|
+
)
|
|
1661
|
+
|
|
1662
|
+
# Geo level VIF check.
|
|
1663
|
+
geo_threshold = self._spec.vif_spec.geo_threshold
|
|
1664
|
+
geo_vif_da = tc_da.groupby(constants.GEO).map(
|
|
1665
|
+
lambda x: _calculate_vif(x, _STACK_VAR_COORD_NAME)
|
|
1666
|
+
)
|
|
1667
|
+
extreme_geo_vif_da = geo_vif_da.where(geo_vif_da > geo_threshold)
|
|
1668
|
+
extreme_geo_vif_df = extreme_geo_vif_da.to_dataframe(
|
|
1669
|
+
name=_VIF_COL_NAME
|
|
1670
|
+
).dropna()
|
|
1671
|
+
|
|
1672
|
+
geo_vif_artifact = eda_outcome.VIFArtifact(
|
|
1673
|
+
level=eda_outcome.AnalysisLevel.GEO,
|
|
1674
|
+
vif_da=geo_vif_da,
|
|
1675
|
+
outlier_df=extreme_geo_vif_df,
|
|
1676
|
+
)
|
|
1677
|
+
|
|
1678
|
+
findings = []
|
|
1679
|
+
if not extreme_overall_vif_df.empty:
|
|
1680
|
+
high_vif_vars = extreme_overall_vif_df.index.to_list()
|
|
1681
|
+
findings.append(
|
|
1682
|
+
eda_outcome.EDAFinding(
|
|
1683
|
+
severity=eda_outcome.EDASeverity.ERROR,
|
|
1684
|
+
explanation=(
|
|
1685
|
+
'Some variables have extreme multicollinearity (VIF'
|
|
1686
|
+
f' >{overall_threshold}) across all times and geos. To'
|
|
1687
|
+
' address multicollinearity, please drop any variable that'
|
|
1688
|
+
' is a linear combination of other variables. Otherwise,'
|
|
1689
|
+
' consider combining variables.\n'
|
|
1690
|
+
f'Variables with extreme VIF: {high_vif_vars}'
|
|
1691
|
+
),
|
|
1692
|
+
)
|
|
1693
|
+
)
|
|
1694
|
+
elif not extreme_geo_vif_df.empty:
|
|
1695
|
+
findings.append(
|
|
1696
|
+
eda_outcome.EDAFinding(
|
|
1697
|
+
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1698
|
+
explanation=(
|
|
1699
|
+
'Some variables have extreme multicollinearity (with VIF >'
|
|
1700
|
+
f' {geo_threshold}) in certain geo(s). Consider checking your'
|
|
1701
|
+
' data, and/or combining these variables if they also have'
|
|
1702
|
+
' high VIF in other geos.'
|
|
1703
|
+
),
|
|
1704
|
+
)
|
|
1705
|
+
)
|
|
1706
|
+
else:
|
|
1707
|
+
findings.append(
|
|
1708
|
+
eda_outcome.EDAFinding(
|
|
1709
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
1710
|
+
explanation=(
|
|
1711
|
+
'Please review the computed VIFs. Note that high VIF suggests'
|
|
1712
|
+
' multicollinearity issues in the dataset, which may'
|
|
1713
|
+
' jeopardize model identifiability and model convergence.'
|
|
1714
|
+
' Consider combining the variables if high VIF occurs.'
|
|
1715
|
+
),
|
|
1716
|
+
)
|
|
1717
|
+
)
|
|
1718
|
+
|
|
1719
|
+
return eda_outcome.EDAOutcome(
|
|
1720
|
+
check_type=eda_outcome.EDACheckType.MULTICOLLINEARITY,
|
|
1721
|
+
findings=findings,
|
|
1722
|
+
analysis_artifacts=[overall_vif_artifact, geo_vif_artifact],
|
|
1723
|
+
)
|
|
1724
|
+
|
|
1725
|
+
def check_national_vif(
|
|
1726
|
+
self,
|
|
1727
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1728
|
+
"""Computes national-level variance inflation factor among treatments and controls."""
|
|
1729
|
+
national_tc_da = self._stacked_national_treatment_control_scaled_da
|
|
1730
|
+
national_threshold = self._spec.vif_spec.national_threshold
|
|
1731
|
+
national_vif_da = _calculate_vif(national_tc_da, _STACK_VAR_COORD_NAME)
|
|
1732
|
+
|
|
1733
|
+
extreme_national_vif_df = (
|
|
1734
|
+
national_vif_da.where(national_vif_da > national_threshold)
|
|
1735
|
+
.to_dataframe(name=_VIF_COL_NAME)
|
|
1736
|
+
.dropna()
|
|
1737
|
+
)
|
|
1738
|
+
national_vif_artifact = eda_outcome.VIFArtifact(
|
|
1739
|
+
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1740
|
+
vif_da=national_vif_da,
|
|
1741
|
+
outlier_df=extreme_national_vif_df,
|
|
1742
|
+
)
|
|
1743
|
+
|
|
1744
|
+
findings = []
|
|
1745
|
+
if not extreme_national_vif_df.empty:
|
|
1746
|
+
high_vif_vars = extreme_national_vif_df.index.to_list()
|
|
1747
|
+
findings.append(
|
|
1748
|
+
eda_outcome.EDAFinding(
|
|
1749
|
+
severity=eda_outcome.EDASeverity.ERROR,
|
|
1750
|
+
explanation=(
|
|
1751
|
+
'Some variables have extreme multicollinearity (with VIF >'
|
|
1752
|
+
f' {national_threshold}) across all times. To address'
|
|
1753
|
+
' multicollinearity, please drop any variable that is a'
|
|
1754
|
+
' linear combination of other variables. Otherwise, consider'
|
|
1755
|
+
' combining variables.\n'
|
|
1756
|
+
f'Variables with extreme VIF: {high_vif_vars}'
|
|
1757
|
+
),
|
|
1758
|
+
)
|
|
1759
|
+
)
|
|
1760
|
+
else:
|
|
1761
|
+
findings.append(
|
|
1762
|
+
eda_outcome.EDAFinding(
|
|
1763
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
1764
|
+
explanation=(
|
|
1765
|
+
'Please review the computed VIFs. Note that high VIF suggests'
|
|
1766
|
+
' multicollinearity issues in the dataset, which may'
|
|
1767
|
+
' jeopardize model identifiability and model convergence.'
|
|
1768
|
+
' Consider combining the variables if high VIF occurs.'
|
|
1769
|
+
),
|
|
1770
|
+
)
|
|
1771
|
+
)
|
|
1772
|
+
return eda_outcome.EDAOutcome(
|
|
1773
|
+
check_type=eda_outcome.EDACheckType.MULTICOLLINEARITY,
|
|
1774
|
+
findings=findings,
|
|
1775
|
+
analysis_artifacts=[national_vif_artifact],
|
|
1776
|
+
)
|
|
1777
|
+
|
|
1778
|
+
def check_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1779
|
+
"""Computes variance inflation factor among treatments and controls.
|
|
1780
|
+
|
|
1781
|
+
Returns:
|
|
1782
|
+
An EDAOutcome object with findings and result values.
|
|
1783
|
+
"""
|
|
1784
|
+
if self._is_national_data:
|
|
1785
|
+
return self.check_national_vif()
|
|
1786
|
+
else:
|
|
1787
|
+
return self.check_geo_vif()
|
|
1788
|
+
|
|
1789
|
+
@property
|
|
1790
|
+
def kpi_has_variability(self) -> bool:
|
|
1791
|
+
"""Returns True if the KPI has variability across geos and times."""
|
|
1792
|
+
return (
|
|
1793
|
+
self._overall_scaled_kpi_invariability_artifact.kpi_stdev.item()
|
|
1794
|
+
>= _STD_THRESHOLD
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
def check_overall_kpi_invariability(self) -> eda_outcome.EDAOutcome:
|
|
1798
|
+
"""Checks if the KPI is constant across all geos and times."""
|
|
1799
|
+
kpi = self._overall_scaled_kpi_invariability_artifact.kpi_da.name
|
|
1800
|
+
geo_text = '' if self._is_national_data else 'geos and '
|
|
1801
|
+
|
|
1802
|
+
if not self.kpi_has_variability:
|
|
1803
|
+
eda_finding = eda_outcome.EDAFinding(
|
|
1804
|
+
severity=eda_outcome.EDASeverity.ERROR,
|
|
1805
|
+
explanation=(
|
|
1806
|
+
f'`{kpi}` is constant across all {geo_text}times, indicating no'
|
|
1807
|
+
' signal in the data. Please fix this data error.'
|
|
1808
|
+
),
|
|
1809
|
+
)
|
|
1810
|
+
else:
|
|
1811
|
+
eda_finding = eda_outcome.EDAFinding(
|
|
1812
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
1813
|
+
explanation=(
|
|
1814
|
+
f'The {kpi} has variability across {geo_text}times in the data.'
|
|
1815
|
+
),
|
|
1816
|
+
)
|
|
1817
|
+
|
|
1818
|
+
return eda_outcome.EDAOutcome(
|
|
1819
|
+
check_type=eda_outcome.EDACheckType.KPI_INVARIABILITY,
|
|
1820
|
+
findings=[eda_finding],
|
|
1821
|
+
analysis_artifacts=[self._overall_scaled_kpi_invariability_artifact],
|
|
1822
|
+
)
|
|
1823
|
+
|
|
1824
|
+
def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]:
|
|
1825
|
+
"""Runs all critical EDA checks.
|
|
1826
|
+
|
|
1827
|
+
Critical checks are those that can result in EDASeverity.ERROR findings.
|
|
1828
|
+
|
|
1829
|
+
Returns:
|
|
1830
|
+
A list of EDA outcomes, one for each check.
|
|
1831
|
+
"""
|
|
1832
|
+
outcomes = []
|
|
1833
|
+
for check, check_type in self._critical_checks:
|
|
1834
|
+
try:
|
|
1835
|
+
outcomes.append(check())
|
|
1836
|
+
except Exception as e: # pylint: disable=broad-except
|
|
1837
|
+
error_finding = eda_outcome.EDAFinding(
|
|
1838
|
+
severity=eda_outcome.EDASeverity.ERROR,
|
|
1839
|
+
explanation=f'An error occurred during check {check.__name__}: {e}',
|
|
1840
|
+
)
|
|
1841
|
+
outcomes.append(
|
|
1842
|
+
eda_outcome.EDAOutcome(
|
|
1843
|
+
check_type=check_type,
|
|
1844
|
+
findings=[error_finding],
|
|
1845
|
+
analysis_artifacts=[],
|
|
1846
|
+
)
|
|
1847
|
+
)
|
|
1848
|
+
return outcomes
|