google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
meridian/model/eda/eda_engine.py
CHANGED
|
@@ -16,19 +16,23 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
+
from collections.abc import Collection, Sequence
|
|
19
20
|
import dataclasses
|
|
20
21
|
import functools
|
|
21
22
|
import typing
|
|
22
|
-
from typing import
|
|
23
|
+
from typing import Protocol
|
|
24
|
+
import warnings
|
|
23
25
|
|
|
24
26
|
from meridian import backend
|
|
25
27
|
from meridian import constants
|
|
28
|
+
from meridian.model import context
|
|
26
29
|
from meridian.model import transformers
|
|
27
30
|
from meridian.model.eda import constants as eda_constants
|
|
28
31
|
from meridian.model.eda import eda_outcome
|
|
29
32
|
from meridian.model.eda import eda_spec
|
|
30
33
|
import numpy as np
|
|
31
34
|
import pandas as pd
|
|
35
|
+
from scipy import stats
|
|
32
36
|
import statsmodels.api as sm
|
|
33
37
|
from statsmodels.stats import outliers_influence
|
|
34
38
|
import xarray as xr
|
|
@@ -39,25 +43,6 @@ if typing.TYPE_CHECKING:
|
|
|
39
43
|
|
|
40
44
|
__all__ = ['EDAEngine', 'GeoLevelCheckOnNationalModelError']
|
|
41
45
|
|
|
42
|
-
_DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
|
|
43
|
-
_CORRELATION_COL_NAME = eda_constants.CORRELATION
|
|
44
|
-
_STACK_VAR_COORD_NAME = eda_constants.VARIABLE
|
|
45
|
-
_CORR_VAR1 = eda_constants.VARIABLE_1
|
|
46
|
-
_CORR_VAR2 = eda_constants.VARIABLE_2
|
|
47
|
-
_CORRELATION_MATRIX_NAME = 'correlation_matrix'
|
|
48
|
-
_OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
49
|
-
_GEO_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
50
|
-
_NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
51
|
-
_Q1_THRESHOLD = 0.25
|
|
52
|
-
_Q3_THRESHOLD = 0.75
|
|
53
|
-
_IQR_MULTIPLIER = 1.5
|
|
54
|
-
_STD_WITH_OUTLIERS_VAR_NAME = 'std_with_outliers'
|
|
55
|
-
_STD_WITHOUT_OUTLIERS_VAR_NAME = 'std_without_outliers'
|
|
56
|
-
_STD_THRESHOLD = 1e-4
|
|
57
|
-
_OUTLIERS_COL_NAME = 'outliers'
|
|
58
|
-
_ABS_OUTLIERS_COL_NAME = 'abs_outliers'
|
|
59
|
-
_VIF_COL_NAME = 'VIF'
|
|
60
|
-
|
|
61
46
|
|
|
62
47
|
class _NamedEDACheckCallable(Protocol):
|
|
63
48
|
"""A callable that returns an EDAOutcome and has a __name__ attribute."""
|
|
@@ -149,6 +134,15 @@ class ReachFrequencyData:
|
|
|
149
134
|
national_rf_impressions_raw_da: xr.DataArray
|
|
150
135
|
|
|
151
136
|
|
|
137
|
+
def _get_vars_from_dataset(
|
|
138
|
+
base_ds: xr.Dataset,
|
|
139
|
+
variables_to_include: Collection[str],
|
|
140
|
+
) -> xr.Dataset | None:
|
|
141
|
+
"""Helper to get a subset of variables from a Dataset."""
|
|
142
|
+
variables = [v for v in base_ds.data_vars if v in variables_to_include]
|
|
143
|
+
return base_ds[variables].copy() if variables else None
|
|
144
|
+
|
|
145
|
+
|
|
152
146
|
def _data_array_like(
|
|
153
147
|
*, da: xr.DataArray, values: np.ndarray | backend.Tensor
|
|
154
148
|
) -> xr.DataArray:
|
|
@@ -173,7 +167,7 @@ def _data_array_like(
|
|
|
173
167
|
|
|
174
168
|
|
|
175
169
|
def stack_variables(
|
|
176
|
-
ds: xr.Dataset, coord_name: str =
|
|
170
|
+
ds: xr.Dataset, coord_name: str = eda_constants.VARIABLE
|
|
177
171
|
) -> xr.DataArray:
|
|
178
172
|
"""Stacks data variables of a Dataset into a single DataArray.
|
|
179
173
|
|
|
@@ -219,12 +213,12 @@ def _compute_correlation_matrix(
|
|
|
219
213
|
An xr.DataArray containing the correlation matrix.
|
|
220
214
|
"""
|
|
221
215
|
# Create two versions for correlation
|
|
222
|
-
da1 = input_da.rename({
|
|
223
|
-
da2 = input_da.rename({
|
|
216
|
+
da1 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_1})
|
|
217
|
+
da2 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_2})
|
|
224
218
|
|
|
225
219
|
# Compute pairwise correlation across dims. Other dims are broadcasted.
|
|
226
220
|
corr_mat_da = xr.corr(da1, da2, dim=dims)
|
|
227
|
-
corr_mat_da.name =
|
|
221
|
+
corr_mat_da.name = eda_constants.CORRELATION_MATRIX_NAME
|
|
228
222
|
return corr_mat_da
|
|
229
223
|
|
|
230
224
|
|
|
@@ -238,14 +232,14 @@ def _get_upper_triangle_corr_mat(corr_mat_da: xr.DataArray) -> xr.DataArray:
|
|
|
238
232
|
An xr.DataArray containing only the elements in the upper triangle of the
|
|
239
233
|
correlation matrix, with other elements masked as NaN.
|
|
240
234
|
"""
|
|
241
|
-
n_vars = corr_mat_da.sizes[
|
|
235
|
+
n_vars = corr_mat_da.sizes[eda_constants.VARIABLE_1]
|
|
242
236
|
mask_np = np.triu(np.ones((n_vars, n_vars), dtype=bool), k=1)
|
|
243
237
|
mask = xr.DataArray(
|
|
244
238
|
mask_np,
|
|
245
|
-
dims=[
|
|
239
|
+
dims=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2],
|
|
246
240
|
coords={
|
|
247
|
-
|
|
248
|
-
|
|
241
|
+
eda_constants.VARIABLE_1: corr_mat_da[eda_constants.VARIABLE_1],
|
|
242
|
+
eda_constants.VARIABLE_2: corr_mat_da[eda_constants.VARIABLE_2],
|
|
249
243
|
},
|
|
250
244
|
)
|
|
251
245
|
return corr_mat_da.where(mask)
|
|
@@ -259,11 +253,11 @@ def _find_extreme_corr_pairs(
|
|
|
259
253
|
extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
|
|
260
254
|
|
|
261
255
|
return (
|
|
262
|
-
extreme_corr_da.to_dataframe(name=
|
|
256
|
+
extreme_corr_da.to_dataframe(name=eda_constants.CORRELATION)
|
|
263
257
|
.dropna()
|
|
264
258
|
.assign(**{
|
|
265
259
|
eda_constants.ABS_CORRELATION_COL_NAME: (
|
|
266
|
-
lambda x: x[
|
|
260
|
+
lambda x: x[eda_constants.CORRELATION].abs()
|
|
267
261
|
)
|
|
268
262
|
})
|
|
269
263
|
.sort_values(
|
|
@@ -286,11 +280,11 @@ def _get_outlier_bounds(
|
|
|
286
280
|
A tuple containing the lower and upper bounds of outliers as DataArrays.
|
|
287
281
|
"""
|
|
288
282
|
# TODO: Allow users to specify custom outlier definitions.
|
|
289
|
-
q1 = input_da.quantile(
|
|
290
|
-
q3 = input_da.quantile(
|
|
283
|
+
q1 = input_da.quantile(eda_constants.Q1_THRESHOLD, dim=constants.TIME)
|
|
284
|
+
q3 = input_da.quantile(eda_constants.Q3_THRESHOLD, dim=constants.TIME)
|
|
291
285
|
iqr = q3 - q1
|
|
292
|
-
lower_bound = q1 -
|
|
293
|
-
upper_bound = q3 +
|
|
286
|
+
lower_bound = q1 - eda_constants.IQR_MULTIPLIER * iqr
|
|
287
|
+
upper_bound = q3 + eda_constants.IQR_MULTIPLIER * iqr
|
|
294
288
|
return lower_bound, upper_bound
|
|
295
289
|
|
|
296
290
|
|
|
@@ -314,11 +308,10 @@ def _calculate_std(
|
|
|
314
308
|
)
|
|
315
309
|
std_without_outliers = da_no_outlier.std(dim=constants.TIME, ddof=1)
|
|
316
310
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
311
|
+
return xr.Dataset({
|
|
312
|
+
eda_constants.STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
|
|
313
|
+
eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
|
|
320
314
|
})
|
|
321
|
-
return std_ds
|
|
322
315
|
|
|
323
316
|
|
|
324
317
|
def _calculate_outliers(
|
|
@@ -334,23 +327,29 @@ def _calculate_outliers(
|
|
|
334
327
|
outlier values.
|
|
335
328
|
"""
|
|
336
329
|
lower_bound, upper_bound = _get_outlier_bounds(input_da)
|
|
337
|
-
|
|
338
|
-
(input_da < lower_bound) | (input_da > upper_bound)
|
|
339
|
-
|
|
340
|
-
outlier_df = (
|
|
341
|
-
outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME)
|
|
330
|
+
return (
|
|
331
|
+
input_da.where((input_da < lower_bound) | (input_da > upper_bound))
|
|
332
|
+
.to_dataframe(name=eda_constants.OUTLIERS_COL_NAME)
|
|
342
333
|
.dropna()
|
|
343
|
-
.assign(
|
|
344
|
-
|
|
334
|
+
.assign(**{
|
|
335
|
+
eda_constants.ABS_OUTLIERS_COL_NAME: lambda x: np.abs(
|
|
336
|
+
x[eda_constants.OUTLIERS_COL_NAME]
|
|
337
|
+
)
|
|
338
|
+
})
|
|
339
|
+
.sort_values(
|
|
340
|
+
by=eda_constants.ABS_OUTLIERS_COL_NAME,
|
|
341
|
+
ascending=False,
|
|
342
|
+
inplace=False,
|
|
345
343
|
)
|
|
346
|
-
.sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
|
|
347
344
|
)
|
|
348
|
-
return outlier_df
|
|
349
345
|
|
|
350
346
|
|
|
351
347
|
def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
|
|
352
348
|
"""Helper function to compute variance inflation factor.
|
|
353
349
|
|
|
350
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
351
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
352
|
+
|
|
354
353
|
Args:
|
|
355
354
|
input_da: A DataArray for which to calculate the VIF over sample dimensions
|
|
356
355
|
(e.g. time and geo if applicable).
|
|
@@ -359,25 +358,28 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
|
|
|
359
358
|
Returns:
|
|
360
359
|
A DataArray containing the VIF for each variable in the variable dimension.
|
|
361
360
|
"""
|
|
361
|
+
|
|
362
362
|
num_vars = input_da.sizes[var_dim]
|
|
363
363
|
np_data = input_da.values.reshape(-1, num_vars)
|
|
364
|
-
np_data_with_const = sm.add_constant(
|
|
365
|
-
np_data, prepend=True, has_constant='add'
|
|
366
|
-
)
|
|
367
364
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
365
|
+
is_constant = np.std(np_data, axis=0) < eda_constants.STD_THRESHOLD
|
|
366
|
+
vif_values = np.full(num_vars, np.nan)
|
|
367
|
+
(non_constant_vars_indices,) = (~is_constant).nonzero()
|
|
368
|
+
|
|
369
|
+
if non_constant_vars_indices.size > 0:
|
|
370
|
+
design_matrix = sm.add_constant(
|
|
371
|
+
np_data[:, ~is_constant], prepend=True, has_constant='add'
|
|
372
|
+
)
|
|
373
|
+
for i, var_index in enumerate(non_constant_vars_indices):
|
|
374
|
+
vif_values[var_index] = outliers_influence.variance_inflation_factor(
|
|
375
|
+
design_matrix, i + 1
|
|
376
|
+
)
|
|
374
377
|
|
|
375
|
-
|
|
376
|
-
|
|
378
|
+
return xr.DataArray(
|
|
379
|
+
vif_values,
|
|
377
380
|
coords={var_dim: input_da[var_dim].values},
|
|
378
381
|
dims=[var_dim],
|
|
379
382
|
)
|
|
380
|
-
return vif_da
|
|
381
383
|
|
|
382
384
|
|
|
383
385
|
def _check_cost_media_unit_inconsistency(
|
|
@@ -392,9 +394,9 @@ def _check_cost_media_unit_inconsistency(
|
|
|
392
394
|
|
|
393
395
|
Returns:
|
|
394
396
|
A DataFrame of inconsistencies where either cost is zero and media units
|
|
395
|
-
are
|
|
396
|
-
positive, or cost is positive and media units are zero.
|
|
397
|
+
are positive, or cost is positive and media units are zero.
|
|
397
398
|
"""
|
|
399
|
+
|
|
398
400
|
cost_media_units_ds = xr.merge([cost_da, media_units_da])
|
|
399
401
|
|
|
400
402
|
# Condition 1: cost == 0 and media unit > 0
|
|
@@ -432,39 +434,69 @@ def _check_cost_per_media_unit(
|
|
|
432
434
|
cost_da,
|
|
433
435
|
media_units_da,
|
|
434
436
|
)
|
|
437
|
+
|
|
438
|
+
# Calculate cost per media unit. Avoid division by zero by setting cost to
|
|
439
|
+
# NaN where media units are 0. Note that both (cost == media unit == 0) and
|
|
440
|
+
# (cost > 0 and media unit == 0) result in NaN, while the latter one is not
|
|
441
|
+
# desired.
|
|
442
|
+
cost_per_media_unit_da = xr.where(
|
|
443
|
+
media_units_da == 0,
|
|
444
|
+
np.nan,
|
|
445
|
+
cost_da / media_units_da,
|
|
446
|
+
)
|
|
447
|
+
cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
|
|
448
|
+
outlier_df = _calculate_outliers(cost_per_media_unit_da)
|
|
449
|
+
|
|
450
|
+
if not outlier_df.empty:
|
|
451
|
+
outlier_df = outlier_df.rename(
|
|
452
|
+
columns={
|
|
453
|
+
eda_constants.OUTLIERS_COL_NAME: eda_constants.COST_PER_MEDIA_UNIT,
|
|
454
|
+
eda_constants.ABS_OUTLIERS_COL_NAME: (
|
|
455
|
+
eda_constants.ABS_COST_PER_MEDIA_UNIT
|
|
456
|
+
),
|
|
457
|
+
}
|
|
458
|
+
).assign(**{
|
|
459
|
+
constants.SPEND: cost_da.to_series(),
|
|
460
|
+
constants.MEDIA_UNITS: media_units_da.to_series(),
|
|
461
|
+
})[[
|
|
462
|
+
constants.SPEND,
|
|
463
|
+
constants.MEDIA_UNITS,
|
|
464
|
+
eda_constants.COST_PER_MEDIA_UNIT,
|
|
465
|
+
eda_constants.ABS_COST_PER_MEDIA_UNIT,
|
|
466
|
+
]]
|
|
467
|
+
|
|
468
|
+
artifact = eda_outcome.CostPerMediaUnitArtifact(
|
|
469
|
+
level=level,
|
|
470
|
+
cost_per_media_unit_da=cost_per_media_unit_da,
|
|
471
|
+
cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
|
|
472
|
+
outlier_df=outlier_df,
|
|
473
|
+
)
|
|
474
|
+
|
|
435
475
|
if not cost_media_unit_inconsistency_df.empty:
|
|
436
476
|
findings.append(
|
|
437
477
|
eda_outcome.EDAFinding(
|
|
438
478
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
439
479
|
explanation=(
|
|
440
|
-
'There are instances of inconsistent
|
|
441
|
-
' This occurs when
|
|
442
|
-
' or when
|
|
443
|
-
' review the
|
|
480
|
+
'There are instances of inconsistent spend and media units.'
|
|
481
|
+
' This occurs when spend is zero but media units are positive,'
|
|
482
|
+
' or when spend is positive but media units are zero. Please'
|
|
483
|
+
' review the data input for media units and spend.'
|
|
444
484
|
),
|
|
485
|
+
finding_cause=eda_outcome.FindingCause.INCONSISTENT_DATA,
|
|
486
|
+
associated_artifact=artifact,
|
|
445
487
|
)
|
|
446
488
|
)
|
|
447
489
|
|
|
448
|
-
# Calculate cost per media unit
|
|
449
|
-
# Avoid division by zero by setting cost to NaN where media units are 0.
|
|
450
|
-
# Note that both (cost == media unit == 0) and (cost > 0 and media unit ==
|
|
451
|
-
# 0) result in NaN, while the latter one is not desired.
|
|
452
|
-
cost_per_media_unit_da = xr.where(
|
|
453
|
-
media_units_da == 0,
|
|
454
|
-
np.nan,
|
|
455
|
-
cost_da / media_units_da,
|
|
456
|
-
)
|
|
457
|
-
cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
|
|
458
|
-
|
|
459
|
-
outlier_df = _calculate_outliers(cost_per_media_unit_da)
|
|
460
490
|
if not outlier_df.empty:
|
|
461
491
|
findings.append(
|
|
462
492
|
eda_outcome.EDAFinding(
|
|
463
493
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
464
494
|
explanation=(
|
|
465
495
|
'There are outliers in cost per media unit across time.'
|
|
466
|
-
' Please
|
|
496
|
+
' Please check for any possible data input error.'
|
|
467
497
|
),
|
|
498
|
+
finding_cause=eda_outcome.FindingCause.OUTLIER,
|
|
499
|
+
associated_artifact=artifact,
|
|
468
500
|
)
|
|
469
501
|
)
|
|
470
502
|
|
|
@@ -474,16 +506,10 @@ def _check_cost_per_media_unit(
|
|
|
474
506
|
eda_outcome.EDAFinding(
|
|
475
507
|
severity=eda_outcome.EDASeverity.INFO,
|
|
476
508
|
explanation='Please review the cost per media unit data.',
|
|
509
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
477
510
|
)
|
|
478
511
|
)
|
|
479
512
|
|
|
480
|
-
artifact = eda_outcome.CostPerMediaUnitArtifact(
|
|
481
|
-
level=level,
|
|
482
|
-
cost_per_media_unit_da=cost_per_media_unit_da,
|
|
483
|
-
cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
|
|
484
|
-
outlier_df=outlier_df,
|
|
485
|
-
)
|
|
486
|
-
|
|
487
513
|
return eda_outcome.EDAOutcome(
|
|
488
514
|
check_type=eda_outcome.EDACheckType.COST_PER_MEDIA_UNIT,
|
|
489
515
|
findings=findings,
|
|
@@ -491,42 +517,91 @@ def _check_cost_per_media_unit(
|
|
|
491
517
|
)
|
|
492
518
|
|
|
493
519
|
|
|
520
|
+
def _calc_adj_r2(da: xr.DataArray, regressor: str) -> xr.DataArray:
|
|
521
|
+
"""Calculates adjusted R-squared for a DataArray against a regressor.
|
|
522
|
+
|
|
523
|
+
If the input DataArray `da` is constant, it returns NaN.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
da: The input DataArray.
|
|
527
|
+
regressor: The regressor to use in the formula.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
An xr.DataArray containing the adjusted R-squared value or NaN if `da` is
|
|
531
|
+
constant.
|
|
532
|
+
"""
|
|
533
|
+
if da.std(ddof=1) < eda_constants.STD_THRESHOLD:
|
|
534
|
+
return xr.DataArray(np.nan)
|
|
535
|
+
tmp_name = 'dep_var'
|
|
536
|
+
df = da.to_dataframe(name=tmp_name).reset_index()
|
|
537
|
+
formula = f'{tmp_name} ~ C({regressor})'
|
|
538
|
+
ols = sm.OLS.from_formula(formula, df).fit()
|
|
539
|
+
return xr.DataArray(ols.rsquared_adj)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _spearman_coeff(x, y):
|
|
543
|
+
"""Computes spearman correlation coefficient between two ArrayLike objects."""
|
|
544
|
+
|
|
545
|
+
return stats.spearmanr(x, y, nan_policy='omit').statistic
|
|
546
|
+
|
|
547
|
+
|
|
494
548
|
class EDAEngine:
|
|
495
549
|
"""Meridian EDA Engine."""
|
|
496
550
|
|
|
497
551
|
def __init__(
|
|
498
552
|
self,
|
|
499
|
-
meridian: model.Meridian,
|
|
553
|
+
meridian: model.Meridian | None = None,
|
|
500
554
|
spec: eda_spec.EDASpec = eda_spec.EDASpec(),
|
|
555
|
+
*,
|
|
556
|
+
model_context: context.ModelContext | None = None,
|
|
501
557
|
):
|
|
502
|
-
|
|
558
|
+
if meridian is not None and model_context is not None:
|
|
559
|
+
raise ValueError(
|
|
560
|
+
'Only one of `meridian` or `model_context` can be provided.'
|
|
561
|
+
)
|
|
562
|
+
if meridian is not None:
|
|
563
|
+
warnings.warn(
|
|
564
|
+
'Initializing EDAEngine with a Meridian object is deprecated'
|
|
565
|
+
' and will be removed in a future version. Please use'
|
|
566
|
+
' `model_context` instead.',
|
|
567
|
+
DeprecationWarning,
|
|
568
|
+
stacklevel=2,
|
|
569
|
+
)
|
|
570
|
+
self._model_context = meridian.model_context
|
|
571
|
+
elif model_context is not None:
|
|
572
|
+
self._model_context = model_context
|
|
573
|
+
else:
|
|
574
|
+
raise ValueError('Either `meridian` or `model_context` must be provided.')
|
|
575
|
+
|
|
576
|
+
self._input_data = self._model_context.input_data
|
|
503
577
|
self._spec = spec
|
|
504
578
|
self._agg_config = self._spec.aggregation_config
|
|
505
579
|
|
|
506
580
|
@property
|
|
507
581
|
def spec(self) -> eda_spec.EDASpec:
|
|
582
|
+
"""The EDA specification."""
|
|
508
583
|
return self._spec
|
|
509
584
|
|
|
510
585
|
@property
|
|
511
586
|
def _is_national_data(self) -> bool:
|
|
512
|
-
return self.
|
|
587
|
+
return self._model_context.is_national
|
|
513
588
|
|
|
514
589
|
@functools.cached_property
|
|
515
590
|
def controls_scaled_da(self) -> xr.DataArray | None:
|
|
516
|
-
"""
|
|
517
|
-
if self.
|
|
591
|
+
"""The scaled controls data array."""
|
|
592
|
+
if self._input_data.controls is None:
|
|
518
593
|
return None
|
|
519
594
|
controls_scaled_da = _data_array_like(
|
|
520
|
-
da=self.
|
|
521
|
-
values=self.
|
|
595
|
+
da=self._input_data.controls,
|
|
596
|
+
values=self._model_context.controls_scaled,
|
|
522
597
|
)
|
|
523
598
|
controls_scaled_da.name = constants.CONTROLS_SCALED
|
|
524
599
|
return controls_scaled_da
|
|
525
600
|
|
|
526
601
|
@functools.cached_property
|
|
527
602
|
def national_controls_scaled_da(self) -> xr.DataArray | None:
|
|
528
|
-
"""
|
|
529
|
-
if self.
|
|
603
|
+
"""The national scaled controls data array."""
|
|
604
|
+
if self._input_data.controls is None:
|
|
530
605
|
return None
|
|
531
606
|
if self._is_national_data:
|
|
532
607
|
if self.controls_scaled_da is None:
|
|
@@ -538,7 +613,7 @@ class EDAEngine:
|
|
|
538
613
|
national_da.name = constants.NATIONAL_CONTROLS_SCALED
|
|
539
614
|
else:
|
|
540
615
|
national_da = self._aggregate_and_scale_geo_da(
|
|
541
|
-
self.
|
|
616
|
+
self._input_data.controls,
|
|
542
617
|
constants.NATIONAL_CONTROLS_SCALED,
|
|
543
618
|
transformers.CenteringAndScalingTransformer,
|
|
544
619
|
constants.CONTROL_VARIABLE,
|
|
@@ -548,43 +623,43 @@ class EDAEngine:
|
|
|
548
623
|
|
|
549
624
|
@functools.cached_property
|
|
550
625
|
def media_raw_da(self) -> xr.DataArray | None:
|
|
551
|
-
"""
|
|
552
|
-
if self.
|
|
626
|
+
"""The raw media data array."""
|
|
627
|
+
if self._input_data.media is None:
|
|
553
628
|
return None
|
|
554
|
-
raw_media_da = self._truncate_media_time(self.
|
|
629
|
+
raw_media_da = self._truncate_media_time(self._input_data.media)
|
|
555
630
|
raw_media_da.name = constants.MEDIA
|
|
556
631
|
return raw_media_da
|
|
557
632
|
|
|
558
633
|
@functools.cached_property
|
|
559
634
|
def media_scaled_da(self) -> xr.DataArray | None:
|
|
560
|
-
"""
|
|
561
|
-
if self.
|
|
635
|
+
"""The scaled media data array."""
|
|
636
|
+
if self._input_data.media is None:
|
|
562
637
|
return None
|
|
563
638
|
media_scaled_da = _data_array_like(
|
|
564
|
-
da=self.
|
|
565
|
-
values=self.
|
|
639
|
+
da=self._input_data.media,
|
|
640
|
+
values=self._model_context.media_tensors.media_scaled,
|
|
566
641
|
)
|
|
567
642
|
media_scaled_da.name = constants.MEDIA_SCALED
|
|
568
643
|
return self._truncate_media_time(media_scaled_da)
|
|
569
644
|
|
|
570
645
|
@functools.cached_property
|
|
571
646
|
def media_spend_da(self) -> xr.DataArray | None:
|
|
572
|
-
"""
|
|
647
|
+
"""The media spend data.
|
|
573
648
|
|
|
574
649
|
If the input spend is aggregated, it is allocated across geo and time
|
|
575
650
|
proportionally to media units.
|
|
576
651
|
"""
|
|
577
652
|
# No need to truncate the media time for media spend.
|
|
578
|
-
|
|
579
|
-
if
|
|
653
|
+
allocated_media_spend = self._input_data.allocated_media_spend
|
|
654
|
+
if allocated_media_spend is None:
|
|
580
655
|
return None
|
|
581
|
-
da =
|
|
656
|
+
da = allocated_media_spend.copy()
|
|
582
657
|
da.name = constants.MEDIA_SPEND
|
|
583
658
|
return da
|
|
584
659
|
|
|
585
660
|
@functools.cached_property
|
|
586
661
|
def national_media_spend_da(self) -> xr.DataArray | None:
|
|
587
|
-
"""
|
|
662
|
+
"""The national media spend data array."""
|
|
588
663
|
media_spend = self.media_spend_da
|
|
589
664
|
if media_spend is None:
|
|
590
665
|
return None
|
|
@@ -593,7 +668,7 @@ class EDAEngine:
|
|
|
593
668
|
national_da.name = constants.NATIONAL_MEDIA_SPEND
|
|
594
669
|
else:
|
|
595
670
|
national_da = self._aggregate_and_scale_geo_da(
|
|
596
|
-
self.
|
|
671
|
+
self._input_data.allocated_media_spend,
|
|
597
672
|
constants.NATIONAL_MEDIA_SPEND,
|
|
598
673
|
None,
|
|
599
674
|
)
|
|
@@ -601,7 +676,7 @@ class EDAEngine:
|
|
|
601
676
|
|
|
602
677
|
@functools.cached_property
|
|
603
678
|
def national_media_raw_da(self) -> xr.DataArray | None:
|
|
604
|
-
"""
|
|
679
|
+
"""The national raw media data array."""
|
|
605
680
|
if self.media_raw_da is None:
|
|
606
681
|
return None
|
|
607
682
|
if self._is_national_data:
|
|
@@ -618,7 +693,7 @@ class EDAEngine:
|
|
|
618
693
|
|
|
619
694
|
@functools.cached_property
|
|
620
695
|
def national_media_scaled_da(self) -> xr.DataArray | None:
|
|
621
|
-
"""
|
|
696
|
+
"""The national scaled media data array."""
|
|
622
697
|
if self.media_scaled_da is None:
|
|
623
698
|
return None
|
|
624
699
|
if self._is_national_data:
|
|
@@ -635,30 +710,30 @@ class EDAEngine:
|
|
|
635
710
|
|
|
636
711
|
@functools.cached_property
|
|
637
712
|
def organic_media_raw_da(self) -> xr.DataArray | None:
|
|
638
|
-
"""
|
|
639
|
-
if self.
|
|
713
|
+
"""The raw organic media data array."""
|
|
714
|
+
if self._input_data.organic_media is None:
|
|
640
715
|
return None
|
|
641
716
|
raw_organic_media_da = self._truncate_media_time(
|
|
642
|
-
self.
|
|
717
|
+
self._input_data.organic_media
|
|
643
718
|
)
|
|
644
719
|
raw_organic_media_da.name = constants.ORGANIC_MEDIA
|
|
645
720
|
return raw_organic_media_da
|
|
646
721
|
|
|
647
722
|
@functools.cached_property
|
|
648
723
|
def organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
649
|
-
"""
|
|
650
|
-
if self.
|
|
724
|
+
"""The scaled organic media data array."""
|
|
725
|
+
if self._input_data.organic_media is None:
|
|
651
726
|
return None
|
|
652
727
|
organic_media_scaled_da = _data_array_like(
|
|
653
|
-
da=self.
|
|
654
|
-
values=self.
|
|
728
|
+
da=self._input_data.organic_media,
|
|
729
|
+
values=self._model_context.organic_media_tensors.organic_media_scaled,
|
|
655
730
|
)
|
|
656
731
|
organic_media_scaled_da.name = constants.ORGANIC_MEDIA_SCALED
|
|
657
732
|
return self._truncate_media_time(organic_media_scaled_da)
|
|
658
733
|
|
|
659
734
|
@functools.cached_property
|
|
660
735
|
def national_organic_media_raw_da(self) -> xr.DataArray | None:
|
|
661
|
-
"""
|
|
736
|
+
"""The national raw organic media data array."""
|
|
662
737
|
if self.organic_media_raw_da is None:
|
|
663
738
|
return None
|
|
664
739
|
if self._is_national_data:
|
|
@@ -673,7 +748,7 @@ class EDAEngine:
|
|
|
673
748
|
|
|
674
749
|
@functools.cached_property
|
|
675
750
|
def national_organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
676
|
-
"""
|
|
751
|
+
"""The national scaled organic media data array."""
|
|
677
752
|
if self.organic_media_scaled_da is None:
|
|
678
753
|
return None
|
|
679
754
|
if self._is_national_data:
|
|
@@ -692,20 +767,20 @@ class EDAEngine:
|
|
|
692
767
|
|
|
693
768
|
@functools.cached_property
|
|
694
769
|
def non_media_scaled_da(self) -> xr.DataArray | None:
|
|
695
|
-
"""
|
|
696
|
-
if self.
|
|
770
|
+
"""The scaled non-media treatments data array."""
|
|
771
|
+
if self._input_data.non_media_treatments is None:
|
|
697
772
|
return None
|
|
698
773
|
non_media_scaled_da = _data_array_like(
|
|
699
|
-
da=self.
|
|
700
|
-
values=self.
|
|
774
|
+
da=self._input_data.non_media_treatments,
|
|
775
|
+
values=self._model_context.non_media_treatments_normalized,
|
|
701
776
|
)
|
|
702
777
|
non_media_scaled_da.name = constants.NON_MEDIA_TREATMENTS_SCALED
|
|
703
778
|
return non_media_scaled_da
|
|
704
779
|
|
|
705
780
|
@functools.cached_property
|
|
706
781
|
def national_non_media_scaled_da(self) -> xr.DataArray | None:
|
|
707
|
-
"""
|
|
708
|
-
if self.
|
|
782
|
+
"""The national scaled non-media treatment data array."""
|
|
783
|
+
if self._input_data.non_media_treatments is None:
|
|
709
784
|
return None
|
|
710
785
|
if self._is_national_data:
|
|
711
786
|
if self.non_media_scaled_da is None:
|
|
@@ -717,7 +792,7 @@ class EDAEngine:
|
|
|
717
792
|
national_da.name = constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED
|
|
718
793
|
else:
|
|
719
794
|
national_da = self._aggregate_and_scale_geo_da(
|
|
720
|
-
self.
|
|
795
|
+
self._input_data.non_media_treatments,
|
|
721
796
|
constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
|
|
722
797
|
transformers.CenteringAndScalingTransformer,
|
|
723
798
|
constants.NON_MEDIA_CHANNEL,
|
|
@@ -727,12 +802,12 @@ class EDAEngine:
|
|
|
727
802
|
|
|
728
803
|
@functools.cached_property
|
|
729
804
|
def rf_spend_da(self) -> xr.DataArray | None:
|
|
730
|
-
"""
|
|
805
|
+
"""The RF spend data.
|
|
731
806
|
|
|
732
807
|
If the input spend is aggregated, it is allocated across geo and time
|
|
733
808
|
proportionally to RF impressions (reach * frequency).
|
|
734
809
|
"""
|
|
735
|
-
da = self.
|
|
810
|
+
da = self._input_data.allocated_rf_spend
|
|
736
811
|
if da is None:
|
|
737
812
|
return None
|
|
738
813
|
da = da.copy()
|
|
@@ -741,7 +816,7 @@ class EDAEngine:
|
|
|
741
816
|
|
|
742
817
|
@functools.cached_property
|
|
743
818
|
def national_rf_spend_da(self) -> xr.DataArray | None:
|
|
744
|
-
"""
|
|
819
|
+
"""The national RF spend data array."""
|
|
745
820
|
rf_spend = self.rf_spend_da
|
|
746
821
|
if rf_spend is None:
|
|
747
822
|
return None
|
|
@@ -750,7 +825,7 @@ class EDAEngine:
|
|
|
750
825
|
national_da.name = constants.NATIONAL_RF_SPEND
|
|
751
826
|
else:
|
|
752
827
|
national_da = self._aggregate_and_scale_geo_da(
|
|
753
|
-
self.
|
|
828
|
+
self._input_data.allocated_rf_spend,
|
|
754
829
|
constants.NATIONAL_RF_SPEND,
|
|
755
830
|
None,
|
|
756
831
|
)
|
|
@@ -758,182 +833,182 @@ class EDAEngine:
|
|
|
758
833
|
|
|
759
834
|
@functools.cached_property
|
|
760
835
|
def _rf_data(self) -> ReachFrequencyData | None:
|
|
761
|
-
if self.
|
|
836
|
+
if self._input_data.reach is None:
|
|
762
837
|
return None
|
|
763
838
|
return self._get_rf_data(
|
|
764
|
-
self.
|
|
765
|
-
self.
|
|
839
|
+
self._input_data.reach,
|
|
840
|
+
self._input_data.frequency,
|
|
766
841
|
is_organic=False,
|
|
767
842
|
)
|
|
768
843
|
|
|
769
844
|
@property
|
|
770
845
|
def reach_raw_da(self) -> xr.DataArray | None:
|
|
771
|
-
"""
|
|
846
|
+
"""The raw reach data array."""
|
|
772
847
|
if self._rf_data is None:
|
|
773
848
|
return None
|
|
774
|
-
return self._rf_data.reach_raw_da
|
|
849
|
+
return self._rf_data.reach_raw_da # pytype: disable=attribute-error
|
|
775
850
|
|
|
776
851
|
@property
|
|
777
852
|
def reach_scaled_da(self) -> xr.DataArray | None:
|
|
778
|
-
"""
|
|
853
|
+
"""The scaled reach data array."""
|
|
779
854
|
if self._rf_data is None:
|
|
780
855
|
return None
|
|
781
856
|
return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
|
|
782
857
|
|
|
783
858
|
@property
|
|
784
859
|
def national_reach_raw_da(self) -> xr.DataArray | None:
|
|
785
|
-
"""
|
|
860
|
+
"""The national raw reach data array."""
|
|
786
861
|
if self._rf_data is None:
|
|
787
862
|
return None
|
|
788
863
|
return self._rf_data.national_reach_raw_da
|
|
789
864
|
|
|
790
865
|
@property
|
|
791
866
|
def national_reach_scaled_da(self) -> xr.DataArray | None:
|
|
792
|
-
"""
|
|
867
|
+
"""The national scaled reach data array."""
|
|
793
868
|
if self._rf_data is None:
|
|
794
869
|
return None
|
|
795
870
|
return self._rf_data.national_reach_scaled_da # pytype: disable=attribute-error
|
|
796
871
|
|
|
797
872
|
@property
|
|
798
873
|
def frequency_da(self) -> xr.DataArray | None:
|
|
799
|
-
"""
|
|
874
|
+
"""The frequency data array."""
|
|
800
875
|
if self._rf_data is None:
|
|
801
876
|
return None
|
|
802
877
|
return self._rf_data.frequency_da # pytype: disable=attribute-error
|
|
803
878
|
|
|
804
879
|
@property
|
|
805
880
|
def national_frequency_da(self) -> xr.DataArray | None:
|
|
806
|
-
"""
|
|
881
|
+
"""The national frequency data array."""
|
|
807
882
|
if self._rf_data is None:
|
|
808
883
|
return None
|
|
809
884
|
return self._rf_data.national_frequency_da # pytype: disable=attribute-error
|
|
810
885
|
|
|
811
886
|
@property
|
|
812
887
|
def rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
813
|
-
"""
|
|
888
|
+
"""The raw RF impressions data array."""
|
|
814
889
|
if self._rf_data is None:
|
|
815
890
|
return None
|
|
816
891
|
return self._rf_data.rf_impressions_raw_da # pytype: disable=attribute-error
|
|
817
892
|
|
|
818
893
|
@property
|
|
819
894
|
def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
820
|
-
"""
|
|
895
|
+
"""The national raw RF impressions data array."""
|
|
821
896
|
if self._rf_data is None:
|
|
822
897
|
return None
|
|
823
898
|
return self._rf_data.national_rf_impressions_raw_da # pytype: disable=attribute-error
|
|
824
899
|
|
|
825
900
|
@property
|
|
826
901
|
def rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
827
|
-
"""
|
|
902
|
+
"""The scaled RF impressions data array."""
|
|
828
903
|
if self._rf_data is None:
|
|
829
904
|
return None
|
|
830
905
|
return self._rf_data.rf_impressions_scaled_da
|
|
831
906
|
|
|
832
907
|
@property
|
|
833
908
|
def national_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
834
|
-
"""
|
|
909
|
+
"""The national scaled RF impressions data array."""
|
|
835
910
|
if self._rf_data is None:
|
|
836
911
|
return None
|
|
837
912
|
return self._rf_data.national_rf_impressions_scaled_da
|
|
838
913
|
|
|
839
914
|
@functools.cached_property
|
|
840
915
|
def _organic_rf_data(self) -> ReachFrequencyData | None:
|
|
841
|
-
if self.
|
|
916
|
+
if self._input_data.organic_reach is None:
|
|
842
917
|
return None
|
|
843
918
|
return self._get_rf_data(
|
|
844
|
-
self.
|
|
845
|
-
self.
|
|
919
|
+
self._input_data.organic_reach,
|
|
920
|
+
self._input_data.organic_frequency,
|
|
846
921
|
is_organic=True,
|
|
847
922
|
)
|
|
848
923
|
|
|
849
924
|
@property
|
|
850
925
|
def organic_reach_raw_da(self) -> xr.DataArray | None:
|
|
851
|
-
"""
|
|
926
|
+
"""The raw organic reach data array."""
|
|
852
927
|
if self._organic_rf_data is None:
|
|
853
928
|
return None
|
|
854
|
-
return self._organic_rf_data.reach_raw_da
|
|
929
|
+
return self._organic_rf_data.reach_raw_da # pytype: disable=attribute-error
|
|
855
930
|
|
|
856
931
|
@property
|
|
857
932
|
def organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
858
|
-
"""
|
|
933
|
+
"""The scaled organic reach data array."""
|
|
859
934
|
if self._organic_rf_data is None:
|
|
860
935
|
return None
|
|
861
936
|
return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
|
|
862
937
|
|
|
863
938
|
@property
|
|
864
939
|
def national_organic_reach_raw_da(self) -> xr.DataArray | None:
|
|
865
|
-
"""
|
|
940
|
+
"""The national raw organic reach data array."""
|
|
866
941
|
if self._organic_rf_data is None:
|
|
867
942
|
return None
|
|
868
943
|
return self._organic_rf_data.national_reach_raw_da
|
|
869
944
|
|
|
870
945
|
@property
|
|
871
946
|
def national_organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
872
|
-
"""
|
|
947
|
+
"""The national scaled organic reach data array."""
|
|
873
948
|
if self._organic_rf_data is None:
|
|
874
949
|
return None
|
|
875
950
|
return self._organic_rf_data.national_reach_scaled_da # pytype: disable=attribute-error
|
|
876
951
|
|
|
877
952
|
@property
|
|
878
953
|
def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
879
|
-
"""
|
|
954
|
+
"""The scaled organic RF impressions data array."""
|
|
880
955
|
if self._organic_rf_data is None:
|
|
881
956
|
return None
|
|
882
957
|
return self._organic_rf_data.rf_impressions_scaled_da
|
|
883
958
|
|
|
884
959
|
@property
|
|
885
960
|
def national_organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
886
|
-
"""
|
|
961
|
+
"""The national scaled organic RF impressions data array."""
|
|
887
962
|
if self._organic_rf_data is None:
|
|
888
963
|
return None
|
|
889
964
|
return self._organic_rf_data.national_rf_impressions_scaled_da
|
|
890
965
|
|
|
891
966
|
@property
|
|
892
967
|
def organic_frequency_da(self) -> xr.DataArray | None:
|
|
893
|
-
"""
|
|
968
|
+
"""The organic frequency data array."""
|
|
894
969
|
if self._organic_rf_data is None:
|
|
895
970
|
return None
|
|
896
971
|
return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
|
|
897
972
|
|
|
898
973
|
@property
|
|
899
974
|
def national_organic_frequency_da(self) -> xr.DataArray | None:
|
|
900
|
-
"""
|
|
975
|
+
"""The national organic frequency data array."""
|
|
901
976
|
if self._organic_rf_data is None:
|
|
902
977
|
return None
|
|
903
978
|
return self._organic_rf_data.national_frequency_da # pytype: disable=attribute-error
|
|
904
979
|
|
|
905
980
|
@property
|
|
906
981
|
def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
907
|
-
"""
|
|
982
|
+
"""The raw organic RF impressions data array."""
|
|
908
983
|
if self._organic_rf_data is None:
|
|
909
984
|
return None
|
|
910
985
|
return self._organic_rf_data.rf_impressions_raw_da
|
|
911
986
|
|
|
912
987
|
@property
|
|
913
988
|
def national_organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
914
|
-
"""
|
|
989
|
+
"""The national raw organic RF impressions data array."""
|
|
915
990
|
if self._organic_rf_data is None:
|
|
916
991
|
return None
|
|
917
992
|
return self._organic_rf_data.national_rf_impressions_raw_da
|
|
918
993
|
|
|
919
994
|
@functools.cached_property
|
|
920
995
|
def geo_population_da(self) -> xr.DataArray | None:
|
|
921
|
-
"""
|
|
996
|
+
"""The geo population data array."""
|
|
922
997
|
if self._is_national_data:
|
|
923
998
|
return None
|
|
924
999
|
return xr.DataArray(
|
|
925
|
-
self.
|
|
926
|
-
coords={constants.GEO: self.
|
|
1000
|
+
self._model_context.population,
|
|
1001
|
+
coords={constants.GEO: self._input_data.geo.values},
|
|
927
1002
|
dims=[constants.GEO],
|
|
928
1003
|
name=constants.POPULATION,
|
|
929
1004
|
)
|
|
930
1005
|
|
|
931
1006
|
@functools.cached_property
|
|
932
1007
|
def kpi_scaled_da(self) -> xr.DataArray:
|
|
933
|
-
"""
|
|
1008
|
+
"""The scaled KPI data array."""
|
|
934
1009
|
scaled_kpi_da = _data_array_like(
|
|
935
|
-
da=self.
|
|
936
|
-
values=self.
|
|
1010
|
+
da=self._input_data.kpi,
|
|
1011
|
+
values=self._model_context.kpi_scaled,
|
|
937
1012
|
)
|
|
938
1013
|
scaled_kpi_da.name = constants.KPI_SCALED
|
|
939
1014
|
return scaled_kpi_da
|
|
@@ -942,7 +1017,7 @@ class EDAEngine:
|
|
|
942
1017
|
def _overall_scaled_kpi_invariability_artifact(
|
|
943
1018
|
self,
|
|
944
1019
|
) -> eda_outcome.KpiInvariabilityArtifact:
|
|
945
|
-
"""
|
|
1020
|
+
"""An artifact of overall scaled KPI invariability."""
|
|
946
1021
|
return eda_outcome.KpiInvariabilityArtifact(
|
|
947
1022
|
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
948
1023
|
kpi_da=self.kpi_scaled_da,
|
|
@@ -951,14 +1026,14 @@ class EDAEngine:
|
|
|
951
1026
|
|
|
952
1027
|
@functools.cached_property
|
|
953
1028
|
def national_kpi_scaled_da(self) -> xr.DataArray:
|
|
954
|
-
"""
|
|
1029
|
+
"""The national scaled KPI data array."""
|
|
955
1030
|
if self._is_national_data:
|
|
956
1031
|
national_da = self.kpi_scaled_da.squeeze(constants.GEO, drop=True)
|
|
957
1032
|
national_da.name = constants.NATIONAL_KPI_SCALED
|
|
958
1033
|
else:
|
|
959
1034
|
# Note that kpi is summable by assumption.
|
|
960
1035
|
national_da = self._aggregate_and_scale_geo_da(
|
|
961
|
-
self.
|
|
1036
|
+
self._input_data.kpi,
|
|
962
1037
|
constants.NATIONAL_KPI_SCALED,
|
|
963
1038
|
transformers.CenteringAndScalingTransformer,
|
|
964
1039
|
)
|
|
@@ -966,7 +1041,7 @@ class EDAEngine:
|
|
|
966
1041
|
|
|
967
1042
|
@functools.cached_property
|
|
968
1043
|
def treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
969
|
-
"""
|
|
1044
|
+
"""A Dataset containing all scaled treatments and controls.
|
|
970
1045
|
|
|
971
1046
|
This includes media, RF impressions, organic media, organic RF impressions,
|
|
972
1047
|
non-media treatments, and control variables, all at the geo level.
|
|
@@ -987,7 +1062,7 @@ class EDAEngine:
|
|
|
987
1062
|
|
|
988
1063
|
@functools.cached_property
|
|
989
1064
|
def all_spend_ds(self) -> xr.Dataset:
|
|
990
|
-
"""
|
|
1065
|
+
"""A Dataset containing all spend data.
|
|
991
1066
|
|
|
992
1067
|
This includes media spend and rf spend.
|
|
993
1068
|
"""
|
|
@@ -1003,7 +1078,7 @@ class EDAEngine:
|
|
|
1003
1078
|
|
|
1004
1079
|
@functools.cached_property
|
|
1005
1080
|
def national_all_spend_ds(self) -> xr.Dataset:
|
|
1006
|
-
"""
|
|
1081
|
+
"""A Dataset containing all national spend data.
|
|
1007
1082
|
|
|
1008
1083
|
This includes media spend and rf spend.
|
|
1009
1084
|
"""
|
|
@@ -1019,14 +1094,14 @@ class EDAEngine:
|
|
|
1019
1094
|
|
|
1020
1095
|
@functools.cached_property
|
|
1021
1096
|
def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
1022
|
-
"""
|
|
1097
|
+
"""A stacked DataArray of treatment_control_scaled_ds."""
|
|
1023
1098
|
da = stack_variables(self.treatment_control_scaled_ds)
|
|
1024
1099
|
da.name = constants.TREATMENT_CONTROL_SCALED
|
|
1025
1100
|
return da
|
|
1026
1101
|
|
|
1027
1102
|
@functools.cached_property
|
|
1028
1103
|
def national_treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
1029
|
-
"""
|
|
1104
|
+
"""A Dataset containing all scaled treatments and controls.
|
|
1030
1105
|
|
|
1031
1106
|
This includes media, RF impressions, organic media, organic RF impressions,
|
|
1032
1107
|
non-media treatments, and control variables, all at the national level.
|
|
@@ -1047,14 +1122,14 @@ class EDAEngine:
|
|
|
1047
1122
|
|
|
1048
1123
|
@functools.cached_property
|
|
1049
1124
|
def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
1050
|
-
"""
|
|
1125
|
+
"""A stacked DataArray of national_treatment_control_scaled_ds."""
|
|
1051
1126
|
da = stack_variables(self.national_treatment_control_scaled_ds)
|
|
1052
1127
|
da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
|
|
1053
1128
|
return da
|
|
1054
1129
|
|
|
1055
1130
|
@functools.cached_property
|
|
1056
1131
|
def treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
|
|
1057
|
-
"""
|
|
1132
|
+
"""A Dataset of scaled treatments excluding non-media."""
|
|
1058
1133
|
return self.treatment_control_scaled_ds.drop_dims(
|
|
1059
1134
|
[constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
|
|
1060
1135
|
errors='ignore',
|
|
@@ -1062,15 +1137,34 @@ class EDAEngine:
|
|
|
1062
1137
|
|
|
1063
1138
|
@functools.cached_property
|
|
1064
1139
|
def national_treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
|
|
1065
|
-
"""
|
|
1140
|
+
"""A Dataset of national scaled treatments excluding non-media."""
|
|
1066
1141
|
return self.national_treatment_control_scaled_ds.drop_dims(
|
|
1067
1142
|
[constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
|
|
1068
1143
|
errors='ignore',
|
|
1069
1144
|
)
|
|
1070
1145
|
|
|
1146
|
+
@functools.cached_property
|
|
1147
|
+
def controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
|
|
1148
|
+
"""A Dataset of scaled controls and non-media treatments."""
|
|
1149
|
+
return _get_vars_from_dataset(
|
|
1150
|
+
self.treatment_control_scaled_ds,
|
|
1151
|
+
[constants.CONTROLS_SCALED, constants.NON_MEDIA_TREATMENTS_SCALED],
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
@functools.cached_property
|
|
1155
|
+
def national_controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
|
|
1156
|
+
"""A Dataset of national scaled controls and non-media treatments."""
|
|
1157
|
+
return _get_vars_from_dataset(
|
|
1158
|
+
self.national_treatment_control_scaled_ds,
|
|
1159
|
+
[
|
|
1160
|
+
constants.NATIONAL_CONTROLS_SCALED,
|
|
1161
|
+
constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
|
|
1162
|
+
],
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1071
1165
|
@functools.cached_property
|
|
1072
1166
|
def all_reach_scaled_da(self) -> xr.DataArray | None:
|
|
1073
|
-
"""
|
|
1167
|
+
"""A DataArray containing all scaled reach data.
|
|
1074
1168
|
|
|
1075
1169
|
This includes both paid and organic reach, concatenated along the RF_CHANNEL
|
|
1076
1170
|
dimension.
|
|
@@ -1096,7 +1190,7 @@ class EDAEngine:
|
|
|
1096
1190
|
|
|
1097
1191
|
@functools.cached_property
|
|
1098
1192
|
def all_freq_da(self) -> xr.DataArray | None:
|
|
1099
|
-
"""
|
|
1193
|
+
"""A DataArray containing all frequency data.
|
|
1100
1194
|
|
|
1101
1195
|
This includes both paid and organic frequency, concatenated along the
|
|
1102
1196
|
RF_CHANNEL dimension.
|
|
@@ -1122,7 +1216,7 @@ class EDAEngine:
|
|
|
1122
1216
|
|
|
1123
1217
|
@functools.cached_property
|
|
1124
1218
|
def national_all_reach_scaled_da(self) -> xr.DataArray | None:
|
|
1125
|
-
"""
|
|
1219
|
+
"""A DataArray containing all national-level scaled reach data.
|
|
1126
1220
|
|
|
1127
1221
|
This includes both paid and organic reach, concatenated along the
|
|
1128
1222
|
RF_CHANNEL dimension.
|
|
@@ -1149,7 +1243,7 @@ class EDAEngine:
|
|
|
1149
1243
|
|
|
1150
1244
|
@functools.cached_property
|
|
1151
1245
|
def national_all_freq_da(self) -> xr.DataArray | None:
|
|
1152
|
-
"""
|
|
1246
|
+
"""A DataArray containing all national-level frequency data.
|
|
1153
1247
|
|
|
1154
1248
|
This includes both paid and organic frequency, concatenated along the
|
|
1155
1249
|
RF_CHANNEL dimension.
|
|
@@ -1202,7 +1296,7 @@ class EDAEngine:
|
|
|
1202
1296
|
def _critical_checks(
|
|
1203
1297
|
self,
|
|
1204
1298
|
) -> list[tuple[_NamedEDACheckCallable, eda_outcome.EDACheckType]]:
|
|
1205
|
-
"""
|
|
1299
|
+
"""A list of critical checks to be performed."""
|
|
1206
1300
|
checks = [
|
|
1207
1301
|
(
|
|
1208
1302
|
self.check_overall_kpi_invariability,
|
|
@@ -1221,19 +1315,21 @@ class EDAEngine:
|
|
|
1221
1315
|
# This should not happen. If it does, it means this function is mis-used.
|
|
1222
1316
|
if constants.MEDIA_TIME not in da.coords:
|
|
1223
1317
|
raise ValueError(
|
|
1224
|
-
f'Variable does not have a media time coordinate: {da.name}.'
|
|
1318
|
+
f'Variable does not have a media time coordinate: {da.name!r}.'
|
|
1225
1319
|
)
|
|
1226
1320
|
|
|
1227
|
-
start = self.
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1321
|
+
start = self._model_context.n_media_times - self._model_context.n_times
|
|
1322
|
+
return (
|
|
1323
|
+
da.copy()
|
|
1324
|
+
.isel({constants.MEDIA_TIME: slice(start, None)})
|
|
1325
|
+
.rename({constants.MEDIA_TIME: constants.TIME})
|
|
1326
|
+
)
|
|
1231
1327
|
|
|
1232
1328
|
def _scale_xarray(
|
|
1233
1329
|
self,
|
|
1234
1330
|
xarray: xr.DataArray,
|
|
1235
|
-
transformer_class:
|
|
1236
|
-
population:
|
|
1331
|
+
transformer_class: type[transformers.TensorTransformer] | None,
|
|
1332
|
+
population: backend.Tensor | None = None,
|
|
1237
1333
|
) -> xr.DataArray:
|
|
1238
1334
|
"""Scales xarray values with a TensorTransformer."""
|
|
1239
1335
|
da = xarray.copy()
|
|
@@ -1285,7 +1381,9 @@ class EDAEngine:
|
|
|
1285
1381
|
agg_results = []
|
|
1286
1382
|
for var_name in geo_da[channel_dim].values:
|
|
1287
1383
|
var_data = geo_da.sel({channel_dim: var_name})
|
|
1288
|
-
agg_func = da_var_agg_map.get(
|
|
1384
|
+
agg_func = da_var_agg_map.get(
|
|
1385
|
+
var_name, eda_constants.DEFAULT_DA_VAR_AGG_FUNCTION
|
|
1386
|
+
)
|
|
1289
1387
|
# Apply the aggregation function over the GEO dimension
|
|
1290
1388
|
aggregated_data = var_data.reduce(
|
|
1291
1389
|
agg_func, dim=constants.GEO, keepdims=keepdims
|
|
@@ -1299,9 +1397,9 @@ class EDAEngine:
|
|
|
1299
1397
|
self,
|
|
1300
1398
|
geo_da: xr.DataArray,
|
|
1301
1399
|
national_da_name: str,
|
|
1302
|
-
transformer_class:
|
|
1303
|
-
channel_dim:
|
|
1304
|
-
da_var_agg_map:
|
|
1400
|
+
transformer_class: type[transformers.TensorTransformer] | None,
|
|
1401
|
+
channel_dim: str | None = None,
|
|
1402
|
+
da_var_agg_map: eda_spec.AggregationMap | None = None,
|
|
1305
1403
|
) -> xr.DataArray:
|
|
1306
1404
|
"""Aggregate geo-level xr.DataArray to national level and then scale values.
|
|
1307
1405
|
|
|
@@ -1351,11 +1449,11 @@ class EDAEngine:
|
|
|
1351
1449
|
"""Get impressions and frequencies data arrays for RF channels."""
|
|
1352
1450
|
if is_organic:
|
|
1353
1451
|
scaled_reach_values = (
|
|
1354
|
-
self.
|
|
1452
|
+
self._model_context.organic_rf_tensors.organic_reach_scaled
|
|
1355
1453
|
)
|
|
1356
1454
|
names = _ORGANIC_RF_NAMES
|
|
1357
1455
|
else:
|
|
1358
|
-
scaled_reach_values = self.
|
|
1456
|
+
scaled_reach_values = self._model_context.rf_tensors.reach_scaled
|
|
1359
1457
|
names = _RF_NAMES
|
|
1360
1458
|
|
|
1361
1459
|
reach_scaled_da = _data_array_like(
|
|
@@ -1433,7 +1531,7 @@ class EDAEngine:
|
|
|
1433
1531
|
impressions_scaled_da = self._scale_xarray(
|
|
1434
1532
|
impressions_raw_da,
|
|
1435
1533
|
transformers.MediaTransformer,
|
|
1436
|
-
population=self.
|
|
1534
|
+
population=self._model_context.population,
|
|
1437
1535
|
)
|
|
1438
1536
|
impressions_scaled_da.name = names.impressions_scaled
|
|
1439
1537
|
|
|
@@ -1491,9 +1589,16 @@ class EDAEngine:
|
|
|
1491
1589
|
overall_corr_mat, overall_extreme_corr_var_pairs_df = (
|
|
1492
1590
|
self._pairwise_corr_for_geo_data(
|
|
1493
1591
|
dims=[constants.GEO, constants.TIME],
|
|
1494
|
-
extreme_corr_threshold=
|
|
1592
|
+
extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1495
1593
|
)
|
|
1496
1594
|
)
|
|
1595
|
+
overall_artifact = eda_outcome.PairwiseCorrArtifact(
|
|
1596
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
1597
|
+
corr_matrix=overall_corr_mat,
|
|
1598
|
+
extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
|
|
1599
|
+
extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1497
1602
|
if not overall_extreme_corr_var_pairs_df.empty:
|
|
1498
1603
|
var_pairs = overall_extreme_corr_var_pairs_df.index.to_list()
|
|
1499
1604
|
findings.append(
|
|
@@ -1505,21 +1610,33 @@ class EDAEngine:
|
|
|
1505
1610
|
' variables, please remove one of the variables from the'
|
|
1506
1611
|
f' model.\nPairs with perfect correlation: {var_pairs}'
|
|
1507
1612
|
),
|
|
1613
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
1614
|
+
associated_artifact=overall_artifact,
|
|
1508
1615
|
)
|
|
1509
1616
|
)
|
|
1510
1617
|
|
|
1511
1618
|
geo_corr_mat, geo_extreme_corr_var_pairs_df = (
|
|
1512
1619
|
self._pairwise_corr_for_geo_data(
|
|
1513
1620
|
dims=constants.TIME,
|
|
1514
|
-
extreme_corr_threshold=
|
|
1621
|
+
extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1515
1622
|
)
|
|
1516
1623
|
)
|
|
1517
|
-
#
|
|
1518
|
-
#
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1624
|
+
# Pairs that cause overall level findings are very likely to cause geo
|
|
1625
|
+
# level findings as well, so we exclude them when determining geo-level
|
|
1626
|
+
# findings. This is to avoid over-reporting findings.
|
|
1627
|
+
overall_pairs_index = overall_extreme_corr_var_pairs_df.index
|
|
1628
|
+
is_in_overall = geo_extreme_corr_var_pairs_df.index.droplevel(
|
|
1629
|
+
constants.GEO
|
|
1630
|
+
).isin(overall_pairs_index)
|
|
1631
|
+
geo_df_for_attention = geo_extreme_corr_var_pairs_df[~is_in_overall]
|
|
1632
|
+
geo_artifact = eda_outcome.PairwiseCorrArtifact(
|
|
1633
|
+
level=eda_outcome.AnalysisLevel.GEO,
|
|
1634
|
+
corr_matrix=geo_corr_mat,
|
|
1635
|
+
extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
|
|
1636
|
+
extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
if not geo_df_for_attention.empty:
|
|
1523
1640
|
findings.append(
|
|
1524
1641
|
eda_outcome.EDAFinding(
|
|
1525
1642
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
@@ -1529,6 +1646,8 @@ class EDAEngine:
|
|
|
1529
1646
|
' variables if they also have high pairwise correlations in'
|
|
1530
1647
|
' other geos.'
|
|
1531
1648
|
),
|
|
1649
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
1650
|
+
associated_artifact=geo_artifact,
|
|
1532
1651
|
)
|
|
1533
1652
|
)
|
|
1534
1653
|
|
|
@@ -1538,34 +1657,15 @@ class EDAEngine:
|
|
|
1538
1657
|
findings.append(
|
|
1539
1658
|
eda_outcome.EDAFinding(
|
|
1540
1659
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1541
|
-
explanation=(
|
|
1542
|
-
|
|
1543
|
-
' high pairwise correlation may cause model identifiability'
|
|
1544
|
-
' and convergence issues. Consider combining the variables if'
|
|
1545
|
-
' high correlation exists.'
|
|
1546
|
-
),
|
|
1660
|
+
explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
|
|
1661
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1547
1662
|
)
|
|
1548
1663
|
)
|
|
1549
1664
|
|
|
1550
|
-
pairwise_corr_artifacts = [
|
|
1551
|
-
eda_outcome.PairwiseCorrArtifact(
|
|
1552
|
-
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
1553
|
-
corr_matrix=overall_corr_mat,
|
|
1554
|
-
extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
|
|
1555
|
-
extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1556
|
-
),
|
|
1557
|
-
eda_outcome.PairwiseCorrArtifact(
|
|
1558
|
-
level=eda_outcome.AnalysisLevel.GEO,
|
|
1559
|
-
corr_matrix=geo_corr_mat,
|
|
1560
|
-
extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
|
|
1561
|
-
extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1562
|
-
),
|
|
1563
|
-
]
|
|
1564
|
-
|
|
1565
1665
|
return eda_outcome.EDAOutcome(
|
|
1566
1666
|
check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
|
|
1567
1667
|
findings=findings,
|
|
1568
|
-
analysis_artifacts=
|
|
1668
|
+
analysis_artifacts=[overall_artifact, geo_artifact],
|
|
1569
1669
|
)
|
|
1570
1670
|
|
|
1571
1671
|
def check_national_pairwise_corr(
|
|
@@ -1582,7 +1682,14 @@ class EDAEngine:
|
|
|
1582
1682
|
self._stacked_national_treatment_control_scaled_da, dims=constants.TIME
|
|
1583
1683
|
)
|
|
1584
1684
|
extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
|
|
1585
|
-
corr_mat,
|
|
1685
|
+
corr_mat, eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD
|
|
1686
|
+
)
|
|
1687
|
+
|
|
1688
|
+
artifact = eda_outcome.PairwiseCorrArtifact(
|
|
1689
|
+
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1690
|
+
corr_matrix=corr_mat,
|
|
1691
|
+
extreme_corr_var_pairs=extreme_corr_var_pairs_df,
|
|
1692
|
+
extreme_corr_threshold=eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD,
|
|
1586
1693
|
)
|
|
1587
1694
|
|
|
1588
1695
|
if not extreme_corr_var_pairs_df.empty:
|
|
@@ -1596,33 +1703,23 @@ class EDAEngine:
|
|
|
1596
1703
|
' variables, please remove one of the variables from the'
|
|
1597
1704
|
f' model.\nPairs with perfect correlation: {var_pairs}'
|
|
1598
1705
|
),
|
|
1706
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
1707
|
+
associated_artifact=artifact,
|
|
1599
1708
|
)
|
|
1600
1709
|
)
|
|
1601
1710
|
else:
|
|
1602
1711
|
findings.append(
|
|
1603
1712
|
eda_outcome.EDAFinding(
|
|
1604
1713
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1605
|
-
explanation=(
|
|
1606
|
-
|
|
1607
|
-
' high pairwise correlation may cause model identifiability'
|
|
1608
|
-
' and convergence issues. Consider combining the variables if'
|
|
1609
|
-
' high correlation exists.'
|
|
1610
|
-
),
|
|
1714
|
+
explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
|
|
1715
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1611
1716
|
)
|
|
1612
1717
|
)
|
|
1613
1718
|
|
|
1614
|
-
pairwise_corr_artifacts = [
|
|
1615
|
-
eda_outcome.PairwiseCorrArtifact(
|
|
1616
|
-
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1617
|
-
corr_matrix=corr_mat,
|
|
1618
|
-
extreme_corr_var_pairs=extreme_corr_var_pairs_df,
|
|
1619
|
-
extreme_corr_threshold=_NATIONAL_PAIRWISE_CORR_THRESHOLD,
|
|
1620
|
-
)
|
|
1621
|
-
]
|
|
1622
1719
|
return eda_outcome.EDAOutcome(
|
|
1623
1720
|
check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
|
|
1624
1721
|
findings=findings,
|
|
1625
|
-
analysis_artifacts=
|
|
1722
|
+
analysis_artifacts=[artifact],
|
|
1626
1723
|
)
|
|
1627
1724
|
|
|
1628
1725
|
def check_pairwise_corr(
|
|
@@ -1643,20 +1740,14 @@ class EDAEngine:
|
|
|
1643
1740
|
data: xr.DataArray,
|
|
1644
1741
|
level: eda_outcome.AnalysisLevel,
|
|
1645
1742
|
zero_std_message: str,
|
|
1743
|
+
outlier_message: str,
|
|
1646
1744
|
) -> tuple[
|
|
1647
|
-
eda_outcome.EDAFinding
|
|
1745
|
+
list[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
|
|
1648
1746
|
]:
|
|
1649
1747
|
"""Helper to check standard deviation."""
|
|
1650
1748
|
std_ds = _calculate_std(data)
|
|
1651
1749
|
outlier_df = _calculate_outliers(data)
|
|
1652
1750
|
|
|
1653
|
-
finding = None
|
|
1654
|
-
if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
|
|
1655
|
-
finding = eda_outcome.EDAFinding(
|
|
1656
|
-
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1657
|
-
explanation=zero_std_message,
|
|
1658
|
-
)
|
|
1659
|
-
|
|
1660
1751
|
artifact = eda_outcome.StandardDeviationArtifact(
|
|
1661
1752
|
variable=str(data.name),
|
|
1662
1753
|
level=level,
|
|
@@ -1664,7 +1755,31 @@ class EDAEngine:
|
|
|
1664
1755
|
outlier_df=outlier_df,
|
|
1665
1756
|
)
|
|
1666
1757
|
|
|
1667
|
-
|
|
1758
|
+
findings = []
|
|
1759
|
+
if (
|
|
1760
|
+
std_ds[eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME]
|
|
1761
|
+
< eda_constants.STD_THRESHOLD
|
|
1762
|
+
).any():
|
|
1763
|
+
findings.append(
|
|
1764
|
+
eda_outcome.EDAFinding(
|
|
1765
|
+
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1766
|
+
explanation=zero_std_message,
|
|
1767
|
+
finding_cause=eda_outcome.FindingCause.VARIABILITY,
|
|
1768
|
+
associated_artifact=artifact,
|
|
1769
|
+
)
|
|
1770
|
+
)
|
|
1771
|
+
|
|
1772
|
+
if not outlier_df.empty:
|
|
1773
|
+
findings.append(
|
|
1774
|
+
eda_outcome.EDAFinding(
|
|
1775
|
+
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1776
|
+
explanation=outlier_message,
|
|
1777
|
+
finding_cause=eda_outcome.FindingCause.OUTLIER,
|
|
1778
|
+
associated_artifact=artifact,
|
|
1779
|
+
)
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
return findings, artifact
|
|
1668
1783
|
|
|
1669
1784
|
def check_geo_std(
|
|
1670
1785
|
self,
|
|
@@ -1685,6 +1800,10 @@ class EDAEngine:
|
|
|
1685
1800
|
' variable for these geos. Please review the input data,'
|
|
1686
1801
|
' and/or consider grouping these geos together.'
|
|
1687
1802
|
),
|
|
1803
|
+
(
|
|
1804
|
+
'There are outliers in the scaled KPI in certain geos.'
|
|
1805
|
+
' Please check for any possible data errors.'
|
|
1806
|
+
),
|
|
1688
1807
|
),
|
|
1689
1808
|
(
|
|
1690
1809
|
self._stacked_treatment_control_scaled_da,
|
|
@@ -1695,6 +1814,11 @@ class EDAEngine:
|
|
|
1695
1814
|
' consider combining them to mitigate potential model'
|
|
1696
1815
|
' identifiability and convergence issues.'
|
|
1697
1816
|
),
|
|
1817
|
+
(
|
|
1818
|
+
'There are outliers in the scaled treatment or control'
|
|
1819
|
+
' variables in certain geos. Please check for any possible data'
|
|
1820
|
+
' errors.'
|
|
1821
|
+
),
|
|
1698
1822
|
),
|
|
1699
1823
|
(
|
|
1700
1824
|
self.all_reach_scaled_da,
|
|
@@ -1705,6 +1829,11 @@ class EDAEngine:
|
|
|
1705
1829
|
' geos, consider modeling them as impression-based channels'
|
|
1706
1830
|
' instead by taking reach * frequency.'
|
|
1707
1831
|
),
|
|
1832
|
+
(
|
|
1833
|
+
'There are outliers in the scaled reach values of the RF or'
|
|
1834
|
+
' Organic RF channels in certain geos. Please check for any'
|
|
1835
|
+
' possible data errors.'
|
|
1836
|
+
),
|
|
1708
1837
|
),
|
|
1709
1838
|
(
|
|
1710
1839
|
self.all_freq_da,
|
|
@@ -1715,30 +1844,34 @@ class EDAEngine:
|
|
|
1715
1844
|
' geos, consider modeling them as impression-based channels'
|
|
1716
1845
|
' instead by taking reach * frequency.'
|
|
1717
1846
|
),
|
|
1847
|
+
(
|
|
1848
|
+
'There are outliers in the scaled frequency values of the RF or'
|
|
1849
|
+
' Organic RF channels in certain geos. Please check for any'
|
|
1850
|
+
' possible data errors.'
|
|
1851
|
+
),
|
|
1718
1852
|
),
|
|
1719
1853
|
]
|
|
1720
1854
|
|
|
1721
|
-
for data_da,
|
|
1855
|
+
for data_da, std_message, outlier_message in checks:
|
|
1722
1856
|
if data_da is None:
|
|
1723
1857
|
continue
|
|
1724
|
-
|
|
1858
|
+
current_findings, artifact = self._check_std(
|
|
1725
1859
|
level=eda_outcome.AnalysisLevel.GEO,
|
|
1726
1860
|
data=data_da,
|
|
1727
|
-
zero_std_message=
|
|
1861
|
+
zero_std_message=std_message,
|
|
1862
|
+
outlier_message=outlier_message,
|
|
1728
1863
|
)
|
|
1729
1864
|
artifacts.append(artifact)
|
|
1730
|
-
if
|
|
1731
|
-
findings.
|
|
1865
|
+
if current_findings:
|
|
1866
|
+
findings.extend(current_findings)
|
|
1732
1867
|
|
|
1733
1868
|
# Add an INFO finding if no findings were added.
|
|
1734
1869
|
if not findings:
|
|
1735
1870
|
findings.append(
|
|
1736
1871
|
eda_outcome.EDAFinding(
|
|
1737
1872
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1738
|
-
explanation=
|
|
1739
|
-
|
|
1740
|
-
' deviation.'
|
|
1741
|
-
),
|
|
1873
|
+
explanation='Please review the computed standard deviations.',
|
|
1874
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1742
1875
|
)
|
|
1743
1876
|
)
|
|
1744
1877
|
|
|
@@ -1765,6 +1898,10 @@ class EDAEngine:
|
|
|
1765
1898
|
' the input data, and/or reconsider the feasibility of model'
|
|
1766
1899
|
' fitting with this dataset.'
|
|
1767
1900
|
),
|
|
1901
|
+
(
|
|
1902
|
+
'There are outliers in the scaled KPI.'
|
|
1903
|
+
' Please check for any possible data errors.'
|
|
1904
|
+
),
|
|
1768
1905
|
),
|
|
1769
1906
|
(
|
|
1770
1907
|
self._stacked_national_treatment_control_scaled_da,
|
|
@@ -1776,6 +1913,10 @@ class EDAEngine:
|
|
|
1776
1913
|
' Please review the input data, and/or consider combining these'
|
|
1777
1914
|
' variables to mitigate sparsity.'
|
|
1778
1915
|
),
|
|
1916
|
+
(
|
|
1917
|
+
'There are outliers in the scaled treatment or control'
|
|
1918
|
+
' variables. Please check for any possible data errors.'
|
|
1919
|
+
),
|
|
1779
1920
|
),
|
|
1780
1921
|
(
|
|
1781
1922
|
self.national_all_reach_scaled_da,
|
|
@@ -1785,6 +1926,11 @@ class EDAEngine:
|
|
|
1785
1926
|
' Consider modeling these RF channels as impression-based'
|
|
1786
1927
|
' channels instead.'
|
|
1787
1928
|
),
|
|
1929
|
+
(
|
|
1930
|
+
'There are outliers in the scaled reach values of the RF or'
|
|
1931
|
+
' Organic RF channels. Please check for any possible data'
|
|
1932
|
+
' errors.'
|
|
1933
|
+
),
|
|
1788
1934
|
),
|
|
1789
1935
|
(
|
|
1790
1936
|
self.national_all_freq_da,
|
|
@@ -1794,30 +1940,34 @@ class EDAEngine:
|
|
|
1794
1940
|
' Consider modeling these RF channels as impression-based'
|
|
1795
1941
|
' channels instead.'
|
|
1796
1942
|
),
|
|
1943
|
+
(
|
|
1944
|
+
'There are outliers in the scaled frequency values of the RF or'
|
|
1945
|
+
' Organic RF channels. Please check for any possible data'
|
|
1946
|
+
' errors.'
|
|
1947
|
+
),
|
|
1797
1948
|
),
|
|
1798
1949
|
]
|
|
1799
1950
|
|
|
1800
|
-
for data_da,
|
|
1951
|
+
for data_da, std_message, outlier_message in checks:
|
|
1801
1952
|
if data_da is None:
|
|
1802
1953
|
continue
|
|
1803
|
-
|
|
1954
|
+
current_findings, artifact = self._check_std(
|
|
1804
1955
|
data=data_da,
|
|
1805
1956
|
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1806
|
-
zero_std_message=
|
|
1957
|
+
zero_std_message=std_message,
|
|
1958
|
+
outlier_message=outlier_message,
|
|
1807
1959
|
)
|
|
1808
1960
|
artifacts.append(artifact)
|
|
1809
|
-
if
|
|
1810
|
-
findings.
|
|
1961
|
+
if current_findings:
|
|
1962
|
+
findings.extend(current_findings)
|
|
1811
1963
|
|
|
1812
1964
|
# Add an INFO finding if no findings were added.
|
|
1813
1965
|
if not findings:
|
|
1814
1966
|
findings.append(
|
|
1815
1967
|
eda_outcome.EDAFinding(
|
|
1816
1968
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1817
|
-
explanation=
|
|
1818
|
-
|
|
1819
|
-
' deviation.'
|
|
1820
|
-
),
|
|
1969
|
+
explanation='Please review the computed standard deviations.',
|
|
1970
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1821
1971
|
)
|
|
1822
1972
|
)
|
|
1823
1973
|
|
|
@@ -1841,7 +1991,15 @@ class EDAEngine:
|
|
|
1841
1991
|
return self.check_geo_std()
|
|
1842
1992
|
|
|
1843
1993
|
def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1844
|
-
"""
|
|
1994
|
+
"""Checks geo variance inflation factor among treatments and controls.
|
|
1995
|
+
|
|
1996
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
1997
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
1998
|
+
|
|
1999
|
+
Returns:
|
|
2000
|
+
An EDAOutcome object with findings and result values.
|
|
2001
|
+
"""
|
|
2002
|
+
|
|
1845
2003
|
if self._is_national_data:
|
|
1846
2004
|
raise ValueError(
|
|
1847
2005
|
'Geo-level VIF checks are not applicable for national models.'
|
|
@@ -1851,12 +2009,12 @@ class EDAEngine:
|
|
|
1851
2009
|
tc_da = self._stacked_treatment_control_scaled_da
|
|
1852
2010
|
overall_threshold = self._spec.vif_spec.overall_threshold
|
|
1853
2011
|
|
|
1854
|
-
overall_vif_da = _calculate_vif(tc_da,
|
|
2012
|
+
overall_vif_da = _calculate_vif(tc_da, eda_constants.VARIABLE)
|
|
1855
2013
|
extreme_overall_vif_da = overall_vif_da.where(
|
|
1856
2014
|
overall_vif_da > overall_threshold
|
|
1857
2015
|
)
|
|
1858
2016
|
extreme_overall_vif_df = extreme_overall_vif_da.to_dataframe(
|
|
1859
|
-
name=
|
|
2017
|
+
name=eda_constants.VIF_COL_NAME
|
|
1860
2018
|
).dropna()
|
|
1861
2019
|
|
|
1862
2020
|
overall_vif_artifact = eda_outcome.VIFArtifact(
|
|
@@ -1868,11 +2026,11 @@ class EDAEngine:
|
|
|
1868
2026
|
# Geo level VIF check.
|
|
1869
2027
|
geo_threshold = self._spec.vif_spec.geo_threshold
|
|
1870
2028
|
geo_vif_da = tc_da.groupby(constants.GEO).map(
|
|
1871
|
-
lambda x: _calculate_vif(x,
|
|
2029
|
+
lambda x: _calculate_vif(x, eda_constants.VARIABLE)
|
|
1872
2030
|
)
|
|
1873
2031
|
extreme_geo_vif_da = geo_vif_da.where(geo_vif_da > geo_threshold)
|
|
1874
2032
|
extreme_geo_vif_df = extreme_geo_vif_da.to_dataframe(
|
|
1875
|
-
name=
|
|
2033
|
+
name=eda_constants.VIF_COL_NAME
|
|
1876
2034
|
).dropna()
|
|
1877
2035
|
|
|
1878
2036
|
geo_vif_artifact = eda_outcome.VIFArtifact(
|
|
@@ -1883,33 +2041,47 @@ class EDAEngine:
|
|
|
1883
2041
|
|
|
1884
2042
|
findings = []
|
|
1885
2043
|
if not extreme_overall_vif_df.empty:
|
|
1886
|
-
|
|
2044
|
+
high_vif_vars_message = (
|
|
2045
|
+
'\nVariables with extreme VIF:'
|
|
2046
|
+
f' {extreme_overall_vif_df.index.to_list()}'
|
|
2047
|
+
)
|
|
1887
2048
|
findings.append(
|
|
1888
2049
|
eda_outcome.EDAFinding(
|
|
1889
2050
|
severity=eda_outcome.EDASeverity.ERROR,
|
|
1890
|
-
explanation=(
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
' is a linear combination of other variables. Otherwise,'
|
|
1895
|
-
' consider combining variables.\n'
|
|
1896
|
-
f'Variables with extreme VIF: {high_vif_vars}'
|
|
2051
|
+
explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
|
|
2052
|
+
threshold=overall_threshold,
|
|
2053
|
+
aggregation='times and geos',
|
|
2054
|
+
additional_info=high_vif_vars_message,
|
|
1897
2055
|
),
|
|
2056
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
2057
|
+
associated_artifact=overall_vif_artifact,
|
|
1898
2058
|
)
|
|
1899
2059
|
)
|
|
1900
|
-
|
|
2060
|
+
|
|
2061
|
+
# Variables that cause overall level findings are very likely to cause
|
|
2062
|
+
# geo-level findings as well, so we exclude them when determining
|
|
2063
|
+
# geo-level findings. This is to avoid over-reporting findings.
|
|
2064
|
+
overall_vars_index = extreme_overall_vif_df.index
|
|
2065
|
+
is_in_overall = extreme_geo_vif_df.index.get_level_values(
|
|
2066
|
+
eda_constants.VARIABLE
|
|
2067
|
+
).isin(overall_vars_index)
|
|
2068
|
+
geo_df_for_attention = extreme_geo_vif_df[~is_in_overall]
|
|
2069
|
+
|
|
2070
|
+
if not geo_df_for_attention.empty:
|
|
1901
2071
|
findings.append(
|
|
1902
2072
|
eda_outcome.EDAFinding(
|
|
1903
2073
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1904
2074
|
explanation=(
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
' high VIF in other geos.'
|
|
2075
|
+
eda_constants.MULTICOLLINEARITY_ATTENTION.format(
|
|
2076
|
+
threshold=geo_threshold, additional_info=''
|
|
2077
|
+
)
|
|
1909
2078
|
),
|
|
2079
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
2080
|
+
associated_artifact=geo_vif_artifact,
|
|
1910
2081
|
)
|
|
1911
2082
|
)
|
|
1912
|
-
|
|
2083
|
+
|
|
2084
|
+
if not findings:
|
|
1913
2085
|
findings.append(
|
|
1914
2086
|
eda_outcome.EDAFinding(
|
|
1915
2087
|
severity=eda_outcome.EDASeverity.INFO,
|
|
@@ -1919,6 +2091,7 @@ class EDAEngine:
|
|
|
1919
2091
|
' jeopardize model identifiability and model convergence.'
|
|
1920
2092
|
' Consider combining the variables if high VIF occurs.'
|
|
1921
2093
|
),
|
|
2094
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1922
2095
|
)
|
|
1923
2096
|
)
|
|
1924
2097
|
|
|
@@ -1931,14 +2104,21 @@ class EDAEngine:
|
|
|
1931
2104
|
def check_national_vif(
|
|
1932
2105
|
self,
|
|
1933
2106
|
) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1934
|
-
"""
|
|
2107
|
+
"""Checks national variance inflation factor among treatments and controls.
|
|
2108
|
+
|
|
2109
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
2110
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
2111
|
+
|
|
2112
|
+
Returns:
|
|
2113
|
+
An EDAOutcome object with findings and result values.
|
|
2114
|
+
"""
|
|
1935
2115
|
national_tc_da = self._stacked_national_treatment_control_scaled_da
|
|
1936
2116
|
national_threshold = self._spec.vif_spec.national_threshold
|
|
1937
|
-
national_vif_da = _calculate_vif(national_tc_da,
|
|
2117
|
+
national_vif_da = _calculate_vif(national_tc_da, eda_constants.VARIABLE)
|
|
1938
2118
|
|
|
1939
2119
|
extreme_national_vif_df = (
|
|
1940
2120
|
national_vif_da.where(national_vif_da > national_threshold)
|
|
1941
|
-
.to_dataframe(name=
|
|
2121
|
+
.to_dataframe(name=eda_constants.VIF_COL_NAME)
|
|
1942
2122
|
.dropna()
|
|
1943
2123
|
)
|
|
1944
2124
|
national_vif_artifact = eda_outcome.VIFArtifact(
|
|
@@ -1949,18 +2129,20 @@ class EDAEngine:
|
|
|
1949
2129
|
|
|
1950
2130
|
findings = []
|
|
1951
2131
|
if not extreme_national_vif_df.empty:
|
|
1952
|
-
|
|
2132
|
+
high_vif_vars_message = (
|
|
2133
|
+
'\nVariables with extreme VIF:'
|
|
2134
|
+
f' {extreme_national_vif_df.index.to_list()}'
|
|
2135
|
+
)
|
|
1953
2136
|
findings.append(
|
|
1954
2137
|
eda_outcome.EDAFinding(
|
|
1955
2138
|
severity=eda_outcome.EDASeverity.ERROR,
|
|
1956
|
-
explanation=(
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
' linear combination of other variables. Otherwise, consider'
|
|
1961
|
-
' combining variables.\n'
|
|
1962
|
-
f'Variables with extreme VIF: {high_vif_vars}'
|
|
2139
|
+
explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
|
|
2140
|
+
threshold=national_threshold,
|
|
2141
|
+
aggregation='times',
|
|
2142
|
+
additional_info=high_vif_vars_message,
|
|
1963
2143
|
),
|
|
2144
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
2145
|
+
associated_artifact=national_vif_artifact,
|
|
1964
2146
|
)
|
|
1965
2147
|
)
|
|
1966
2148
|
else:
|
|
@@ -1973,6 +2155,7 @@ class EDAEngine:
|
|
|
1973
2155
|
' jeopardize model identifiability and model convergence.'
|
|
1974
2156
|
' Consider combining the variables if high VIF occurs.'
|
|
1975
2157
|
),
|
|
2158
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1976
2159
|
)
|
|
1977
2160
|
)
|
|
1978
2161
|
return eda_outcome.EDAOutcome(
|
|
@@ -1984,6 +2167,9 @@ class EDAEngine:
|
|
|
1984
2167
|
def check_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1985
2168
|
"""Computes variance inflation factor among treatments and controls.
|
|
1986
2169
|
|
|
2170
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
2171
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
2172
|
+
|
|
1987
2173
|
Returns:
|
|
1988
2174
|
An EDAOutcome object with findings and result values.
|
|
1989
2175
|
"""
|
|
@@ -1994,15 +2180,18 @@ class EDAEngine:
|
|
|
1994
2180
|
|
|
1995
2181
|
@property
|
|
1996
2182
|
def kpi_has_variability(self) -> bool:
|
|
1997
|
-
"""
|
|
2183
|
+
"""Whether the KPI has variability across geos and times."""
|
|
1998
2184
|
return (
|
|
1999
2185
|
self._overall_scaled_kpi_invariability_artifact.kpi_stdev.item()
|
|
2000
|
-
>=
|
|
2186
|
+
>= eda_constants.STD_THRESHOLD
|
|
2001
2187
|
)
|
|
2002
2188
|
|
|
2003
|
-
def check_overall_kpi_invariability(
|
|
2189
|
+
def check_overall_kpi_invariability(
|
|
2190
|
+
self,
|
|
2191
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.KpiInvariabilityArtifact]:
|
|
2004
2192
|
"""Checks if the KPI is constant across all geos and times."""
|
|
2005
|
-
|
|
2193
|
+
artifact = self._overall_scaled_kpi_invariability_artifact
|
|
2194
|
+
kpi = artifact.kpi_da.name
|
|
2006
2195
|
geo_text = '' if self._is_national_data else 'geos and '
|
|
2007
2196
|
|
|
2008
2197
|
if not self.kpi_has_variability:
|
|
@@ -2012,6 +2201,8 @@ class EDAEngine:
|
|
|
2012
2201
|
f'`{kpi}` is constant across all {geo_text}times, indicating no'
|
|
2013
2202
|
' signal in the data. Please fix this data error.'
|
|
2014
2203
|
),
|
|
2204
|
+
finding_cause=eda_outcome.FindingCause.VARIABILITY,
|
|
2205
|
+
associated_artifact=artifact,
|
|
2015
2206
|
)
|
|
2016
2207
|
else:
|
|
2017
2208
|
eda_finding = eda_outcome.EDAFinding(
|
|
@@ -2019,12 +2210,13 @@ class EDAEngine:
|
|
|
2019
2210
|
explanation=(
|
|
2020
2211
|
f'The {kpi} has variability across {geo_text}times in the data.'
|
|
2021
2212
|
),
|
|
2213
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2022
2214
|
)
|
|
2023
2215
|
|
|
2024
2216
|
return eda_outcome.EDAOutcome(
|
|
2025
2217
|
check_type=eda_outcome.EDACheckType.KPI_INVARIABILITY,
|
|
2026
2218
|
findings=[eda_finding],
|
|
2027
|
-
analysis_artifacts=[
|
|
2219
|
+
analysis_artifacts=[artifact],
|
|
2028
2220
|
)
|
|
2029
2221
|
|
|
2030
2222
|
def check_geo_cost_per_media_unit(
|
|
@@ -2081,30 +2273,249 @@ class EDAEngine:
|
|
|
2081
2273
|
|
|
2082
2274
|
return self.check_geo_cost_per_media_unit()
|
|
2083
2275
|
|
|
2084
|
-
def run_all_critical_checks(self) ->
|
|
2276
|
+
def run_all_critical_checks(self) -> eda_outcome.CriticalCheckEDAOutcomes:
|
|
2085
2277
|
"""Runs all critical EDA checks.
|
|
2086
2278
|
|
|
2087
2279
|
Critical checks are those that can result in EDASeverity.ERROR findings.
|
|
2088
2280
|
|
|
2089
2281
|
Returns:
|
|
2090
|
-
A
|
|
2282
|
+
A CriticalCheckEDAOutcomes object containing the results of all critical
|
|
2283
|
+
checks.
|
|
2091
2284
|
"""
|
|
2092
|
-
outcomes =
|
|
2285
|
+
outcomes = {}
|
|
2093
2286
|
for check, check_type in self._critical_checks:
|
|
2094
2287
|
try:
|
|
2095
|
-
outcomes
|
|
2288
|
+
outcomes[check_type] = check()
|
|
2096
2289
|
except Exception as e: # pylint: disable=broad-except
|
|
2097
2290
|
error_finding = eda_outcome.EDAFinding(
|
|
2098
2291
|
severity=eda_outcome.EDASeverity.ERROR,
|
|
2099
2292
|
explanation=(
|
|
2100
2293
|
f'An error occurred during running {check.__name__}: {e!r}'
|
|
2101
2294
|
),
|
|
2295
|
+
finding_cause=eda_outcome.FindingCause.RUNTIME_ERROR,
|
|
2296
|
+
)
|
|
2297
|
+
outcomes[check_type] = eda_outcome.EDAOutcome(
|
|
2298
|
+
check_type=check_type,
|
|
2299
|
+
findings=[error_finding],
|
|
2300
|
+
analysis_artifacts=[],
|
|
2102
2301
|
)
|
|
2103
|
-
|
|
2104
|
-
|
|
2105
|
-
|
|
2106
|
-
|
|
2107
|
-
|
|
2302
|
+
|
|
2303
|
+
return eda_outcome.CriticalCheckEDAOutcomes(
|
|
2304
|
+
kpi_invariability=outcomes[eda_outcome.EDACheckType.KPI_INVARIABILITY],
|
|
2305
|
+
multicollinearity=outcomes[eda_outcome.EDACheckType.MULTICOLLINEARITY],
|
|
2306
|
+
pairwise_correlation=outcomes[
|
|
2307
|
+
eda_outcome.EDACheckType.PAIRWISE_CORRELATION
|
|
2308
|
+
],
|
|
2309
|
+
)
|
|
2310
|
+
|
|
2311
|
+
def check_variable_geo_time_collinearity(
|
|
2312
|
+
self,
|
|
2313
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.VariableGeoTimeCollinearityArtifact]:
|
|
2314
|
+
"""Compute adjusted R-squared for treatments and controls vs geo and time.
|
|
2315
|
+
|
|
2316
|
+
These checks are applied to geo-level dataset only.
|
|
2317
|
+
|
|
2318
|
+
Returns:
|
|
2319
|
+
An EDAOutcome object containing a VariableGeoTimeCollinearityArtifact.
|
|
2320
|
+
The artifact includes a Dataset with 'rsquared_geo' and 'rsquared_time',
|
|
2321
|
+
showing the adjusted R-squared values for each treatment/control variable
|
|
2322
|
+
when regressed against 'geo' and 'time', respectively. If a variable is
|
|
2323
|
+
constant across geos or times, the corresponding 'rsquared_geo' or
|
|
2324
|
+
'rsquared_time' value will be NaN.
|
|
2325
|
+
"""
|
|
2326
|
+
if self._is_national_data:
|
|
2327
|
+
raise ValueError(
|
|
2328
|
+
'check_variable_geo_time_collinearity is not supported for national'
|
|
2329
|
+
' models.'
|
|
2330
|
+
)
|
|
2331
|
+
|
|
2332
|
+
grouped_da = self._stacked_treatment_control_scaled_da.groupby(
|
|
2333
|
+
eda_constants.VARIABLE
|
|
2334
|
+
)
|
|
2335
|
+
rsq_geo = grouped_da.map(_calc_adj_r2, args=(constants.GEO,))
|
|
2336
|
+
rsq_time = grouped_da.map(_calc_adj_r2, args=(constants.TIME,))
|
|
2337
|
+
|
|
2338
|
+
rsquared_ds = xr.Dataset({
|
|
2339
|
+
eda_constants.RSQUARED_GEO: rsq_geo,
|
|
2340
|
+
eda_constants.RSQUARED_TIME: rsq_time,
|
|
2341
|
+
})
|
|
2342
|
+
|
|
2343
|
+
artifact = eda_outcome.VariableGeoTimeCollinearityArtifact(
|
|
2344
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
2345
|
+
rsquared_ds=rsquared_ds,
|
|
2346
|
+
)
|
|
2347
|
+
findings = [
|
|
2348
|
+
eda_outcome.EDAFinding(
|
|
2349
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2350
|
+
explanation=eda_constants.R_SQUARED_TIME_INFO,
|
|
2351
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2352
|
+
),
|
|
2353
|
+
eda_outcome.EDAFinding(
|
|
2354
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2355
|
+
explanation=eda_constants.R_SQUARED_GEO_INFO,
|
|
2356
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2357
|
+
),
|
|
2358
|
+
]
|
|
2359
|
+
|
|
2360
|
+
return eda_outcome.EDAOutcome(
|
|
2361
|
+
check_type=eda_outcome.EDACheckType.VARIABLE_GEO_TIME_COLLINEARITY,
|
|
2362
|
+
findings=findings,
|
|
2363
|
+
analysis_artifacts=[artifact],
|
|
2364
|
+
)
|
|
2365
|
+
|
|
2366
|
+
def _calculate_population_corr(
|
|
2367
|
+
self, ds: xr.Dataset, *, explanation: str, check_name: str
|
|
2368
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
|
|
2369
|
+
"""Calculates Spearman correlation between population and data variables.
|
|
2370
|
+
|
|
2371
|
+
Args:
|
|
2372
|
+
ds: An xr.Dataset containing the data variables for which to calculate the
|
|
2373
|
+
correlation with population. The Dataset is expected to have a 'geo'
|
|
2374
|
+
dimension.
|
|
2375
|
+
explanation: A string providing an explanation for the EDA finding.
|
|
2376
|
+
check_name: A string representing the name of the calling check function,
|
|
2377
|
+
used in error messages.
|
|
2378
|
+
|
|
2379
|
+
Returns:
|
|
2380
|
+
An EDAOutcome object containing a PopulationCorrelationArtifact. The
|
|
2381
|
+
artifact includes a Dataset with the Spearman correlation coefficients
|
|
2382
|
+
between each variable in `ds` and the geo population.
|
|
2383
|
+
|
|
2384
|
+
Raises:
|
|
2385
|
+
GeoLevelCheckOnNationalModelError: If the model is national or if
|
|
2386
|
+
`self.geo_population_da` is None.
|
|
2387
|
+
"""
|
|
2388
|
+
|
|
2389
|
+
# self.geo_population_da can never be None if the model is geo-level. Adding
|
|
2390
|
+
# this check to make pytype happy.
|
|
2391
|
+
if self._is_national_data or self.geo_population_da is None:
|
|
2392
|
+
raise GeoLevelCheckOnNationalModelError(
|
|
2393
|
+
f'{check_name} is not supported for national models.'
|
|
2394
|
+
)
|
|
2395
|
+
|
|
2396
|
+
corr_ds: xr.Dataset = xr.apply_ufunc(
|
|
2397
|
+
_spearman_coeff,
|
|
2398
|
+
ds.mean(dim=constants.TIME),
|
|
2399
|
+
self.geo_population_da,
|
|
2400
|
+
input_core_dims=[[constants.GEO], [constants.GEO]],
|
|
2401
|
+
vectorize=True,
|
|
2402
|
+
output_dtypes=[float],
|
|
2403
|
+
)
|
|
2404
|
+
|
|
2405
|
+
artifact = eda_outcome.PopulationCorrelationArtifact(
|
|
2406
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
2407
|
+
correlation_ds=corr_ds,
|
|
2408
|
+
)
|
|
2409
|
+
|
|
2410
|
+
return eda_outcome.EDAOutcome(
|
|
2411
|
+
check_type=eda_outcome.EDACheckType.POPULATION_CORRELATION,
|
|
2412
|
+
findings=[
|
|
2413
|
+
eda_outcome.EDAFinding(
|
|
2414
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2415
|
+
explanation=explanation,
|
|
2416
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2417
|
+
associated_artifact=artifact,
|
|
2108
2418
|
)
|
|
2419
|
+
],
|
|
2420
|
+
analysis_artifacts=[artifact],
|
|
2421
|
+
)
|
|
2422
|
+
|
|
2423
|
+
def check_population_corr_scaled_treatment_control(
|
|
2424
|
+
self,
|
|
2425
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
|
|
2426
|
+
"""Checks Spearman correlation between population and treatments/controls.
|
|
2427
|
+
|
|
2428
|
+
Calculates correlation between population and time-averaged
|
|
2429
|
+
treatments and controls. High correlation for controls or non-media
|
|
2430
|
+
channels may indicate a need for population-scaling. High
|
|
2431
|
+
correlation for other media channels may indicate double-scaling.
|
|
2432
|
+
|
|
2433
|
+
Returns:
|
|
2434
|
+
An EDAOutcome object with findings and result values.
|
|
2435
|
+
|
|
2436
|
+
Raises:
|
|
2437
|
+
GeoLevelCheckOnNationalModelError: If the model is national or geo
|
|
2438
|
+
population data is missing.
|
|
2439
|
+
"""
|
|
2440
|
+
return self._calculate_population_corr(
|
|
2441
|
+
ds=self.treatment_control_scaled_ds,
|
|
2442
|
+
explanation=eda_constants.POPULATION_CORRELATION_SCALED_TREATMENT_CONTROL_INFO,
|
|
2443
|
+
check_name='check_population_corr_scaled_treatment_control',
|
|
2444
|
+
)
|
|
2445
|
+
|
|
2446
|
+
def check_population_corr_raw_media(
|
|
2447
|
+
self,
|
|
2448
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
|
|
2449
|
+
"""Checks Spearman correlation between population and raw media executions.
|
|
2450
|
+
|
|
2451
|
+
Calculates correlation between population and time-averaged raw
|
|
2452
|
+
media executions (paid/organic impressions/reach). These are
|
|
2453
|
+
expected to have reasonably high correlation with population.
|
|
2454
|
+
|
|
2455
|
+
Returns:
|
|
2456
|
+
An EDAOutcome object with findings and result values.
|
|
2457
|
+
|
|
2458
|
+
Raises:
|
|
2459
|
+
GeoLevelCheckOnNationalModelError: If the model is national or geo
|
|
2460
|
+
population data is missing.
|
|
2461
|
+
"""
|
|
2462
|
+
to_merge = (
|
|
2463
|
+
da
|
|
2464
|
+
for da in [
|
|
2465
|
+
self.media_raw_da,
|
|
2466
|
+
self.organic_media_raw_da,
|
|
2467
|
+
self.reach_raw_da,
|
|
2468
|
+
self.organic_reach_raw_da,
|
|
2469
|
+
]
|
|
2470
|
+
if da is not None
|
|
2471
|
+
)
|
|
2472
|
+
# Handle the case where there are no media channels.
|
|
2473
|
+
|
|
2474
|
+
return self._calculate_population_corr(
|
|
2475
|
+
ds=xr.merge(to_merge, join='inner'),
|
|
2476
|
+
explanation=eda_constants.POPULATION_CORRELATION_RAW_MEDIA_INFO,
|
|
2477
|
+
check_name='check_population_corr_raw_media',
|
|
2478
|
+
)
|
|
2479
|
+
|
|
2480
|
+
def _check_prior_probability(
|
|
2481
|
+
self,
|
|
2482
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PriorProbabilityArtifact]:
|
|
2483
|
+
"""Checks the prior probability of a negative baseline.
|
|
2484
|
+
|
|
2485
|
+
Returns:
|
|
2486
|
+
An EDAOutcome object containing a PriorProbabilityArtifact. The artifact
|
|
2487
|
+
includes a mock prior negative baseline probability and a DataArray of
|
|
2488
|
+
mock mean prior contributions per channel.
|
|
2489
|
+
"""
|
|
2490
|
+
# TODO: b/476128592 - currently, this check is blocked. for the meantime,
|
|
2491
|
+
# we will return mock data for the report.
|
|
2492
|
+
channel_names = self._model_context.input_data.get_all_channels()
|
|
2493
|
+
mean_prior_contribution = np.random.uniform(
|
|
2494
|
+
size=len(channel_names), low=0.0, high=0.05
|
|
2495
|
+
)
|
|
2496
|
+
mean_prior_contribution_da = xr.DataArray(
|
|
2497
|
+
mean_prior_contribution,
|
|
2498
|
+
coords={constants.CHANNEL: channel_names},
|
|
2499
|
+
dims=[constants.CHANNEL],
|
|
2500
|
+
)
|
|
2501
|
+
|
|
2502
|
+
artifact = eda_outcome.PriorProbabilityArtifact(
|
|
2503
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
2504
|
+
prior_negative_baseline_prob=0.123,
|
|
2505
|
+
mean_prior_contribution_da=mean_prior_contribution_da,
|
|
2506
|
+
)
|
|
2507
|
+
|
|
2508
|
+
findings = [
|
|
2509
|
+
eda_outcome.EDAFinding(
|
|
2510
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2511
|
+
explanation=eda_constants.PRIOR_PROBABILITY_REPORT_INFO,
|
|
2512
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2513
|
+
associated_artifact=artifact,
|
|
2109
2514
|
)
|
|
2110
|
-
|
|
2515
|
+
]
|
|
2516
|
+
|
|
2517
|
+
return eda_outcome.EDAOutcome(
|
|
2518
|
+
check_type=eda_outcome.EDACheckType.PRIOR_PROBABILITY,
|
|
2519
|
+
findings=findings,
|
|
2520
|
+
analysis_artifacts=[artifact],
|
|
2521
|
+
)
|