google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/checks.py +118 -116
- meridian/analysis/review/constants.py +3 -3
- meridian/analysis/review/results.py +131 -68
- meridian/analysis/review/reviewer.py +8 -23
- meridian/analysis/summarizer.py +6 -1
- meridian/analysis/test_utils.py +2898 -2538
- meridian/analysis/visualizer.py +28 -9
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +1 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +25 -41
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +134 -0
- meridian/model/eda/constants.py +334 -4
- meridian/model/eda/eda_engine.py +724 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/model.py +159 -110
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/linkingapi/constants.py +1 -1
- scenarioplanner/mmm_ui_proto_generator.py +1 -0
- schema/processors/marketing_processor.py +11 -10
- schema/processors/model_processor.py +4 -1
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +12 -3
- schema/utils/__init__.py +1 -0
- schema/utils/proto_enum_converter.py +127 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
meridian/model/eda/eda_engine.py
CHANGED
|
@@ -16,19 +16,23 @@
|
|
|
16
16
|
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
+
from collections.abc import Collection, Sequence
|
|
19
20
|
import dataclasses
|
|
20
21
|
import functools
|
|
21
22
|
import typing
|
|
22
|
-
from typing import
|
|
23
|
+
from typing import Protocol
|
|
24
|
+
import warnings
|
|
23
25
|
|
|
24
26
|
from meridian import backend
|
|
25
27
|
from meridian import constants
|
|
28
|
+
from meridian.model import context
|
|
26
29
|
from meridian.model import transformers
|
|
27
30
|
from meridian.model.eda import constants as eda_constants
|
|
28
31
|
from meridian.model.eda import eda_outcome
|
|
29
32
|
from meridian.model.eda import eda_spec
|
|
30
33
|
import numpy as np
|
|
31
34
|
import pandas as pd
|
|
35
|
+
from scipy import stats
|
|
32
36
|
import statsmodels.api as sm
|
|
33
37
|
from statsmodels.stats import outliers_influence
|
|
34
38
|
import xarray as xr
|
|
@@ -39,25 +43,6 @@ if typing.TYPE_CHECKING:
|
|
|
39
43
|
|
|
40
44
|
__all__ = ['EDAEngine', 'GeoLevelCheckOnNationalModelError']
|
|
41
45
|
|
|
42
|
-
_DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
|
|
43
|
-
_CORRELATION_COL_NAME = eda_constants.CORRELATION
|
|
44
|
-
_STACK_VAR_COORD_NAME = eda_constants.VARIABLE
|
|
45
|
-
_CORR_VAR1 = eda_constants.VARIABLE_1
|
|
46
|
-
_CORR_VAR2 = eda_constants.VARIABLE_2
|
|
47
|
-
_CORRELATION_MATRIX_NAME = 'correlation_matrix'
|
|
48
|
-
_OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
49
|
-
_GEO_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
50
|
-
_NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
|
|
51
|
-
_Q1_THRESHOLD = 0.25
|
|
52
|
-
_Q3_THRESHOLD = 0.75
|
|
53
|
-
_IQR_MULTIPLIER = 1.5
|
|
54
|
-
_STD_WITH_OUTLIERS_VAR_NAME = 'std_with_outliers'
|
|
55
|
-
_STD_WITHOUT_OUTLIERS_VAR_NAME = 'std_without_outliers'
|
|
56
|
-
_STD_THRESHOLD = 1e-4
|
|
57
|
-
_OUTLIERS_COL_NAME = 'outliers'
|
|
58
|
-
_ABS_OUTLIERS_COL_NAME = 'abs_outliers'
|
|
59
|
-
_VIF_COL_NAME = 'VIF'
|
|
60
|
-
|
|
61
46
|
|
|
62
47
|
class _NamedEDACheckCallable(Protocol):
|
|
63
48
|
"""A callable that returns an EDAOutcome and has a __name__ attribute."""
|
|
@@ -149,6 +134,15 @@ class ReachFrequencyData:
|
|
|
149
134
|
national_rf_impressions_raw_da: xr.DataArray
|
|
150
135
|
|
|
151
136
|
|
|
137
|
+
def _get_vars_from_dataset(
|
|
138
|
+
base_ds: xr.Dataset,
|
|
139
|
+
variables_to_include: Collection[str],
|
|
140
|
+
) -> xr.Dataset | None:
|
|
141
|
+
"""Helper to get a subset of variables from a Dataset."""
|
|
142
|
+
variables = [v for v in base_ds.data_vars if v in variables_to_include]
|
|
143
|
+
return base_ds[variables].copy() if variables else None
|
|
144
|
+
|
|
145
|
+
|
|
152
146
|
def _data_array_like(
|
|
153
147
|
*, da: xr.DataArray, values: np.ndarray | backend.Tensor
|
|
154
148
|
) -> xr.DataArray:
|
|
@@ -173,7 +167,7 @@ def _data_array_like(
|
|
|
173
167
|
|
|
174
168
|
|
|
175
169
|
def stack_variables(
|
|
176
|
-
ds: xr.Dataset, coord_name: str =
|
|
170
|
+
ds: xr.Dataset, coord_name: str = eda_constants.VARIABLE
|
|
177
171
|
) -> xr.DataArray:
|
|
178
172
|
"""Stacks data variables of a Dataset into a single DataArray.
|
|
179
173
|
|
|
@@ -219,12 +213,12 @@ def _compute_correlation_matrix(
|
|
|
219
213
|
An xr.DataArray containing the correlation matrix.
|
|
220
214
|
"""
|
|
221
215
|
# Create two versions for correlation
|
|
222
|
-
da1 = input_da.rename({
|
|
223
|
-
da2 = input_da.rename({
|
|
216
|
+
da1 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_1})
|
|
217
|
+
da2 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_2})
|
|
224
218
|
|
|
225
219
|
# Compute pairwise correlation across dims. Other dims are broadcasted.
|
|
226
220
|
corr_mat_da = xr.corr(da1, da2, dim=dims)
|
|
227
|
-
corr_mat_da.name =
|
|
221
|
+
corr_mat_da.name = eda_constants.CORRELATION_MATRIX_NAME
|
|
228
222
|
return corr_mat_da
|
|
229
223
|
|
|
230
224
|
|
|
@@ -238,14 +232,14 @@ def _get_upper_triangle_corr_mat(corr_mat_da: xr.DataArray) -> xr.DataArray:
|
|
|
238
232
|
An xr.DataArray containing only the elements in the upper triangle of the
|
|
239
233
|
correlation matrix, with other elements masked as NaN.
|
|
240
234
|
"""
|
|
241
|
-
n_vars = corr_mat_da.sizes[
|
|
235
|
+
n_vars = corr_mat_da.sizes[eda_constants.VARIABLE_1]
|
|
242
236
|
mask_np = np.triu(np.ones((n_vars, n_vars), dtype=bool), k=1)
|
|
243
237
|
mask = xr.DataArray(
|
|
244
238
|
mask_np,
|
|
245
|
-
dims=[
|
|
239
|
+
dims=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2],
|
|
246
240
|
coords={
|
|
247
|
-
|
|
248
|
-
|
|
241
|
+
eda_constants.VARIABLE_1: corr_mat_da[eda_constants.VARIABLE_1],
|
|
242
|
+
eda_constants.VARIABLE_2: corr_mat_da[eda_constants.VARIABLE_2],
|
|
249
243
|
},
|
|
250
244
|
)
|
|
251
245
|
return corr_mat_da.where(mask)
|
|
@@ -259,11 +253,11 @@ def _find_extreme_corr_pairs(
|
|
|
259
253
|
extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
|
|
260
254
|
|
|
261
255
|
return (
|
|
262
|
-
extreme_corr_da.to_dataframe(name=
|
|
256
|
+
extreme_corr_da.to_dataframe(name=eda_constants.CORRELATION)
|
|
263
257
|
.dropna()
|
|
264
258
|
.assign(**{
|
|
265
259
|
eda_constants.ABS_CORRELATION_COL_NAME: (
|
|
266
|
-
lambda x: x[
|
|
260
|
+
lambda x: x[eda_constants.CORRELATION].abs()
|
|
267
261
|
)
|
|
268
262
|
})
|
|
269
263
|
.sort_values(
|
|
@@ -286,11 +280,11 @@ def _get_outlier_bounds(
|
|
|
286
280
|
A tuple containing the lower and upper bounds of outliers as DataArrays.
|
|
287
281
|
"""
|
|
288
282
|
# TODO: Allow users to specify custom outlier definitions.
|
|
289
|
-
q1 = input_da.quantile(
|
|
290
|
-
q3 = input_da.quantile(
|
|
283
|
+
q1 = input_da.quantile(eda_constants.Q1_THRESHOLD, dim=constants.TIME)
|
|
284
|
+
q3 = input_da.quantile(eda_constants.Q3_THRESHOLD, dim=constants.TIME)
|
|
291
285
|
iqr = q3 - q1
|
|
292
|
-
lower_bound = q1 -
|
|
293
|
-
upper_bound = q3 +
|
|
286
|
+
lower_bound = q1 - eda_constants.IQR_MULTIPLIER * iqr
|
|
287
|
+
upper_bound = q3 + eda_constants.IQR_MULTIPLIER * iqr
|
|
294
288
|
return lower_bound, upper_bound
|
|
295
289
|
|
|
296
290
|
|
|
@@ -314,11 +308,10 @@ def _calculate_std(
|
|
|
314
308
|
)
|
|
315
309
|
std_without_outliers = da_no_outlier.std(dim=constants.TIME, ddof=1)
|
|
316
310
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
311
|
+
return xr.Dataset({
|
|
312
|
+
eda_constants.STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
|
|
313
|
+
eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
|
|
320
314
|
})
|
|
321
|
-
return std_ds
|
|
322
315
|
|
|
323
316
|
|
|
324
317
|
def _calculate_outliers(
|
|
@@ -334,23 +327,29 @@ def _calculate_outliers(
|
|
|
334
327
|
outlier values.
|
|
335
328
|
"""
|
|
336
329
|
lower_bound, upper_bound = _get_outlier_bounds(input_da)
|
|
337
|
-
|
|
338
|
-
(input_da < lower_bound) | (input_da > upper_bound)
|
|
339
|
-
|
|
340
|
-
outlier_df = (
|
|
341
|
-
outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME)
|
|
330
|
+
return (
|
|
331
|
+
input_da.where((input_da < lower_bound) | (input_da > upper_bound))
|
|
332
|
+
.to_dataframe(name=eda_constants.OUTLIERS_COL_NAME)
|
|
342
333
|
.dropna()
|
|
343
|
-
.assign(
|
|
344
|
-
|
|
334
|
+
.assign(**{
|
|
335
|
+
eda_constants.ABS_OUTLIERS_COL_NAME: lambda x: np.abs(
|
|
336
|
+
x[eda_constants.OUTLIERS_COL_NAME]
|
|
337
|
+
)
|
|
338
|
+
})
|
|
339
|
+
.sort_values(
|
|
340
|
+
by=eda_constants.ABS_OUTLIERS_COL_NAME,
|
|
341
|
+
ascending=False,
|
|
342
|
+
inplace=False,
|
|
345
343
|
)
|
|
346
|
-
.sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
|
|
347
344
|
)
|
|
348
|
-
return outlier_df
|
|
349
345
|
|
|
350
346
|
|
|
351
347
|
def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
|
|
352
348
|
"""Helper function to compute variance inflation factor.
|
|
353
349
|
|
|
350
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
351
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
352
|
+
|
|
354
353
|
Args:
|
|
355
354
|
input_da: A DataArray for which to calculate the VIF over sample dimensions
|
|
356
355
|
(e.g. time and geo if applicable).
|
|
@@ -359,25 +358,28 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
|
|
|
359
358
|
Returns:
|
|
360
359
|
A DataArray containing the VIF for each variable in the variable dimension.
|
|
361
360
|
"""
|
|
361
|
+
|
|
362
362
|
num_vars = input_da.sizes[var_dim]
|
|
363
363
|
np_data = input_da.values.reshape(-1, num_vars)
|
|
364
|
-
np_data_with_const = sm.add_constant(
|
|
365
|
-
np_data, prepend=True, has_constant='add'
|
|
366
|
-
)
|
|
367
364
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
365
|
+
is_constant = np.std(np_data, axis=0) < eda_constants.STD_THRESHOLD
|
|
366
|
+
vif_values = np.full(num_vars, np.nan)
|
|
367
|
+
(non_constant_vars_indices,) = (~is_constant).nonzero()
|
|
368
|
+
|
|
369
|
+
if non_constant_vars_indices.size > 0:
|
|
370
|
+
design_matrix = sm.add_constant(
|
|
371
|
+
np_data[:, ~is_constant], prepend=True, has_constant='add'
|
|
372
|
+
)
|
|
373
|
+
for i, var_index in enumerate(non_constant_vars_indices):
|
|
374
|
+
vif_values[var_index] = outliers_influence.variance_inflation_factor(
|
|
375
|
+
design_matrix, i + 1
|
|
376
|
+
)
|
|
374
377
|
|
|
375
|
-
|
|
376
|
-
|
|
378
|
+
return xr.DataArray(
|
|
379
|
+
vif_values,
|
|
377
380
|
coords={var_dim: input_da[var_dim].values},
|
|
378
381
|
dims=[var_dim],
|
|
379
382
|
)
|
|
380
|
-
return vif_da
|
|
381
383
|
|
|
382
384
|
|
|
383
385
|
def _check_cost_media_unit_inconsistency(
|
|
@@ -392,9 +394,9 @@ def _check_cost_media_unit_inconsistency(
|
|
|
392
394
|
|
|
393
395
|
Returns:
|
|
394
396
|
A DataFrame of inconsistencies where either cost is zero and media units
|
|
395
|
-
are
|
|
396
|
-
positive, or cost is positive and media units are zero.
|
|
397
|
+
are positive, or cost is positive and media units are zero.
|
|
397
398
|
"""
|
|
399
|
+
|
|
398
400
|
cost_media_units_ds = xr.merge([cost_da, media_units_da])
|
|
399
401
|
|
|
400
402
|
# Condition 1: cost == 0 and media unit > 0
|
|
@@ -432,39 +434,69 @@ def _check_cost_per_media_unit(
|
|
|
432
434
|
cost_da,
|
|
433
435
|
media_units_da,
|
|
434
436
|
)
|
|
437
|
+
|
|
438
|
+
# Calculate cost per media unit. Avoid division by zero by setting cost to
|
|
439
|
+
# NaN where media units are 0. Note that both (cost == media unit == 0) and
|
|
440
|
+
# (cost > 0 and media unit == 0) result in NaN, while the latter one is not
|
|
441
|
+
# desired.
|
|
442
|
+
cost_per_media_unit_da = xr.where(
|
|
443
|
+
media_units_da == 0,
|
|
444
|
+
np.nan,
|
|
445
|
+
cost_da / media_units_da,
|
|
446
|
+
)
|
|
447
|
+
cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
|
|
448
|
+
outlier_df = _calculate_outliers(cost_per_media_unit_da)
|
|
449
|
+
|
|
450
|
+
if not outlier_df.empty:
|
|
451
|
+
outlier_df = outlier_df.rename(
|
|
452
|
+
columns={
|
|
453
|
+
eda_constants.OUTLIERS_COL_NAME: eda_constants.COST_PER_MEDIA_UNIT,
|
|
454
|
+
eda_constants.ABS_OUTLIERS_COL_NAME: (
|
|
455
|
+
eda_constants.ABS_COST_PER_MEDIA_UNIT
|
|
456
|
+
),
|
|
457
|
+
}
|
|
458
|
+
).assign(**{
|
|
459
|
+
constants.SPEND: cost_da.to_series(),
|
|
460
|
+
constants.MEDIA_UNITS: media_units_da.to_series(),
|
|
461
|
+
})[[
|
|
462
|
+
constants.SPEND,
|
|
463
|
+
constants.MEDIA_UNITS,
|
|
464
|
+
eda_constants.COST_PER_MEDIA_UNIT,
|
|
465
|
+
eda_constants.ABS_COST_PER_MEDIA_UNIT,
|
|
466
|
+
]]
|
|
467
|
+
|
|
468
|
+
artifact = eda_outcome.CostPerMediaUnitArtifact(
|
|
469
|
+
level=level,
|
|
470
|
+
cost_per_media_unit_da=cost_per_media_unit_da,
|
|
471
|
+
cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
|
|
472
|
+
outlier_df=outlier_df,
|
|
473
|
+
)
|
|
474
|
+
|
|
435
475
|
if not cost_media_unit_inconsistency_df.empty:
|
|
436
476
|
findings.append(
|
|
437
477
|
eda_outcome.EDAFinding(
|
|
438
478
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
439
479
|
explanation=(
|
|
440
|
-
'There are instances of inconsistent
|
|
441
|
-
' This occurs when
|
|
442
|
-
' or when
|
|
443
|
-
' review the
|
|
480
|
+
'There are instances of inconsistent spend and media units.'
|
|
481
|
+
' This occurs when spend is zero but media units are positive,'
|
|
482
|
+
' or when spend is positive but media units are zero. Please'
|
|
483
|
+
' review the data input for media units and spend.'
|
|
444
484
|
),
|
|
485
|
+
finding_cause=eda_outcome.FindingCause.INCONSISTENT_DATA,
|
|
486
|
+
associated_artifact=artifact,
|
|
445
487
|
)
|
|
446
488
|
)
|
|
447
489
|
|
|
448
|
-
# Calculate cost per media unit
|
|
449
|
-
# Avoid division by zero by setting cost to NaN where media units are 0.
|
|
450
|
-
# Note that both (cost == media unit == 0) and (cost > 0 and media unit ==
|
|
451
|
-
# 0) result in NaN, while the latter one is not desired.
|
|
452
|
-
cost_per_media_unit_da = xr.where(
|
|
453
|
-
media_units_da == 0,
|
|
454
|
-
np.nan,
|
|
455
|
-
cost_da / media_units_da,
|
|
456
|
-
)
|
|
457
|
-
cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
|
|
458
|
-
|
|
459
|
-
outlier_df = _calculate_outliers(cost_per_media_unit_da)
|
|
460
490
|
if not outlier_df.empty:
|
|
461
491
|
findings.append(
|
|
462
492
|
eda_outcome.EDAFinding(
|
|
463
493
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
464
494
|
explanation=(
|
|
465
495
|
'There are outliers in cost per media unit across time.'
|
|
466
|
-
' Please
|
|
496
|
+
' Please check for any possible data input error.'
|
|
467
497
|
),
|
|
498
|
+
finding_cause=eda_outcome.FindingCause.OUTLIER,
|
|
499
|
+
associated_artifact=artifact,
|
|
468
500
|
)
|
|
469
501
|
)
|
|
470
502
|
|
|
@@ -474,16 +506,10 @@ def _check_cost_per_media_unit(
|
|
|
474
506
|
eda_outcome.EDAFinding(
|
|
475
507
|
severity=eda_outcome.EDASeverity.INFO,
|
|
476
508
|
explanation='Please review the cost per media unit data.',
|
|
509
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
477
510
|
)
|
|
478
511
|
)
|
|
479
512
|
|
|
480
|
-
artifact = eda_outcome.CostPerMediaUnitArtifact(
|
|
481
|
-
level=level,
|
|
482
|
-
cost_per_media_unit_da=cost_per_media_unit_da,
|
|
483
|
-
cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
|
|
484
|
-
outlier_df=outlier_df,
|
|
485
|
-
)
|
|
486
|
-
|
|
487
513
|
return eda_outcome.EDAOutcome(
|
|
488
514
|
check_type=eda_outcome.EDACheckType.COST_PER_MEDIA_UNIT,
|
|
489
515
|
findings=findings,
|
|
@@ -491,42 +517,92 @@ def _check_cost_per_media_unit(
|
|
|
491
517
|
)
|
|
492
518
|
|
|
493
519
|
|
|
520
|
+
def _calc_adj_r2(da: xr.DataArray, regressor: str) -> xr.DataArray:
|
|
521
|
+
"""Calculates adjusted R-squared for a DataArray against a regressor.
|
|
522
|
+
|
|
523
|
+
If the input DataArray `da` is constant, it returns NaN.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
da: The input DataArray.
|
|
527
|
+
regressor: The regressor to use in the formula.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
An xr.DataArray containing the adjusted R-squared value or NaN if `da` is
|
|
531
|
+
constant.
|
|
532
|
+
"""
|
|
533
|
+
if da.std(ddof=1) < eda_constants.STD_THRESHOLD:
|
|
534
|
+
return xr.DataArray(np.nan)
|
|
535
|
+
tmp_name = 'dep_var'
|
|
536
|
+
df = da.to_dataframe(name=tmp_name).reset_index()
|
|
537
|
+
formula = f'{tmp_name} ~ C({regressor})'
|
|
538
|
+
ols = sm.OLS.from_formula(formula, df).fit()
|
|
539
|
+
return xr.DataArray(ols.rsquared_adj)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _spearman_coeff(x, y):
|
|
543
|
+
"""Computes spearman correlation coefficient between two ArrayLike objects."""
|
|
544
|
+
|
|
545
|
+
return stats.spearmanr(x, y, nan_policy='omit').statistic
|
|
546
|
+
|
|
547
|
+
|
|
494
548
|
class EDAEngine:
|
|
495
549
|
"""Meridian EDA Engine."""
|
|
496
550
|
|
|
497
551
|
def __init__(
|
|
498
552
|
self,
|
|
499
|
-
|
|
553
|
+
# TODO: b/476230365 - Remove meridian arg.
|
|
554
|
+
meridian: model.Meridian | None = None,
|
|
500
555
|
spec: eda_spec.EDASpec = eda_spec.EDASpec(),
|
|
556
|
+
*,
|
|
557
|
+
model_context: context.ModelContext | None = None,
|
|
501
558
|
):
|
|
502
|
-
|
|
559
|
+
if meridian is not None and model_context is not None:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
'Only one of `meridian` or `model_context` can be provided.'
|
|
562
|
+
)
|
|
563
|
+
if meridian is not None:
|
|
564
|
+
warnings.warn(
|
|
565
|
+
'Initializing EDAEngine with a Meridian object is deprecated'
|
|
566
|
+
' and will be removed in a future version. Please use'
|
|
567
|
+
' `model_context` instead.',
|
|
568
|
+
DeprecationWarning,
|
|
569
|
+
stacklevel=2,
|
|
570
|
+
)
|
|
571
|
+
self._model_context = meridian.model_context
|
|
572
|
+
elif model_context is not None:
|
|
573
|
+
self._model_context = model_context
|
|
574
|
+
else:
|
|
575
|
+
raise ValueError('Either `meridian` or `model_context` must be provided.')
|
|
576
|
+
|
|
577
|
+
self._input_data = self._model_context.input_data
|
|
503
578
|
self._spec = spec
|
|
504
579
|
self._agg_config = self._spec.aggregation_config
|
|
505
580
|
|
|
506
581
|
@property
|
|
507
582
|
def spec(self) -> eda_spec.EDASpec:
|
|
583
|
+
"""The EDA specification."""
|
|
508
584
|
return self._spec
|
|
509
585
|
|
|
510
586
|
@property
|
|
511
587
|
def _is_national_data(self) -> bool:
|
|
512
|
-
return self.
|
|
588
|
+
return self._model_context.is_national
|
|
513
589
|
|
|
514
590
|
@functools.cached_property
|
|
515
591
|
def controls_scaled_da(self) -> xr.DataArray | None:
|
|
516
|
-
"""
|
|
517
|
-
if self.
|
|
592
|
+
"""The scaled controls data array."""
|
|
593
|
+
if self._input_data.controls is None:
|
|
518
594
|
return None
|
|
519
595
|
controls_scaled_da = _data_array_like(
|
|
520
|
-
da=self.
|
|
521
|
-
values=self.
|
|
596
|
+
da=self._input_data.controls,
|
|
597
|
+
values=self._model_context.controls_scaled,
|
|
522
598
|
)
|
|
523
599
|
controls_scaled_da.name = constants.CONTROLS_SCALED
|
|
524
600
|
return controls_scaled_da
|
|
525
601
|
|
|
526
602
|
@functools.cached_property
|
|
527
603
|
def national_controls_scaled_da(self) -> xr.DataArray | None:
|
|
528
|
-
"""
|
|
529
|
-
if self.
|
|
604
|
+
"""The national scaled controls data array."""
|
|
605
|
+
if self._input_data.controls is None:
|
|
530
606
|
return None
|
|
531
607
|
if self._is_national_data:
|
|
532
608
|
if self.controls_scaled_da is None:
|
|
@@ -538,7 +614,7 @@ class EDAEngine:
|
|
|
538
614
|
national_da.name = constants.NATIONAL_CONTROLS_SCALED
|
|
539
615
|
else:
|
|
540
616
|
national_da = self._aggregate_and_scale_geo_da(
|
|
541
|
-
self.
|
|
617
|
+
self._input_data.controls,
|
|
542
618
|
constants.NATIONAL_CONTROLS_SCALED,
|
|
543
619
|
transformers.CenteringAndScalingTransformer,
|
|
544
620
|
constants.CONTROL_VARIABLE,
|
|
@@ -548,43 +624,43 @@ class EDAEngine:
|
|
|
548
624
|
|
|
549
625
|
@functools.cached_property
|
|
550
626
|
def media_raw_da(self) -> xr.DataArray | None:
|
|
551
|
-
"""
|
|
552
|
-
if self.
|
|
627
|
+
"""The raw media data array."""
|
|
628
|
+
if self._input_data.media is None:
|
|
553
629
|
return None
|
|
554
|
-
raw_media_da = self._truncate_media_time(self.
|
|
630
|
+
raw_media_da = self._truncate_media_time(self._input_data.media)
|
|
555
631
|
raw_media_da.name = constants.MEDIA
|
|
556
632
|
return raw_media_da
|
|
557
633
|
|
|
558
634
|
@functools.cached_property
|
|
559
635
|
def media_scaled_da(self) -> xr.DataArray | None:
|
|
560
|
-
"""
|
|
561
|
-
if self.
|
|
636
|
+
"""The scaled media data array."""
|
|
637
|
+
if self._input_data.media is None:
|
|
562
638
|
return None
|
|
563
639
|
media_scaled_da = _data_array_like(
|
|
564
|
-
da=self.
|
|
565
|
-
values=self.
|
|
640
|
+
da=self._input_data.media,
|
|
641
|
+
values=self._model_context.media_tensors.media_scaled,
|
|
566
642
|
)
|
|
567
643
|
media_scaled_da.name = constants.MEDIA_SCALED
|
|
568
644
|
return self._truncate_media_time(media_scaled_da)
|
|
569
645
|
|
|
570
646
|
@functools.cached_property
|
|
571
647
|
def media_spend_da(self) -> xr.DataArray | None:
|
|
572
|
-
"""
|
|
648
|
+
"""The media spend data.
|
|
573
649
|
|
|
574
650
|
If the input spend is aggregated, it is allocated across geo and time
|
|
575
651
|
proportionally to media units.
|
|
576
652
|
"""
|
|
577
653
|
# No need to truncate the media time for media spend.
|
|
578
|
-
|
|
579
|
-
if
|
|
654
|
+
allocated_media_spend = self._input_data.allocated_media_spend
|
|
655
|
+
if allocated_media_spend is None:
|
|
580
656
|
return None
|
|
581
|
-
da =
|
|
657
|
+
da = allocated_media_spend.copy()
|
|
582
658
|
da.name = constants.MEDIA_SPEND
|
|
583
659
|
return da
|
|
584
660
|
|
|
585
661
|
@functools.cached_property
|
|
586
662
|
def national_media_spend_da(self) -> xr.DataArray | None:
|
|
587
|
-
"""
|
|
663
|
+
"""The national media spend data array."""
|
|
588
664
|
media_spend = self.media_spend_da
|
|
589
665
|
if media_spend is None:
|
|
590
666
|
return None
|
|
@@ -593,7 +669,7 @@ class EDAEngine:
|
|
|
593
669
|
national_da.name = constants.NATIONAL_MEDIA_SPEND
|
|
594
670
|
else:
|
|
595
671
|
national_da = self._aggregate_and_scale_geo_da(
|
|
596
|
-
self.
|
|
672
|
+
self._input_data.allocated_media_spend,
|
|
597
673
|
constants.NATIONAL_MEDIA_SPEND,
|
|
598
674
|
None,
|
|
599
675
|
)
|
|
@@ -601,7 +677,7 @@ class EDAEngine:
|
|
|
601
677
|
|
|
602
678
|
@functools.cached_property
|
|
603
679
|
def national_media_raw_da(self) -> xr.DataArray | None:
|
|
604
|
-
"""
|
|
680
|
+
"""The national raw media data array."""
|
|
605
681
|
if self.media_raw_da is None:
|
|
606
682
|
return None
|
|
607
683
|
if self._is_national_data:
|
|
@@ -618,7 +694,7 @@ class EDAEngine:
|
|
|
618
694
|
|
|
619
695
|
@functools.cached_property
|
|
620
696
|
def national_media_scaled_da(self) -> xr.DataArray | None:
|
|
621
|
-
"""
|
|
697
|
+
"""The national scaled media data array."""
|
|
622
698
|
if self.media_scaled_da is None:
|
|
623
699
|
return None
|
|
624
700
|
if self._is_national_data:
|
|
@@ -635,30 +711,30 @@ class EDAEngine:
|
|
|
635
711
|
|
|
636
712
|
@functools.cached_property
|
|
637
713
|
def organic_media_raw_da(self) -> xr.DataArray | None:
|
|
638
|
-
"""
|
|
639
|
-
if self.
|
|
714
|
+
"""The raw organic media data array."""
|
|
715
|
+
if self._input_data.organic_media is None:
|
|
640
716
|
return None
|
|
641
717
|
raw_organic_media_da = self._truncate_media_time(
|
|
642
|
-
self.
|
|
718
|
+
self._input_data.organic_media
|
|
643
719
|
)
|
|
644
720
|
raw_organic_media_da.name = constants.ORGANIC_MEDIA
|
|
645
721
|
return raw_organic_media_da
|
|
646
722
|
|
|
647
723
|
@functools.cached_property
|
|
648
724
|
def organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
649
|
-
"""
|
|
650
|
-
if self.
|
|
725
|
+
"""The scaled organic media data array."""
|
|
726
|
+
if self._input_data.organic_media is None:
|
|
651
727
|
return None
|
|
652
728
|
organic_media_scaled_da = _data_array_like(
|
|
653
|
-
da=self.
|
|
654
|
-
values=self.
|
|
729
|
+
da=self._input_data.organic_media,
|
|
730
|
+
values=self._model_context.organic_media_tensors.organic_media_scaled,
|
|
655
731
|
)
|
|
656
732
|
organic_media_scaled_da.name = constants.ORGANIC_MEDIA_SCALED
|
|
657
733
|
return self._truncate_media_time(organic_media_scaled_da)
|
|
658
734
|
|
|
659
735
|
@functools.cached_property
|
|
660
736
|
def national_organic_media_raw_da(self) -> xr.DataArray | None:
|
|
661
|
-
"""
|
|
737
|
+
"""The national raw organic media data array."""
|
|
662
738
|
if self.organic_media_raw_da is None:
|
|
663
739
|
return None
|
|
664
740
|
if self._is_national_data:
|
|
@@ -673,7 +749,7 @@ class EDAEngine:
|
|
|
673
749
|
|
|
674
750
|
@functools.cached_property
|
|
675
751
|
def national_organic_media_scaled_da(self) -> xr.DataArray | None:
|
|
676
|
-
"""
|
|
752
|
+
"""The national scaled organic media data array."""
|
|
677
753
|
if self.organic_media_scaled_da is None:
|
|
678
754
|
return None
|
|
679
755
|
if self._is_national_data:
|
|
@@ -692,20 +768,20 @@ class EDAEngine:
|
|
|
692
768
|
|
|
693
769
|
@functools.cached_property
|
|
694
770
|
def non_media_scaled_da(self) -> xr.DataArray | None:
|
|
695
|
-
"""
|
|
696
|
-
if self.
|
|
771
|
+
"""The scaled non-media treatments data array."""
|
|
772
|
+
if self._input_data.non_media_treatments is None:
|
|
697
773
|
return None
|
|
698
774
|
non_media_scaled_da = _data_array_like(
|
|
699
|
-
da=self.
|
|
700
|
-
values=self.
|
|
775
|
+
da=self._input_data.non_media_treatments,
|
|
776
|
+
values=self._model_context.non_media_treatments_normalized,
|
|
701
777
|
)
|
|
702
778
|
non_media_scaled_da.name = constants.NON_MEDIA_TREATMENTS_SCALED
|
|
703
779
|
return non_media_scaled_da
|
|
704
780
|
|
|
705
781
|
@functools.cached_property
|
|
706
782
|
def national_non_media_scaled_da(self) -> xr.DataArray | None:
|
|
707
|
-
"""
|
|
708
|
-
if self.
|
|
783
|
+
"""The national scaled non-media treatment data array."""
|
|
784
|
+
if self._input_data.non_media_treatments is None:
|
|
709
785
|
return None
|
|
710
786
|
if self._is_national_data:
|
|
711
787
|
if self.non_media_scaled_da is None:
|
|
@@ -717,7 +793,7 @@ class EDAEngine:
|
|
|
717
793
|
national_da.name = constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED
|
|
718
794
|
else:
|
|
719
795
|
national_da = self._aggregate_and_scale_geo_da(
|
|
720
|
-
self.
|
|
796
|
+
self._input_data.non_media_treatments,
|
|
721
797
|
constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
|
|
722
798
|
transformers.CenteringAndScalingTransformer,
|
|
723
799
|
constants.NON_MEDIA_CHANNEL,
|
|
@@ -727,12 +803,12 @@ class EDAEngine:
|
|
|
727
803
|
|
|
728
804
|
@functools.cached_property
|
|
729
805
|
def rf_spend_da(self) -> xr.DataArray | None:
|
|
730
|
-
"""
|
|
806
|
+
"""The RF spend data.
|
|
731
807
|
|
|
732
808
|
If the input spend is aggregated, it is allocated across geo and time
|
|
733
809
|
proportionally to RF impressions (reach * frequency).
|
|
734
810
|
"""
|
|
735
|
-
da = self.
|
|
811
|
+
da = self._input_data.allocated_rf_spend
|
|
736
812
|
if da is None:
|
|
737
813
|
return None
|
|
738
814
|
da = da.copy()
|
|
@@ -741,7 +817,7 @@ class EDAEngine:
|
|
|
741
817
|
|
|
742
818
|
@functools.cached_property
|
|
743
819
|
def national_rf_spend_da(self) -> xr.DataArray | None:
|
|
744
|
-
"""
|
|
820
|
+
"""The national RF spend data array."""
|
|
745
821
|
rf_spend = self.rf_spend_da
|
|
746
822
|
if rf_spend is None:
|
|
747
823
|
return None
|
|
@@ -750,7 +826,7 @@ class EDAEngine:
|
|
|
750
826
|
national_da.name = constants.NATIONAL_RF_SPEND
|
|
751
827
|
else:
|
|
752
828
|
national_da = self._aggregate_and_scale_geo_da(
|
|
753
|
-
self.
|
|
829
|
+
self._input_data.allocated_rf_spend,
|
|
754
830
|
constants.NATIONAL_RF_SPEND,
|
|
755
831
|
None,
|
|
756
832
|
)
|
|
@@ -758,182 +834,182 @@ class EDAEngine:
|
|
|
758
834
|
|
|
759
835
|
@functools.cached_property
|
|
760
836
|
def _rf_data(self) -> ReachFrequencyData | None:
|
|
761
|
-
if self.
|
|
837
|
+
if self._input_data.reach is None:
|
|
762
838
|
return None
|
|
763
839
|
return self._get_rf_data(
|
|
764
|
-
self.
|
|
765
|
-
self.
|
|
840
|
+
self._input_data.reach,
|
|
841
|
+
self._input_data.frequency,
|
|
766
842
|
is_organic=False,
|
|
767
843
|
)
|
|
768
844
|
|
|
769
845
|
@property
|
|
770
846
|
def reach_raw_da(self) -> xr.DataArray | None:
|
|
771
|
-
"""
|
|
847
|
+
"""The raw reach data array."""
|
|
772
848
|
if self._rf_data is None:
|
|
773
849
|
return None
|
|
774
|
-
return self._rf_data.reach_raw_da
|
|
850
|
+
return self._rf_data.reach_raw_da # pytype: disable=attribute-error
|
|
775
851
|
|
|
776
852
|
@property
|
|
777
853
|
def reach_scaled_da(self) -> xr.DataArray | None:
|
|
778
|
-
"""
|
|
854
|
+
"""The scaled reach data array."""
|
|
779
855
|
if self._rf_data is None:
|
|
780
856
|
return None
|
|
781
857
|
return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
|
|
782
858
|
|
|
783
859
|
@property
|
|
784
860
|
def national_reach_raw_da(self) -> xr.DataArray | None:
|
|
785
|
-
"""
|
|
861
|
+
"""The national raw reach data array."""
|
|
786
862
|
if self._rf_data is None:
|
|
787
863
|
return None
|
|
788
864
|
return self._rf_data.national_reach_raw_da
|
|
789
865
|
|
|
790
866
|
@property
|
|
791
867
|
def national_reach_scaled_da(self) -> xr.DataArray | None:
|
|
792
|
-
"""
|
|
868
|
+
"""The national scaled reach data array."""
|
|
793
869
|
if self._rf_data is None:
|
|
794
870
|
return None
|
|
795
871
|
return self._rf_data.national_reach_scaled_da # pytype: disable=attribute-error
|
|
796
872
|
|
|
797
873
|
@property
|
|
798
874
|
def frequency_da(self) -> xr.DataArray | None:
|
|
799
|
-
"""
|
|
875
|
+
"""The frequency data array."""
|
|
800
876
|
if self._rf_data is None:
|
|
801
877
|
return None
|
|
802
878
|
return self._rf_data.frequency_da # pytype: disable=attribute-error
|
|
803
879
|
|
|
804
880
|
@property
|
|
805
881
|
def national_frequency_da(self) -> xr.DataArray | None:
|
|
806
|
-
"""
|
|
882
|
+
"""The national frequency data array."""
|
|
807
883
|
if self._rf_data is None:
|
|
808
884
|
return None
|
|
809
885
|
return self._rf_data.national_frequency_da # pytype: disable=attribute-error
|
|
810
886
|
|
|
811
887
|
@property
|
|
812
888
|
def rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
813
|
-
"""
|
|
889
|
+
"""The raw RF impressions data array."""
|
|
814
890
|
if self._rf_data is None:
|
|
815
891
|
return None
|
|
816
892
|
return self._rf_data.rf_impressions_raw_da # pytype: disable=attribute-error
|
|
817
893
|
|
|
818
894
|
@property
|
|
819
895
|
def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
820
|
-
"""
|
|
896
|
+
"""The national raw RF impressions data array."""
|
|
821
897
|
if self._rf_data is None:
|
|
822
898
|
return None
|
|
823
899
|
return self._rf_data.national_rf_impressions_raw_da # pytype: disable=attribute-error
|
|
824
900
|
|
|
825
901
|
@property
|
|
826
902
|
def rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
827
|
-
"""
|
|
903
|
+
"""The scaled RF impressions data array."""
|
|
828
904
|
if self._rf_data is None:
|
|
829
905
|
return None
|
|
830
906
|
return self._rf_data.rf_impressions_scaled_da
|
|
831
907
|
|
|
832
908
|
@property
|
|
833
909
|
def national_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
834
|
-
"""
|
|
910
|
+
"""The national scaled RF impressions data array."""
|
|
835
911
|
if self._rf_data is None:
|
|
836
912
|
return None
|
|
837
913
|
return self._rf_data.national_rf_impressions_scaled_da
|
|
838
914
|
|
|
839
915
|
@functools.cached_property
|
|
840
916
|
def _organic_rf_data(self) -> ReachFrequencyData | None:
|
|
841
|
-
if self.
|
|
917
|
+
if self._input_data.organic_reach is None:
|
|
842
918
|
return None
|
|
843
919
|
return self._get_rf_data(
|
|
844
|
-
self.
|
|
845
|
-
self.
|
|
920
|
+
self._input_data.organic_reach,
|
|
921
|
+
self._input_data.organic_frequency,
|
|
846
922
|
is_organic=True,
|
|
847
923
|
)
|
|
848
924
|
|
|
849
925
|
@property
|
|
850
926
|
def organic_reach_raw_da(self) -> xr.DataArray | None:
|
|
851
|
-
"""
|
|
927
|
+
"""The raw organic reach data array."""
|
|
852
928
|
if self._organic_rf_data is None:
|
|
853
929
|
return None
|
|
854
|
-
return self._organic_rf_data.reach_raw_da
|
|
930
|
+
return self._organic_rf_data.reach_raw_da # pytype: disable=attribute-error
|
|
855
931
|
|
|
856
932
|
@property
|
|
857
933
|
def organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
858
|
-
"""
|
|
934
|
+
"""The scaled organic reach data array."""
|
|
859
935
|
if self._organic_rf_data is None:
|
|
860
936
|
return None
|
|
861
937
|
return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
|
|
862
938
|
|
|
863
939
|
@property
|
|
864
940
|
def national_organic_reach_raw_da(self) -> xr.DataArray | None:
|
|
865
|
-
"""
|
|
941
|
+
"""The national raw organic reach data array."""
|
|
866
942
|
if self._organic_rf_data is None:
|
|
867
943
|
return None
|
|
868
944
|
return self._organic_rf_data.national_reach_raw_da
|
|
869
945
|
|
|
870
946
|
@property
|
|
871
947
|
def national_organic_reach_scaled_da(self) -> xr.DataArray | None:
|
|
872
|
-
"""
|
|
948
|
+
"""The national scaled organic reach data array."""
|
|
873
949
|
if self._organic_rf_data is None:
|
|
874
950
|
return None
|
|
875
951
|
return self._organic_rf_data.national_reach_scaled_da # pytype: disable=attribute-error
|
|
876
952
|
|
|
877
953
|
@property
|
|
878
954
|
def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
879
|
-
"""
|
|
955
|
+
"""The scaled organic RF impressions data array."""
|
|
880
956
|
if self._organic_rf_data is None:
|
|
881
957
|
return None
|
|
882
958
|
return self._organic_rf_data.rf_impressions_scaled_da
|
|
883
959
|
|
|
884
960
|
@property
|
|
885
961
|
def national_organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
|
|
886
|
-
"""
|
|
962
|
+
"""The national scaled organic RF impressions data array."""
|
|
887
963
|
if self._organic_rf_data is None:
|
|
888
964
|
return None
|
|
889
965
|
return self._organic_rf_data.national_rf_impressions_scaled_da
|
|
890
966
|
|
|
891
967
|
@property
|
|
892
968
|
def organic_frequency_da(self) -> xr.DataArray | None:
|
|
893
|
-
"""
|
|
969
|
+
"""The organic frequency data array."""
|
|
894
970
|
if self._organic_rf_data is None:
|
|
895
971
|
return None
|
|
896
972
|
return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
|
|
897
973
|
|
|
898
974
|
@property
|
|
899
975
|
def national_organic_frequency_da(self) -> xr.DataArray | None:
|
|
900
|
-
"""
|
|
976
|
+
"""The national organic frequency data array."""
|
|
901
977
|
if self._organic_rf_data is None:
|
|
902
978
|
return None
|
|
903
979
|
return self._organic_rf_data.national_frequency_da # pytype: disable=attribute-error
|
|
904
980
|
|
|
905
981
|
@property
|
|
906
982
|
def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
907
|
-
"""
|
|
983
|
+
"""The raw organic RF impressions data array."""
|
|
908
984
|
if self._organic_rf_data is None:
|
|
909
985
|
return None
|
|
910
986
|
return self._organic_rf_data.rf_impressions_raw_da
|
|
911
987
|
|
|
912
988
|
@property
|
|
913
989
|
def national_organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
|
|
914
|
-
"""
|
|
990
|
+
"""The national raw organic RF impressions data array."""
|
|
915
991
|
if self._organic_rf_data is None:
|
|
916
992
|
return None
|
|
917
993
|
return self._organic_rf_data.national_rf_impressions_raw_da
|
|
918
994
|
|
|
919
995
|
@functools.cached_property
|
|
920
996
|
def geo_population_da(self) -> xr.DataArray | None:
|
|
921
|
-
"""
|
|
997
|
+
"""The geo population data array."""
|
|
922
998
|
if self._is_national_data:
|
|
923
999
|
return None
|
|
924
1000
|
return xr.DataArray(
|
|
925
|
-
self.
|
|
926
|
-
coords={constants.GEO: self.
|
|
1001
|
+
self._model_context.population,
|
|
1002
|
+
coords={constants.GEO: self._input_data.geo.values},
|
|
927
1003
|
dims=[constants.GEO],
|
|
928
1004
|
name=constants.POPULATION,
|
|
929
1005
|
)
|
|
930
1006
|
|
|
931
1007
|
@functools.cached_property
|
|
932
1008
|
def kpi_scaled_da(self) -> xr.DataArray:
|
|
933
|
-
"""
|
|
1009
|
+
"""The scaled KPI data array."""
|
|
934
1010
|
scaled_kpi_da = _data_array_like(
|
|
935
|
-
da=self.
|
|
936
|
-
values=self.
|
|
1011
|
+
da=self._input_data.kpi,
|
|
1012
|
+
values=self._model_context.kpi_scaled,
|
|
937
1013
|
)
|
|
938
1014
|
scaled_kpi_da.name = constants.KPI_SCALED
|
|
939
1015
|
return scaled_kpi_da
|
|
@@ -942,7 +1018,7 @@ class EDAEngine:
|
|
|
942
1018
|
def _overall_scaled_kpi_invariability_artifact(
|
|
943
1019
|
self,
|
|
944
1020
|
) -> eda_outcome.KpiInvariabilityArtifact:
|
|
945
|
-
"""
|
|
1021
|
+
"""An artifact of overall scaled KPI invariability."""
|
|
946
1022
|
return eda_outcome.KpiInvariabilityArtifact(
|
|
947
1023
|
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
948
1024
|
kpi_da=self.kpi_scaled_da,
|
|
@@ -951,14 +1027,14 @@ class EDAEngine:
|
|
|
951
1027
|
|
|
952
1028
|
@functools.cached_property
|
|
953
1029
|
def national_kpi_scaled_da(self) -> xr.DataArray:
|
|
954
|
-
"""
|
|
1030
|
+
"""The national scaled KPI data array."""
|
|
955
1031
|
if self._is_national_data:
|
|
956
1032
|
national_da = self.kpi_scaled_da.squeeze(constants.GEO, drop=True)
|
|
957
1033
|
national_da.name = constants.NATIONAL_KPI_SCALED
|
|
958
1034
|
else:
|
|
959
1035
|
# Note that kpi is summable by assumption.
|
|
960
1036
|
national_da = self._aggregate_and_scale_geo_da(
|
|
961
|
-
self.
|
|
1037
|
+
self._input_data.kpi,
|
|
962
1038
|
constants.NATIONAL_KPI_SCALED,
|
|
963
1039
|
transformers.CenteringAndScalingTransformer,
|
|
964
1040
|
)
|
|
@@ -966,7 +1042,7 @@ class EDAEngine:
|
|
|
966
1042
|
|
|
967
1043
|
@functools.cached_property
|
|
968
1044
|
def treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
969
|
-
"""
|
|
1045
|
+
"""A Dataset containing all scaled treatments and controls.
|
|
970
1046
|
|
|
971
1047
|
This includes media, RF impressions, organic media, organic RF impressions,
|
|
972
1048
|
non-media treatments, and control variables, all at the geo level.
|
|
@@ -987,7 +1063,7 @@ class EDAEngine:
|
|
|
987
1063
|
|
|
988
1064
|
@functools.cached_property
|
|
989
1065
|
def all_spend_ds(self) -> xr.Dataset:
|
|
990
|
-
"""
|
|
1066
|
+
"""A Dataset containing all spend data.
|
|
991
1067
|
|
|
992
1068
|
This includes media spend and rf spend.
|
|
993
1069
|
"""
|
|
@@ -1003,7 +1079,7 @@ class EDAEngine:
|
|
|
1003
1079
|
|
|
1004
1080
|
@functools.cached_property
|
|
1005
1081
|
def national_all_spend_ds(self) -> xr.Dataset:
|
|
1006
|
-
"""
|
|
1082
|
+
"""A Dataset containing all national spend data.
|
|
1007
1083
|
|
|
1008
1084
|
This includes media spend and rf spend.
|
|
1009
1085
|
"""
|
|
@@ -1019,14 +1095,14 @@ class EDAEngine:
|
|
|
1019
1095
|
|
|
1020
1096
|
@functools.cached_property
|
|
1021
1097
|
def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
1022
|
-
"""
|
|
1098
|
+
"""A stacked DataArray of treatment_control_scaled_ds."""
|
|
1023
1099
|
da = stack_variables(self.treatment_control_scaled_ds)
|
|
1024
1100
|
da.name = constants.TREATMENT_CONTROL_SCALED
|
|
1025
1101
|
return da
|
|
1026
1102
|
|
|
1027
1103
|
@functools.cached_property
|
|
1028
1104
|
def national_treatment_control_scaled_ds(self) -> xr.Dataset:
|
|
1029
|
-
"""
|
|
1105
|
+
"""A Dataset containing all scaled treatments and controls.
|
|
1030
1106
|
|
|
1031
1107
|
This includes media, RF impressions, organic media, organic RF impressions,
|
|
1032
1108
|
non-media treatments, and control variables, all at the national level.
|
|
@@ -1047,14 +1123,14 @@ class EDAEngine:
|
|
|
1047
1123
|
|
|
1048
1124
|
@functools.cached_property
|
|
1049
1125
|
def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
|
|
1050
|
-
"""
|
|
1126
|
+
"""A stacked DataArray of national_treatment_control_scaled_ds."""
|
|
1051
1127
|
da = stack_variables(self.national_treatment_control_scaled_ds)
|
|
1052
1128
|
da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
|
|
1053
1129
|
return da
|
|
1054
1130
|
|
|
1055
1131
|
@functools.cached_property
|
|
1056
1132
|
def treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
|
|
1057
|
-
"""
|
|
1133
|
+
"""A Dataset of scaled treatments excluding non-media."""
|
|
1058
1134
|
return self.treatment_control_scaled_ds.drop_dims(
|
|
1059
1135
|
[constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
|
|
1060
1136
|
errors='ignore',
|
|
@@ -1062,15 +1138,34 @@ class EDAEngine:
|
|
|
1062
1138
|
|
|
1063
1139
|
@functools.cached_property
|
|
1064
1140
|
def national_treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
|
|
1065
|
-
"""
|
|
1141
|
+
"""A Dataset of national scaled treatments excluding non-media."""
|
|
1066
1142
|
return self.national_treatment_control_scaled_ds.drop_dims(
|
|
1067
1143
|
[constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
|
|
1068
1144
|
errors='ignore',
|
|
1069
1145
|
)
|
|
1070
1146
|
|
|
1147
|
+
@functools.cached_property
|
|
1148
|
+
def controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
|
|
1149
|
+
"""A Dataset of scaled controls and non-media treatments."""
|
|
1150
|
+
return _get_vars_from_dataset(
|
|
1151
|
+
self.treatment_control_scaled_ds,
|
|
1152
|
+
[constants.CONTROLS_SCALED, constants.NON_MEDIA_TREATMENTS_SCALED],
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1155
|
+
@functools.cached_property
|
|
1156
|
+
def national_controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
|
|
1157
|
+
"""A Dataset of national scaled controls and non-media treatments."""
|
|
1158
|
+
return _get_vars_from_dataset(
|
|
1159
|
+
self.national_treatment_control_scaled_ds,
|
|
1160
|
+
[
|
|
1161
|
+
constants.NATIONAL_CONTROLS_SCALED,
|
|
1162
|
+
constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
|
|
1163
|
+
],
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1071
1166
|
@functools.cached_property
|
|
1072
1167
|
def all_reach_scaled_da(self) -> xr.DataArray | None:
|
|
1073
|
-
"""
|
|
1168
|
+
"""A DataArray containing all scaled reach data.
|
|
1074
1169
|
|
|
1075
1170
|
This includes both paid and organic reach, concatenated along the RF_CHANNEL
|
|
1076
1171
|
dimension.
|
|
@@ -1096,7 +1191,7 @@ class EDAEngine:
|
|
|
1096
1191
|
|
|
1097
1192
|
@functools.cached_property
|
|
1098
1193
|
def all_freq_da(self) -> xr.DataArray | None:
|
|
1099
|
-
"""
|
|
1194
|
+
"""A DataArray containing all frequency data.
|
|
1100
1195
|
|
|
1101
1196
|
This includes both paid and organic frequency, concatenated along the
|
|
1102
1197
|
RF_CHANNEL dimension.
|
|
@@ -1122,7 +1217,7 @@ class EDAEngine:
|
|
|
1122
1217
|
|
|
1123
1218
|
@functools.cached_property
|
|
1124
1219
|
def national_all_reach_scaled_da(self) -> xr.DataArray | None:
|
|
1125
|
-
"""
|
|
1220
|
+
"""A DataArray containing all national-level scaled reach data.
|
|
1126
1221
|
|
|
1127
1222
|
This includes both paid and organic reach, concatenated along the
|
|
1128
1223
|
RF_CHANNEL dimension.
|
|
@@ -1149,7 +1244,7 @@ class EDAEngine:
|
|
|
1149
1244
|
|
|
1150
1245
|
@functools.cached_property
|
|
1151
1246
|
def national_all_freq_da(self) -> xr.DataArray | None:
|
|
1152
|
-
"""
|
|
1247
|
+
"""A DataArray containing all national-level frequency data.
|
|
1153
1248
|
|
|
1154
1249
|
This includes both paid and organic frequency, concatenated along the
|
|
1155
1250
|
RF_CHANNEL dimension.
|
|
@@ -1202,7 +1297,7 @@ class EDAEngine:
|
|
|
1202
1297
|
def _critical_checks(
|
|
1203
1298
|
self,
|
|
1204
1299
|
) -> list[tuple[_NamedEDACheckCallable, eda_outcome.EDACheckType]]:
|
|
1205
|
-
"""
|
|
1300
|
+
"""A list of critical checks to be performed."""
|
|
1206
1301
|
checks = [
|
|
1207
1302
|
(
|
|
1208
1303
|
self.check_overall_kpi_invariability,
|
|
@@ -1221,19 +1316,21 @@ class EDAEngine:
|
|
|
1221
1316
|
# This should not happen. If it does, it means this function is mis-used.
|
|
1222
1317
|
if constants.MEDIA_TIME not in da.coords:
|
|
1223
1318
|
raise ValueError(
|
|
1224
|
-
f'Variable does not have a media time coordinate: {da.name}.'
|
|
1319
|
+
f'Variable does not have a media time coordinate: {da.name!r}.'
|
|
1225
1320
|
)
|
|
1226
1321
|
|
|
1227
|
-
start = self.
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1322
|
+
start = self._model_context.n_media_times - self._model_context.n_times
|
|
1323
|
+
return (
|
|
1324
|
+
da.copy()
|
|
1325
|
+
.isel({constants.MEDIA_TIME: slice(start, None)})
|
|
1326
|
+
.rename({constants.MEDIA_TIME: constants.TIME})
|
|
1327
|
+
)
|
|
1231
1328
|
|
|
1232
1329
|
def _scale_xarray(
|
|
1233
1330
|
self,
|
|
1234
1331
|
xarray: xr.DataArray,
|
|
1235
|
-
transformer_class:
|
|
1236
|
-
population:
|
|
1332
|
+
transformer_class: type[transformers.TensorTransformer] | None,
|
|
1333
|
+
population: backend.Tensor | None = None,
|
|
1237
1334
|
) -> xr.DataArray:
|
|
1238
1335
|
"""Scales xarray values with a TensorTransformer."""
|
|
1239
1336
|
da = xarray.copy()
|
|
@@ -1285,7 +1382,9 @@ class EDAEngine:
|
|
|
1285
1382
|
agg_results = []
|
|
1286
1383
|
for var_name in geo_da[channel_dim].values:
|
|
1287
1384
|
var_data = geo_da.sel({channel_dim: var_name})
|
|
1288
|
-
agg_func = da_var_agg_map.get(
|
|
1385
|
+
agg_func = da_var_agg_map.get(
|
|
1386
|
+
var_name, eda_constants.DEFAULT_DA_VAR_AGG_FUNCTION
|
|
1387
|
+
)
|
|
1289
1388
|
# Apply the aggregation function over the GEO dimension
|
|
1290
1389
|
aggregated_data = var_data.reduce(
|
|
1291
1390
|
agg_func, dim=constants.GEO, keepdims=keepdims
|
|
@@ -1299,9 +1398,9 @@ class EDAEngine:
|
|
|
1299
1398
|
self,
|
|
1300
1399
|
geo_da: xr.DataArray,
|
|
1301
1400
|
national_da_name: str,
|
|
1302
|
-
transformer_class:
|
|
1303
|
-
channel_dim:
|
|
1304
|
-
da_var_agg_map:
|
|
1401
|
+
transformer_class: type[transformers.TensorTransformer] | None,
|
|
1402
|
+
channel_dim: str | None = None,
|
|
1403
|
+
da_var_agg_map: eda_spec.AggregationMap | None = None,
|
|
1305
1404
|
) -> xr.DataArray:
|
|
1306
1405
|
"""Aggregate geo-level xr.DataArray to national level and then scale values.
|
|
1307
1406
|
|
|
@@ -1351,11 +1450,11 @@ class EDAEngine:
|
|
|
1351
1450
|
"""Get impressions and frequencies data arrays for RF channels."""
|
|
1352
1451
|
if is_organic:
|
|
1353
1452
|
scaled_reach_values = (
|
|
1354
|
-
self.
|
|
1453
|
+
self._model_context.organic_rf_tensors.organic_reach_scaled
|
|
1355
1454
|
)
|
|
1356
1455
|
names = _ORGANIC_RF_NAMES
|
|
1357
1456
|
else:
|
|
1358
|
-
scaled_reach_values = self.
|
|
1457
|
+
scaled_reach_values = self._model_context.rf_tensors.reach_scaled
|
|
1359
1458
|
names = _RF_NAMES
|
|
1360
1459
|
|
|
1361
1460
|
reach_scaled_da = _data_array_like(
|
|
@@ -1433,7 +1532,7 @@ class EDAEngine:
|
|
|
1433
1532
|
impressions_scaled_da = self._scale_xarray(
|
|
1434
1533
|
impressions_raw_da,
|
|
1435
1534
|
transformers.MediaTransformer,
|
|
1436
|
-
population=self.
|
|
1535
|
+
population=self._model_context.population,
|
|
1437
1536
|
)
|
|
1438
1537
|
impressions_scaled_da.name = names.impressions_scaled
|
|
1439
1538
|
|
|
@@ -1491,9 +1590,16 @@ class EDAEngine:
|
|
|
1491
1590
|
overall_corr_mat, overall_extreme_corr_var_pairs_df = (
|
|
1492
1591
|
self._pairwise_corr_for_geo_data(
|
|
1493
1592
|
dims=[constants.GEO, constants.TIME],
|
|
1494
|
-
extreme_corr_threshold=
|
|
1593
|
+
extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1495
1594
|
)
|
|
1496
1595
|
)
|
|
1596
|
+
overall_artifact = eda_outcome.PairwiseCorrArtifact(
|
|
1597
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
1598
|
+
corr_matrix=overall_corr_mat,
|
|
1599
|
+
extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
|
|
1600
|
+
extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1601
|
+
)
|
|
1602
|
+
|
|
1497
1603
|
if not overall_extreme_corr_var_pairs_df.empty:
|
|
1498
1604
|
var_pairs = overall_extreme_corr_var_pairs_df.index.to_list()
|
|
1499
1605
|
findings.append(
|
|
@@ -1505,21 +1611,33 @@ class EDAEngine:
|
|
|
1505
1611
|
' variables, please remove one of the variables from the'
|
|
1506
1612
|
f' model.\nPairs with perfect correlation: {var_pairs}'
|
|
1507
1613
|
),
|
|
1614
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
1615
|
+
associated_artifact=overall_artifact,
|
|
1508
1616
|
)
|
|
1509
1617
|
)
|
|
1510
1618
|
|
|
1511
1619
|
geo_corr_mat, geo_extreme_corr_var_pairs_df = (
|
|
1512
1620
|
self._pairwise_corr_for_geo_data(
|
|
1513
1621
|
dims=constants.TIME,
|
|
1514
|
-
extreme_corr_threshold=
|
|
1622
|
+
extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1515
1623
|
)
|
|
1516
1624
|
)
|
|
1517
|
-
#
|
|
1518
|
-
#
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1625
|
+
# Pairs that cause overall level findings are very likely to cause geo
|
|
1626
|
+
# level findings as well, so we exclude them when determining geo-level
|
|
1627
|
+
# findings. This is to avoid over-reporting findings.
|
|
1628
|
+
overall_pairs_index = overall_extreme_corr_var_pairs_df.index
|
|
1629
|
+
is_in_overall = geo_extreme_corr_var_pairs_df.index.droplevel(
|
|
1630
|
+
constants.GEO
|
|
1631
|
+
).isin(overall_pairs_index)
|
|
1632
|
+
geo_df_for_attention = geo_extreme_corr_var_pairs_df[~is_in_overall]
|
|
1633
|
+
geo_artifact = eda_outcome.PairwiseCorrArtifact(
|
|
1634
|
+
level=eda_outcome.AnalysisLevel.GEO,
|
|
1635
|
+
corr_matrix=geo_corr_mat,
|
|
1636
|
+
extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
|
|
1637
|
+
extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1640
|
+
if not geo_df_for_attention.empty:
|
|
1523
1641
|
findings.append(
|
|
1524
1642
|
eda_outcome.EDAFinding(
|
|
1525
1643
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
@@ -1529,6 +1647,8 @@ class EDAEngine:
|
|
|
1529
1647
|
' variables if they also have high pairwise correlations in'
|
|
1530
1648
|
' other geos.'
|
|
1531
1649
|
),
|
|
1650
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
1651
|
+
associated_artifact=geo_artifact,
|
|
1532
1652
|
)
|
|
1533
1653
|
)
|
|
1534
1654
|
|
|
@@ -1538,34 +1658,15 @@ class EDAEngine:
|
|
|
1538
1658
|
findings.append(
|
|
1539
1659
|
eda_outcome.EDAFinding(
|
|
1540
1660
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1541
|
-
explanation=(
|
|
1542
|
-
|
|
1543
|
-
' high pairwise correlation may cause model identifiability'
|
|
1544
|
-
' and convergence issues. Consider combining the variables if'
|
|
1545
|
-
' high correlation exists.'
|
|
1546
|
-
),
|
|
1661
|
+
explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
|
|
1662
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1547
1663
|
)
|
|
1548
1664
|
)
|
|
1549
1665
|
|
|
1550
|
-
pairwise_corr_artifacts = [
|
|
1551
|
-
eda_outcome.PairwiseCorrArtifact(
|
|
1552
|
-
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
1553
|
-
corr_matrix=overall_corr_mat,
|
|
1554
|
-
extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
|
|
1555
|
-
extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
|
|
1556
|
-
),
|
|
1557
|
-
eda_outcome.PairwiseCorrArtifact(
|
|
1558
|
-
level=eda_outcome.AnalysisLevel.GEO,
|
|
1559
|
-
corr_matrix=geo_corr_mat,
|
|
1560
|
-
extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
|
|
1561
|
-
extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
|
|
1562
|
-
),
|
|
1563
|
-
]
|
|
1564
|
-
|
|
1565
1666
|
return eda_outcome.EDAOutcome(
|
|
1566
1667
|
check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
|
|
1567
1668
|
findings=findings,
|
|
1568
|
-
analysis_artifacts=
|
|
1669
|
+
analysis_artifacts=[overall_artifact, geo_artifact],
|
|
1569
1670
|
)
|
|
1570
1671
|
|
|
1571
1672
|
def check_national_pairwise_corr(
|
|
@@ -1582,7 +1683,14 @@ class EDAEngine:
|
|
|
1582
1683
|
self._stacked_national_treatment_control_scaled_da, dims=constants.TIME
|
|
1583
1684
|
)
|
|
1584
1685
|
extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
|
|
1585
|
-
corr_mat,
|
|
1686
|
+
corr_mat, eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD
|
|
1687
|
+
)
|
|
1688
|
+
|
|
1689
|
+
artifact = eda_outcome.PairwiseCorrArtifact(
|
|
1690
|
+
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1691
|
+
corr_matrix=corr_mat,
|
|
1692
|
+
extreme_corr_var_pairs=extreme_corr_var_pairs_df,
|
|
1693
|
+
extreme_corr_threshold=eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD,
|
|
1586
1694
|
)
|
|
1587
1695
|
|
|
1588
1696
|
if not extreme_corr_var_pairs_df.empty:
|
|
@@ -1596,33 +1704,23 @@ class EDAEngine:
|
|
|
1596
1704
|
' variables, please remove one of the variables from the'
|
|
1597
1705
|
f' model.\nPairs with perfect correlation: {var_pairs}'
|
|
1598
1706
|
),
|
|
1707
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
1708
|
+
associated_artifact=artifact,
|
|
1599
1709
|
)
|
|
1600
1710
|
)
|
|
1601
1711
|
else:
|
|
1602
1712
|
findings.append(
|
|
1603
1713
|
eda_outcome.EDAFinding(
|
|
1604
1714
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1605
|
-
explanation=(
|
|
1606
|
-
|
|
1607
|
-
' high pairwise correlation may cause model identifiability'
|
|
1608
|
-
' and convergence issues. Consider combining the variables if'
|
|
1609
|
-
' high correlation exists.'
|
|
1610
|
-
),
|
|
1715
|
+
explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
|
|
1716
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1611
1717
|
)
|
|
1612
1718
|
)
|
|
1613
1719
|
|
|
1614
|
-
pairwise_corr_artifacts = [
|
|
1615
|
-
eda_outcome.PairwiseCorrArtifact(
|
|
1616
|
-
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1617
|
-
corr_matrix=corr_mat,
|
|
1618
|
-
extreme_corr_var_pairs=extreme_corr_var_pairs_df,
|
|
1619
|
-
extreme_corr_threshold=_NATIONAL_PAIRWISE_CORR_THRESHOLD,
|
|
1620
|
-
)
|
|
1621
|
-
]
|
|
1622
1720
|
return eda_outcome.EDAOutcome(
|
|
1623
1721
|
check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
|
|
1624
1722
|
findings=findings,
|
|
1625
|
-
analysis_artifacts=
|
|
1723
|
+
analysis_artifacts=[artifact],
|
|
1626
1724
|
)
|
|
1627
1725
|
|
|
1628
1726
|
def check_pairwise_corr(
|
|
@@ -1643,20 +1741,14 @@ class EDAEngine:
|
|
|
1643
1741
|
data: xr.DataArray,
|
|
1644
1742
|
level: eda_outcome.AnalysisLevel,
|
|
1645
1743
|
zero_std_message: str,
|
|
1744
|
+
outlier_message: str,
|
|
1646
1745
|
) -> tuple[
|
|
1647
|
-
eda_outcome.EDAFinding
|
|
1746
|
+
list[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
|
|
1648
1747
|
]:
|
|
1649
1748
|
"""Helper to check standard deviation."""
|
|
1650
1749
|
std_ds = _calculate_std(data)
|
|
1651
1750
|
outlier_df = _calculate_outliers(data)
|
|
1652
1751
|
|
|
1653
|
-
finding = None
|
|
1654
|
-
if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
|
|
1655
|
-
finding = eda_outcome.EDAFinding(
|
|
1656
|
-
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1657
|
-
explanation=zero_std_message,
|
|
1658
|
-
)
|
|
1659
|
-
|
|
1660
1752
|
artifact = eda_outcome.StandardDeviationArtifact(
|
|
1661
1753
|
variable=str(data.name),
|
|
1662
1754
|
level=level,
|
|
@@ -1664,7 +1756,31 @@ class EDAEngine:
|
|
|
1664
1756
|
outlier_df=outlier_df,
|
|
1665
1757
|
)
|
|
1666
1758
|
|
|
1667
|
-
|
|
1759
|
+
findings = []
|
|
1760
|
+
if (
|
|
1761
|
+
std_ds[eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME]
|
|
1762
|
+
< eda_constants.STD_THRESHOLD
|
|
1763
|
+
).any():
|
|
1764
|
+
findings.append(
|
|
1765
|
+
eda_outcome.EDAFinding(
|
|
1766
|
+
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1767
|
+
explanation=zero_std_message,
|
|
1768
|
+
finding_cause=eda_outcome.FindingCause.VARIABILITY,
|
|
1769
|
+
associated_artifact=artifact,
|
|
1770
|
+
)
|
|
1771
|
+
)
|
|
1772
|
+
|
|
1773
|
+
if not outlier_df.empty:
|
|
1774
|
+
findings.append(
|
|
1775
|
+
eda_outcome.EDAFinding(
|
|
1776
|
+
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1777
|
+
explanation=outlier_message,
|
|
1778
|
+
finding_cause=eda_outcome.FindingCause.OUTLIER,
|
|
1779
|
+
associated_artifact=artifact,
|
|
1780
|
+
)
|
|
1781
|
+
)
|
|
1782
|
+
|
|
1783
|
+
return findings, artifact
|
|
1668
1784
|
|
|
1669
1785
|
def check_geo_std(
|
|
1670
1786
|
self,
|
|
@@ -1685,6 +1801,10 @@ class EDAEngine:
|
|
|
1685
1801
|
' variable for these geos. Please review the input data,'
|
|
1686
1802
|
' and/or consider grouping these geos together.'
|
|
1687
1803
|
),
|
|
1804
|
+
(
|
|
1805
|
+
'There are outliers in the scaled KPI in certain geos.'
|
|
1806
|
+
' Please check for any possible data errors.'
|
|
1807
|
+
),
|
|
1688
1808
|
),
|
|
1689
1809
|
(
|
|
1690
1810
|
self._stacked_treatment_control_scaled_da,
|
|
@@ -1695,6 +1815,11 @@ class EDAEngine:
|
|
|
1695
1815
|
' consider combining them to mitigate potential model'
|
|
1696
1816
|
' identifiability and convergence issues.'
|
|
1697
1817
|
),
|
|
1818
|
+
(
|
|
1819
|
+
'There are outliers in the scaled treatment or control'
|
|
1820
|
+
' variables in certain geos. Please check for any possible data'
|
|
1821
|
+
' errors.'
|
|
1822
|
+
),
|
|
1698
1823
|
),
|
|
1699
1824
|
(
|
|
1700
1825
|
self.all_reach_scaled_da,
|
|
@@ -1705,6 +1830,11 @@ class EDAEngine:
|
|
|
1705
1830
|
' geos, consider modeling them as impression-based channels'
|
|
1706
1831
|
' instead by taking reach * frequency.'
|
|
1707
1832
|
),
|
|
1833
|
+
(
|
|
1834
|
+
'There are outliers in the scaled reach values of the RF or'
|
|
1835
|
+
' Organic RF channels in certain geos. Please check for any'
|
|
1836
|
+
' possible data errors.'
|
|
1837
|
+
),
|
|
1708
1838
|
),
|
|
1709
1839
|
(
|
|
1710
1840
|
self.all_freq_da,
|
|
@@ -1715,30 +1845,34 @@ class EDAEngine:
|
|
|
1715
1845
|
' geos, consider modeling them as impression-based channels'
|
|
1716
1846
|
' instead by taking reach * frequency.'
|
|
1717
1847
|
),
|
|
1848
|
+
(
|
|
1849
|
+
'There are outliers in the scaled frequency values of the RF or'
|
|
1850
|
+
' Organic RF channels in certain geos. Please check for any'
|
|
1851
|
+
' possible data errors.'
|
|
1852
|
+
),
|
|
1718
1853
|
),
|
|
1719
1854
|
]
|
|
1720
1855
|
|
|
1721
|
-
for data_da,
|
|
1856
|
+
for data_da, std_message, outlier_message in checks:
|
|
1722
1857
|
if data_da is None:
|
|
1723
1858
|
continue
|
|
1724
|
-
|
|
1859
|
+
current_findings, artifact = self._check_std(
|
|
1725
1860
|
level=eda_outcome.AnalysisLevel.GEO,
|
|
1726
1861
|
data=data_da,
|
|
1727
|
-
zero_std_message=
|
|
1862
|
+
zero_std_message=std_message,
|
|
1863
|
+
outlier_message=outlier_message,
|
|
1728
1864
|
)
|
|
1729
1865
|
artifacts.append(artifact)
|
|
1730
|
-
if
|
|
1731
|
-
findings.
|
|
1866
|
+
if current_findings:
|
|
1867
|
+
findings.extend(current_findings)
|
|
1732
1868
|
|
|
1733
1869
|
# Add an INFO finding if no findings were added.
|
|
1734
1870
|
if not findings:
|
|
1735
1871
|
findings.append(
|
|
1736
1872
|
eda_outcome.EDAFinding(
|
|
1737
1873
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1738
|
-
explanation=
|
|
1739
|
-
|
|
1740
|
-
' deviation.'
|
|
1741
|
-
),
|
|
1874
|
+
explanation='Please review the computed standard deviations.',
|
|
1875
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1742
1876
|
)
|
|
1743
1877
|
)
|
|
1744
1878
|
|
|
@@ -1765,6 +1899,10 @@ class EDAEngine:
|
|
|
1765
1899
|
' the input data, and/or reconsider the feasibility of model'
|
|
1766
1900
|
' fitting with this dataset.'
|
|
1767
1901
|
),
|
|
1902
|
+
(
|
|
1903
|
+
'There are outliers in the scaled KPI.'
|
|
1904
|
+
' Please check for any possible data errors.'
|
|
1905
|
+
),
|
|
1768
1906
|
),
|
|
1769
1907
|
(
|
|
1770
1908
|
self._stacked_national_treatment_control_scaled_da,
|
|
@@ -1776,6 +1914,10 @@ class EDAEngine:
|
|
|
1776
1914
|
' Please review the input data, and/or consider combining these'
|
|
1777
1915
|
' variables to mitigate sparsity.'
|
|
1778
1916
|
),
|
|
1917
|
+
(
|
|
1918
|
+
'There are outliers in the scaled treatment or control'
|
|
1919
|
+
' variables. Please check for any possible data errors.'
|
|
1920
|
+
),
|
|
1779
1921
|
),
|
|
1780
1922
|
(
|
|
1781
1923
|
self.national_all_reach_scaled_da,
|
|
@@ -1785,6 +1927,11 @@ class EDAEngine:
|
|
|
1785
1927
|
' Consider modeling these RF channels as impression-based'
|
|
1786
1928
|
' channels instead.'
|
|
1787
1929
|
),
|
|
1930
|
+
(
|
|
1931
|
+
'There are outliers in the scaled reach values of the RF or'
|
|
1932
|
+
' Organic RF channels. Please check for any possible data'
|
|
1933
|
+
' errors.'
|
|
1934
|
+
),
|
|
1788
1935
|
),
|
|
1789
1936
|
(
|
|
1790
1937
|
self.national_all_freq_da,
|
|
@@ -1794,30 +1941,34 @@ class EDAEngine:
|
|
|
1794
1941
|
' Consider modeling these RF channels as impression-based'
|
|
1795
1942
|
' channels instead.'
|
|
1796
1943
|
),
|
|
1944
|
+
(
|
|
1945
|
+
'There are outliers in the scaled frequency values of the RF or'
|
|
1946
|
+
' Organic RF channels. Please check for any possible data'
|
|
1947
|
+
' errors.'
|
|
1948
|
+
),
|
|
1797
1949
|
),
|
|
1798
1950
|
]
|
|
1799
1951
|
|
|
1800
|
-
for data_da,
|
|
1952
|
+
for data_da, std_message, outlier_message in checks:
|
|
1801
1953
|
if data_da is None:
|
|
1802
1954
|
continue
|
|
1803
|
-
|
|
1955
|
+
current_findings, artifact = self._check_std(
|
|
1804
1956
|
data=data_da,
|
|
1805
1957
|
level=eda_outcome.AnalysisLevel.NATIONAL,
|
|
1806
|
-
zero_std_message=
|
|
1958
|
+
zero_std_message=std_message,
|
|
1959
|
+
outlier_message=outlier_message,
|
|
1807
1960
|
)
|
|
1808
1961
|
artifacts.append(artifact)
|
|
1809
|
-
if
|
|
1810
|
-
findings.
|
|
1962
|
+
if current_findings:
|
|
1963
|
+
findings.extend(current_findings)
|
|
1811
1964
|
|
|
1812
1965
|
# Add an INFO finding if no findings were added.
|
|
1813
1966
|
if not findings:
|
|
1814
1967
|
findings.append(
|
|
1815
1968
|
eda_outcome.EDAFinding(
|
|
1816
1969
|
severity=eda_outcome.EDASeverity.INFO,
|
|
1817
|
-
explanation=
|
|
1818
|
-
|
|
1819
|
-
' deviation.'
|
|
1820
|
-
),
|
|
1970
|
+
explanation='Please review the computed standard deviations.',
|
|
1971
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1821
1972
|
)
|
|
1822
1973
|
)
|
|
1823
1974
|
|
|
@@ -1841,7 +1992,15 @@ class EDAEngine:
|
|
|
1841
1992
|
return self.check_geo_std()
|
|
1842
1993
|
|
|
1843
1994
|
def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1844
|
-
"""
|
|
1995
|
+
"""Checks geo variance inflation factor among treatments and controls.
|
|
1996
|
+
|
|
1997
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
1998
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
1999
|
+
|
|
2000
|
+
Returns:
|
|
2001
|
+
An EDAOutcome object with findings and result values.
|
|
2002
|
+
"""
|
|
2003
|
+
|
|
1845
2004
|
if self._is_national_data:
|
|
1846
2005
|
raise ValueError(
|
|
1847
2006
|
'Geo-level VIF checks are not applicable for national models.'
|
|
@@ -1851,12 +2010,12 @@ class EDAEngine:
|
|
|
1851
2010
|
tc_da = self._stacked_treatment_control_scaled_da
|
|
1852
2011
|
overall_threshold = self._spec.vif_spec.overall_threshold
|
|
1853
2012
|
|
|
1854
|
-
overall_vif_da = _calculate_vif(tc_da,
|
|
2013
|
+
overall_vif_da = _calculate_vif(tc_da, eda_constants.VARIABLE)
|
|
1855
2014
|
extreme_overall_vif_da = overall_vif_da.where(
|
|
1856
2015
|
overall_vif_da > overall_threshold
|
|
1857
2016
|
)
|
|
1858
2017
|
extreme_overall_vif_df = extreme_overall_vif_da.to_dataframe(
|
|
1859
|
-
name=
|
|
2018
|
+
name=eda_constants.VIF_COL_NAME
|
|
1860
2019
|
).dropna()
|
|
1861
2020
|
|
|
1862
2021
|
overall_vif_artifact = eda_outcome.VIFArtifact(
|
|
@@ -1868,11 +2027,11 @@ class EDAEngine:
|
|
|
1868
2027
|
# Geo level VIF check.
|
|
1869
2028
|
geo_threshold = self._spec.vif_spec.geo_threshold
|
|
1870
2029
|
geo_vif_da = tc_da.groupby(constants.GEO).map(
|
|
1871
|
-
lambda x: _calculate_vif(x,
|
|
2030
|
+
lambda x: _calculate_vif(x, eda_constants.VARIABLE)
|
|
1872
2031
|
)
|
|
1873
2032
|
extreme_geo_vif_da = geo_vif_da.where(geo_vif_da > geo_threshold)
|
|
1874
2033
|
extreme_geo_vif_df = extreme_geo_vif_da.to_dataframe(
|
|
1875
|
-
name=
|
|
2034
|
+
name=eda_constants.VIF_COL_NAME
|
|
1876
2035
|
).dropna()
|
|
1877
2036
|
|
|
1878
2037
|
geo_vif_artifact = eda_outcome.VIFArtifact(
|
|
@@ -1883,33 +2042,47 @@ class EDAEngine:
|
|
|
1883
2042
|
|
|
1884
2043
|
findings = []
|
|
1885
2044
|
if not extreme_overall_vif_df.empty:
|
|
1886
|
-
|
|
2045
|
+
high_vif_vars_message = (
|
|
2046
|
+
'\nVariables with extreme VIF:'
|
|
2047
|
+
f' {extreme_overall_vif_df.index.to_list()}'
|
|
2048
|
+
)
|
|
1887
2049
|
findings.append(
|
|
1888
2050
|
eda_outcome.EDAFinding(
|
|
1889
2051
|
severity=eda_outcome.EDASeverity.ERROR,
|
|
1890
|
-
explanation=(
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
' is a linear combination of other variables. Otherwise,'
|
|
1895
|
-
' consider combining variables.\n'
|
|
1896
|
-
f'Variables with extreme VIF: {high_vif_vars}'
|
|
2052
|
+
explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
|
|
2053
|
+
threshold=overall_threshold,
|
|
2054
|
+
aggregation='times and geos',
|
|
2055
|
+
additional_info=high_vif_vars_message,
|
|
1897
2056
|
),
|
|
2057
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
2058
|
+
associated_artifact=overall_vif_artifact,
|
|
1898
2059
|
)
|
|
1899
2060
|
)
|
|
1900
|
-
|
|
2061
|
+
|
|
2062
|
+
# Variables that cause overall level findings are very likely to cause
|
|
2063
|
+
# geo-level findings as well, so we exclude them when determining
|
|
2064
|
+
# geo-level findings. This is to avoid over-reporting findings.
|
|
2065
|
+
overall_vars_index = extreme_overall_vif_df.index
|
|
2066
|
+
is_in_overall = extreme_geo_vif_df.index.get_level_values(
|
|
2067
|
+
eda_constants.VARIABLE
|
|
2068
|
+
).isin(overall_vars_index)
|
|
2069
|
+
geo_df_for_attention = extreme_geo_vif_df[~is_in_overall]
|
|
2070
|
+
|
|
2071
|
+
if not geo_df_for_attention.empty:
|
|
1901
2072
|
findings.append(
|
|
1902
2073
|
eda_outcome.EDAFinding(
|
|
1903
2074
|
severity=eda_outcome.EDASeverity.ATTENTION,
|
|
1904
2075
|
explanation=(
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
' high VIF in other geos.'
|
|
2076
|
+
eda_constants.MULTICOLLINEARITY_ATTENTION.format(
|
|
2077
|
+
threshold=geo_threshold, additional_info=''
|
|
2078
|
+
)
|
|
1909
2079
|
),
|
|
2080
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
2081
|
+
associated_artifact=geo_vif_artifact,
|
|
1910
2082
|
)
|
|
1911
2083
|
)
|
|
1912
|
-
|
|
2084
|
+
|
|
2085
|
+
if not findings:
|
|
1913
2086
|
findings.append(
|
|
1914
2087
|
eda_outcome.EDAFinding(
|
|
1915
2088
|
severity=eda_outcome.EDASeverity.INFO,
|
|
@@ -1919,6 +2092,7 @@ class EDAEngine:
|
|
|
1919
2092
|
' jeopardize model identifiability and model convergence.'
|
|
1920
2093
|
' Consider combining the variables if high VIF occurs.'
|
|
1921
2094
|
),
|
|
2095
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1922
2096
|
)
|
|
1923
2097
|
)
|
|
1924
2098
|
|
|
@@ -1931,14 +2105,21 @@ class EDAEngine:
|
|
|
1931
2105
|
def check_national_vif(
|
|
1932
2106
|
self,
|
|
1933
2107
|
) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1934
|
-
"""
|
|
2108
|
+
"""Checks national variance inflation factor among treatments and controls.
|
|
2109
|
+
|
|
2110
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
2111
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
2112
|
+
|
|
2113
|
+
Returns:
|
|
2114
|
+
An EDAOutcome object with findings and result values.
|
|
2115
|
+
"""
|
|
1935
2116
|
national_tc_da = self._stacked_national_treatment_control_scaled_da
|
|
1936
2117
|
national_threshold = self._spec.vif_spec.national_threshold
|
|
1937
|
-
national_vif_da = _calculate_vif(national_tc_da,
|
|
2118
|
+
national_vif_da = _calculate_vif(national_tc_da, eda_constants.VARIABLE)
|
|
1938
2119
|
|
|
1939
2120
|
extreme_national_vif_df = (
|
|
1940
2121
|
national_vif_da.where(national_vif_da > national_threshold)
|
|
1941
|
-
.to_dataframe(name=
|
|
2122
|
+
.to_dataframe(name=eda_constants.VIF_COL_NAME)
|
|
1942
2123
|
.dropna()
|
|
1943
2124
|
)
|
|
1944
2125
|
national_vif_artifact = eda_outcome.VIFArtifact(
|
|
@@ -1949,18 +2130,20 @@ class EDAEngine:
|
|
|
1949
2130
|
|
|
1950
2131
|
findings = []
|
|
1951
2132
|
if not extreme_national_vif_df.empty:
|
|
1952
|
-
|
|
2133
|
+
high_vif_vars_message = (
|
|
2134
|
+
'\nVariables with extreme VIF:'
|
|
2135
|
+
f' {extreme_national_vif_df.index.to_list()}'
|
|
2136
|
+
)
|
|
1953
2137
|
findings.append(
|
|
1954
2138
|
eda_outcome.EDAFinding(
|
|
1955
2139
|
severity=eda_outcome.EDASeverity.ERROR,
|
|
1956
|
-
explanation=(
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
' linear combination of other variables. Otherwise, consider'
|
|
1961
|
-
' combining variables.\n'
|
|
1962
|
-
f'Variables with extreme VIF: {high_vif_vars}'
|
|
2140
|
+
explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
|
|
2141
|
+
threshold=national_threshold,
|
|
2142
|
+
aggregation='times',
|
|
2143
|
+
additional_info=high_vif_vars_message,
|
|
1963
2144
|
),
|
|
2145
|
+
finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
|
|
2146
|
+
associated_artifact=national_vif_artifact,
|
|
1964
2147
|
)
|
|
1965
2148
|
)
|
|
1966
2149
|
else:
|
|
@@ -1973,6 +2156,7 @@ class EDAEngine:
|
|
|
1973
2156
|
' jeopardize model identifiability and model convergence.'
|
|
1974
2157
|
' Consider combining the variables if high VIF occurs.'
|
|
1975
2158
|
),
|
|
2159
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
1976
2160
|
)
|
|
1977
2161
|
)
|
|
1978
2162
|
return eda_outcome.EDAOutcome(
|
|
@@ -1984,6 +2168,9 @@ class EDAEngine:
|
|
|
1984
2168
|
def check_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
|
|
1985
2169
|
"""Computes variance inflation factor among treatments and controls.
|
|
1986
2170
|
|
|
2171
|
+
The VIF calculation only focuses on multicollinearity among non-constant
|
|
2172
|
+
variables. Any variable with constant values will result in a NaN VIF value.
|
|
2173
|
+
|
|
1987
2174
|
Returns:
|
|
1988
2175
|
An EDAOutcome object with findings and result values.
|
|
1989
2176
|
"""
|
|
@@ -1994,15 +2181,18 @@ class EDAEngine:
|
|
|
1994
2181
|
|
|
1995
2182
|
@property
|
|
1996
2183
|
def kpi_has_variability(self) -> bool:
|
|
1997
|
-
"""
|
|
2184
|
+
"""Whether the KPI has variability across geos and times."""
|
|
1998
2185
|
return (
|
|
1999
2186
|
self._overall_scaled_kpi_invariability_artifact.kpi_stdev.item()
|
|
2000
|
-
>=
|
|
2187
|
+
>= eda_constants.STD_THRESHOLD
|
|
2001
2188
|
)
|
|
2002
2189
|
|
|
2003
|
-
def check_overall_kpi_invariability(
|
|
2190
|
+
def check_overall_kpi_invariability(
|
|
2191
|
+
self,
|
|
2192
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.KpiInvariabilityArtifact]:
|
|
2004
2193
|
"""Checks if the KPI is constant across all geos and times."""
|
|
2005
|
-
|
|
2194
|
+
artifact = self._overall_scaled_kpi_invariability_artifact
|
|
2195
|
+
kpi = artifact.kpi_da.name
|
|
2006
2196
|
geo_text = '' if self._is_national_data else 'geos and '
|
|
2007
2197
|
|
|
2008
2198
|
if not self.kpi_has_variability:
|
|
@@ -2012,6 +2202,8 @@ class EDAEngine:
|
|
|
2012
2202
|
f'`{kpi}` is constant across all {geo_text}times, indicating no'
|
|
2013
2203
|
' signal in the data. Please fix this data error.'
|
|
2014
2204
|
),
|
|
2205
|
+
finding_cause=eda_outcome.FindingCause.VARIABILITY,
|
|
2206
|
+
associated_artifact=artifact,
|
|
2015
2207
|
)
|
|
2016
2208
|
else:
|
|
2017
2209
|
eda_finding = eda_outcome.EDAFinding(
|
|
@@ -2019,12 +2211,13 @@ class EDAEngine:
|
|
|
2019
2211
|
explanation=(
|
|
2020
2212
|
f'The {kpi} has variability across {geo_text}times in the data.'
|
|
2021
2213
|
),
|
|
2214
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2022
2215
|
)
|
|
2023
2216
|
|
|
2024
2217
|
return eda_outcome.EDAOutcome(
|
|
2025
2218
|
check_type=eda_outcome.EDACheckType.KPI_INVARIABILITY,
|
|
2026
2219
|
findings=[eda_finding],
|
|
2027
|
-
analysis_artifacts=[
|
|
2220
|
+
analysis_artifacts=[artifact],
|
|
2028
2221
|
)
|
|
2029
2222
|
|
|
2030
2223
|
def check_geo_cost_per_media_unit(
|
|
@@ -2081,30 +2274,249 @@ class EDAEngine:
|
|
|
2081
2274
|
|
|
2082
2275
|
return self.check_geo_cost_per_media_unit()
|
|
2083
2276
|
|
|
2084
|
-
def run_all_critical_checks(self) ->
|
|
2277
|
+
def run_all_critical_checks(self) -> eda_outcome.CriticalCheckEDAOutcomes:
|
|
2085
2278
|
"""Runs all critical EDA checks.
|
|
2086
2279
|
|
|
2087
2280
|
Critical checks are those that can result in EDASeverity.ERROR findings.
|
|
2088
2281
|
|
|
2089
2282
|
Returns:
|
|
2090
|
-
A
|
|
2283
|
+
A CriticalCheckEDAOutcomes object containing the results of all critical
|
|
2284
|
+
checks.
|
|
2091
2285
|
"""
|
|
2092
|
-
outcomes =
|
|
2286
|
+
outcomes = {}
|
|
2093
2287
|
for check, check_type in self._critical_checks:
|
|
2094
2288
|
try:
|
|
2095
|
-
outcomes
|
|
2289
|
+
outcomes[check_type] = check()
|
|
2096
2290
|
except Exception as e: # pylint: disable=broad-except
|
|
2097
2291
|
error_finding = eda_outcome.EDAFinding(
|
|
2098
2292
|
severity=eda_outcome.EDASeverity.ERROR,
|
|
2099
2293
|
explanation=(
|
|
2100
2294
|
f'An error occurred during running {check.__name__}: {e!r}'
|
|
2101
2295
|
),
|
|
2296
|
+
finding_cause=eda_outcome.FindingCause.RUNTIME_ERROR,
|
|
2297
|
+
)
|
|
2298
|
+
outcomes[check_type] = eda_outcome.EDAOutcome(
|
|
2299
|
+
check_type=check_type,
|
|
2300
|
+
findings=[error_finding],
|
|
2301
|
+
analysis_artifacts=[],
|
|
2102
2302
|
)
|
|
2103
|
-
|
|
2104
|
-
|
|
2105
|
-
|
|
2106
|
-
|
|
2107
|
-
|
|
2303
|
+
|
|
2304
|
+
return eda_outcome.CriticalCheckEDAOutcomes(
|
|
2305
|
+
kpi_invariability=outcomes[eda_outcome.EDACheckType.KPI_INVARIABILITY],
|
|
2306
|
+
multicollinearity=outcomes[eda_outcome.EDACheckType.MULTICOLLINEARITY],
|
|
2307
|
+
pairwise_correlation=outcomes[
|
|
2308
|
+
eda_outcome.EDACheckType.PAIRWISE_CORRELATION
|
|
2309
|
+
],
|
|
2310
|
+
)
|
|
2311
|
+
|
|
2312
|
+
def check_variable_geo_time_collinearity(
|
|
2313
|
+
self,
|
|
2314
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.VariableGeoTimeCollinearityArtifact]:
|
|
2315
|
+
"""Compute adjusted R-squared for treatments and controls vs geo and time.
|
|
2316
|
+
|
|
2317
|
+
These checks are applied to geo-level dataset only.
|
|
2318
|
+
|
|
2319
|
+
Returns:
|
|
2320
|
+
An EDAOutcome object containing a VariableGeoTimeCollinearityArtifact.
|
|
2321
|
+
The artifact includes a Dataset with 'rsquared_geo' and 'rsquared_time',
|
|
2322
|
+
showing the adjusted R-squared values for each treatment/control variable
|
|
2323
|
+
when regressed against 'geo' and 'time', respectively. If a variable is
|
|
2324
|
+
constant across geos or times, the corresponding 'rsquared_geo' or
|
|
2325
|
+
'rsquared_time' value will be NaN.
|
|
2326
|
+
"""
|
|
2327
|
+
if self._is_national_data:
|
|
2328
|
+
raise ValueError(
|
|
2329
|
+
'check_variable_geo_time_collinearity is not supported for national'
|
|
2330
|
+
' models.'
|
|
2331
|
+
)
|
|
2332
|
+
|
|
2333
|
+
grouped_da = self._stacked_treatment_control_scaled_da.groupby(
|
|
2334
|
+
eda_constants.VARIABLE
|
|
2335
|
+
)
|
|
2336
|
+
rsq_geo = grouped_da.map(_calc_adj_r2, args=(constants.GEO,))
|
|
2337
|
+
rsq_time = grouped_da.map(_calc_adj_r2, args=(constants.TIME,))
|
|
2338
|
+
|
|
2339
|
+
rsquared_ds = xr.Dataset({
|
|
2340
|
+
eda_constants.RSQUARED_GEO: rsq_geo,
|
|
2341
|
+
eda_constants.RSQUARED_TIME: rsq_time,
|
|
2342
|
+
})
|
|
2343
|
+
|
|
2344
|
+
artifact = eda_outcome.VariableGeoTimeCollinearityArtifact(
|
|
2345
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
2346
|
+
rsquared_ds=rsquared_ds,
|
|
2347
|
+
)
|
|
2348
|
+
findings = [
|
|
2349
|
+
eda_outcome.EDAFinding(
|
|
2350
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2351
|
+
explanation=eda_constants.R_SQUARED_TIME_INFO,
|
|
2352
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2353
|
+
),
|
|
2354
|
+
eda_outcome.EDAFinding(
|
|
2355
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2356
|
+
explanation=eda_constants.R_SQUARED_GEO_INFO,
|
|
2357
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2358
|
+
),
|
|
2359
|
+
]
|
|
2360
|
+
|
|
2361
|
+
return eda_outcome.EDAOutcome(
|
|
2362
|
+
check_type=eda_outcome.EDACheckType.VARIABLE_GEO_TIME_COLLINEARITY,
|
|
2363
|
+
findings=findings,
|
|
2364
|
+
analysis_artifacts=[artifact],
|
|
2365
|
+
)
|
|
2366
|
+
|
|
2367
|
+
def _calculate_population_corr(
|
|
2368
|
+
self, ds: xr.Dataset, *, explanation: str, check_name: str
|
|
2369
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
|
|
2370
|
+
"""Calculates Spearman correlation between population and data variables.
|
|
2371
|
+
|
|
2372
|
+
Args:
|
|
2373
|
+
ds: An xr.Dataset containing the data variables for which to calculate the
|
|
2374
|
+
correlation with population. The Dataset is expected to have a 'geo'
|
|
2375
|
+
dimension.
|
|
2376
|
+
explanation: A string providing an explanation for the EDA finding.
|
|
2377
|
+
check_name: A string representing the name of the calling check function,
|
|
2378
|
+
used in error messages.
|
|
2379
|
+
|
|
2380
|
+
Returns:
|
|
2381
|
+
An EDAOutcome object containing a PopulationCorrelationArtifact. The
|
|
2382
|
+
artifact includes a Dataset with the Spearman correlation coefficients
|
|
2383
|
+
between each variable in `ds` and the geo population.
|
|
2384
|
+
|
|
2385
|
+
Raises:
|
|
2386
|
+
GeoLevelCheckOnNationalModelError: If the model is national or if
|
|
2387
|
+
`self.geo_population_da` is None.
|
|
2388
|
+
"""
|
|
2389
|
+
|
|
2390
|
+
# self.geo_population_da can never be None if the model is geo-level. Adding
|
|
2391
|
+
# this check to make pytype happy.
|
|
2392
|
+
if self._is_national_data or self.geo_population_da is None:
|
|
2393
|
+
raise GeoLevelCheckOnNationalModelError(
|
|
2394
|
+
f'{check_name} is not supported for national models.'
|
|
2395
|
+
)
|
|
2396
|
+
|
|
2397
|
+
corr_ds: xr.Dataset = xr.apply_ufunc(
|
|
2398
|
+
_spearman_coeff,
|
|
2399
|
+
ds.mean(dim=constants.TIME),
|
|
2400
|
+
self.geo_population_da,
|
|
2401
|
+
input_core_dims=[[constants.GEO], [constants.GEO]],
|
|
2402
|
+
vectorize=True,
|
|
2403
|
+
output_dtypes=[float],
|
|
2404
|
+
)
|
|
2405
|
+
|
|
2406
|
+
artifact = eda_outcome.PopulationCorrelationArtifact(
|
|
2407
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
2408
|
+
correlation_ds=corr_ds,
|
|
2409
|
+
)
|
|
2410
|
+
|
|
2411
|
+
return eda_outcome.EDAOutcome(
|
|
2412
|
+
check_type=eda_outcome.EDACheckType.POPULATION_CORRELATION,
|
|
2413
|
+
findings=[
|
|
2414
|
+
eda_outcome.EDAFinding(
|
|
2415
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2416
|
+
explanation=explanation,
|
|
2417
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2418
|
+
associated_artifact=artifact,
|
|
2108
2419
|
)
|
|
2420
|
+
],
|
|
2421
|
+
analysis_artifacts=[artifact],
|
|
2422
|
+
)
|
|
2423
|
+
|
|
2424
|
+
def check_population_corr_scaled_treatment_control(
|
|
2425
|
+
self,
|
|
2426
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
|
|
2427
|
+
"""Checks Spearman correlation between population and treatments/controls.
|
|
2428
|
+
|
|
2429
|
+
Calculates correlation between population and time-averaged
|
|
2430
|
+
treatments and controls. High correlation for controls or non-media
|
|
2431
|
+
channels may indicate a need for population-scaling. High
|
|
2432
|
+
correlation for other media channels may indicate double-scaling.
|
|
2433
|
+
|
|
2434
|
+
Returns:
|
|
2435
|
+
An EDAOutcome object with findings and result values.
|
|
2436
|
+
|
|
2437
|
+
Raises:
|
|
2438
|
+
GeoLevelCheckOnNationalModelError: If the model is national or geo
|
|
2439
|
+
population data is missing.
|
|
2440
|
+
"""
|
|
2441
|
+
return self._calculate_population_corr(
|
|
2442
|
+
ds=self.treatment_control_scaled_ds,
|
|
2443
|
+
explanation=eda_constants.POPULATION_CORRELATION_SCALED_TREATMENT_CONTROL_INFO,
|
|
2444
|
+
check_name='check_population_corr_scaled_treatment_control',
|
|
2445
|
+
)
|
|
2446
|
+
|
|
2447
|
+
def check_population_corr_raw_media(
|
|
2448
|
+
self,
|
|
2449
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
|
|
2450
|
+
"""Checks Spearman correlation between population and raw media executions.
|
|
2451
|
+
|
|
2452
|
+
Calculates correlation between population and time-averaged raw
|
|
2453
|
+
media executions (paid/organic impressions/reach). These are
|
|
2454
|
+
expected to have reasonably high correlation with population.
|
|
2455
|
+
|
|
2456
|
+
Returns:
|
|
2457
|
+
An EDAOutcome object with findings and result values.
|
|
2458
|
+
|
|
2459
|
+
Raises:
|
|
2460
|
+
GeoLevelCheckOnNationalModelError: If the model is national or geo
|
|
2461
|
+
population data is missing.
|
|
2462
|
+
"""
|
|
2463
|
+
to_merge = (
|
|
2464
|
+
da
|
|
2465
|
+
for da in [
|
|
2466
|
+
self.media_raw_da,
|
|
2467
|
+
self.organic_media_raw_da,
|
|
2468
|
+
self.reach_raw_da,
|
|
2469
|
+
self.organic_reach_raw_da,
|
|
2470
|
+
]
|
|
2471
|
+
if da is not None
|
|
2472
|
+
)
|
|
2473
|
+
# Handle the case where there are no media channels.
|
|
2474
|
+
|
|
2475
|
+
return self._calculate_population_corr(
|
|
2476
|
+
ds=xr.merge(to_merge, join='inner'),
|
|
2477
|
+
explanation=eda_constants.POPULATION_CORRELATION_RAW_MEDIA_INFO,
|
|
2478
|
+
check_name='check_population_corr_raw_media',
|
|
2479
|
+
)
|
|
2480
|
+
|
|
2481
|
+
def _check_prior_probability(
|
|
2482
|
+
self,
|
|
2483
|
+
) -> eda_outcome.EDAOutcome[eda_outcome.PriorProbabilityArtifact]:
|
|
2484
|
+
"""Checks the prior probability of a negative baseline.
|
|
2485
|
+
|
|
2486
|
+
Returns:
|
|
2487
|
+
An EDAOutcome object containing a PriorProbabilityArtifact. The artifact
|
|
2488
|
+
includes a mock prior negative baseline probability and a DataArray of
|
|
2489
|
+
mock mean prior contributions per channel.
|
|
2490
|
+
"""
|
|
2491
|
+
# TODO: b/476128592 - currently, this check is blocked. for the meantime,
|
|
2492
|
+
# we will return mock data for the report.
|
|
2493
|
+
channel_names = self._model_context.input_data.get_all_channels()
|
|
2494
|
+
mean_prior_contribution = np.random.uniform(
|
|
2495
|
+
size=len(channel_names), low=0.0, high=0.05
|
|
2496
|
+
)
|
|
2497
|
+
mean_prior_contribution_da = xr.DataArray(
|
|
2498
|
+
mean_prior_contribution,
|
|
2499
|
+
coords={constants.CHANNEL: channel_names},
|
|
2500
|
+
dims=[constants.CHANNEL],
|
|
2501
|
+
)
|
|
2502
|
+
|
|
2503
|
+
artifact = eda_outcome.PriorProbabilityArtifact(
|
|
2504
|
+
level=eda_outcome.AnalysisLevel.OVERALL,
|
|
2505
|
+
prior_negative_baseline_prob=0.123,
|
|
2506
|
+
mean_prior_contribution_da=mean_prior_contribution_da,
|
|
2507
|
+
)
|
|
2508
|
+
|
|
2509
|
+
findings = [
|
|
2510
|
+
eda_outcome.EDAFinding(
|
|
2511
|
+
severity=eda_outcome.EDASeverity.INFO,
|
|
2512
|
+
explanation=eda_constants.PRIOR_PROBABILITY_REPORT_INFO,
|
|
2513
|
+
finding_cause=eda_outcome.FindingCause.NONE,
|
|
2514
|
+
associated_artifact=artifact,
|
|
2109
2515
|
)
|
|
2110
|
-
|
|
2516
|
+
]
|
|
2517
|
+
|
|
2518
|
+
return eda_outcome.EDAOutcome(
|
|
2519
|
+
check_type=eda_outcome.EDACheckType.PRIOR_PROBABILITY,
|
|
2520
|
+
findings=findings,
|
|
2521
|
+
analysis_artifacts=[artifact],
|
|
2522
|
+
)
|