google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/checks.py +118 -116
  7. meridian/analysis/review/constants.py +3 -3
  8. meridian/analysis/review/results.py +131 -68
  9. meridian/analysis/review/reviewer.py +8 -23
  10. meridian/analysis/summarizer.py +6 -1
  11. meridian/analysis/test_utils.py +2898 -2538
  12. meridian/analysis/visualizer.py +28 -9
  13. meridian/backend/__init__.py +106 -0
  14. meridian/constants.py +1 -0
  15. meridian/data/input_data.py +30 -52
  16. meridian/data/input_data_builder.py +2 -9
  17. meridian/data/test_utils.py +25 -41
  18. meridian/data/validator.py +48 -0
  19. meridian/mlflow/autolog.py +19 -9
  20. meridian/model/adstock_hill.py +3 -5
  21. meridian/model/context.py +134 -0
  22. meridian/model/eda/constants.py +334 -4
  23. meridian/model/eda/eda_engine.py +724 -312
  24. meridian/model/eda/eda_outcome.py +177 -33
  25. meridian/model/model.py +159 -110
  26. meridian/model/model_test_data.py +38 -0
  27. meridian/model/posterior_sampler.py +103 -62
  28. meridian/model/prior_sampler.py +114 -94
  29. meridian/model/spec.py +23 -14
  30. meridian/templates/card.html.jinja +9 -7
  31. meridian/templates/chart.html.jinja +1 -6
  32. meridian/templates/finding.html.jinja +19 -0
  33. meridian/templates/findings.html.jinja +33 -0
  34. meridian/templates/formatter.py +41 -5
  35. meridian/templates/formatter_test.py +127 -0
  36. meridian/templates/style.css +66 -9
  37. meridian/templates/style.scss +85 -4
  38. meridian/templates/table.html.jinja +1 -0
  39. meridian/version.py +1 -1
  40. scenarioplanner/linkingapi/constants.py +1 -1
  41. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  42. schema/processors/marketing_processor.py +11 -10
  43. schema/processors/model_processor.py +4 -1
  44. schema/serde/distribution.py +12 -7
  45. schema/serde/hyperparameters.py +54 -107
  46. schema/serde/meridian_serde.py +12 -3
  47. schema/utils/__init__.py +1 -0
  48. schema/utils/proto_enum_converter.py +127 -0
  49. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
  50. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
@@ -16,19 +16,23 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
+ from collections.abc import Collection, Sequence
19
20
  import dataclasses
20
21
  import functools
21
22
  import typing
22
- from typing import Optional, Protocol, Sequence
23
+ from typing import Protocol
24
+ import warnings
23
25
 
24
26
  from meridian import backend
25
27
  from meridian import constants
28
+ from meridian.model import context
26
29
  from meridian.model import transformers
27
30
  from meridian.model.eda import constants as eda_constants
28
31
  from meridian.model.eda import eda_outcome
29
32
  from meridian.model.eda import eda_spec
30
33
  import numpy as np
31
34
  import pandas as pd
35
+ from scipy import stats
32
36
  import statsmodels.api as sm
33
37
  from statsmodels.stats import outliers_influence
34
38
  import xarray as xr
@@ -39,25 +43,6 @@ if typing.TYPE_CHECKING:
39
43
 
40
44
  __all__ = ['EDAEngine', 'GeoLevelCheckOnNationalModelError']
41
45
 
42
- _DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
43
- _CORRELATION_COL_NAME = eda_constants.CORRELATION
44
- _STACK_VAR_COORD_NAME = eda_constants.VARIABLE
45
- _CORR_VAR1 = eda_constants.VARIABLE_1
46
- _CORR_VAR2 = eda_constants.VARIABLE_2
47
- _CORRELATION_MATRIX_NAME = 'correlation_matrix'
48
- _OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
49
- _GEO_PAIRWISE_CORR_THRESHOLD = 0.999
50
- _NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
51
- _Q1_THRESHOLD = 0.25
52
- _Q3_THRESHOLD = 0.75
53
- _IQR_MULTIPLIER = 1.5
54
- _STD_WITH_OUTLIERS_VAR_NAME = 'std_with_outliers'
55
- _STD_WITHOUT_OUTLIERS_VAR_NAME = 'std_without_outliers'
56
- _STD_THRESHOLD = 1e-4
57
- _OUTLIERS_COL_NAME = 'outliers'
58
- _ABS_OUTLIERS_COL_NAME = 'abs_outliers'
59
- _VIF_COL_NAME = 'VIF'
60
-
61
46
 
62
47
  class _NamedEDACheckCallable(Protocol):
63
48
  """A callable that returns an EDAOutcome and has a __name__ attribute."""
@@ -149,6 +134,15 @@ class ReachFrequencyData:
149
134
  national_rf_impressions_raw_da: xr.DataArray
150
135
 
151
136
 
137
+ def _get_vars_from_dataset(
138
+ base_ds: xr.Dataset,
139
+ variables_to_include: Collection[str],
140
+ ) -> xr.Dataset | None:
141
+ """Helper to get a subset of variables from a Dataset."""
142
+ variables = [v for v in base_ds.data_vars if v in variables_to_include]
143
+ return base_ds[variables].copy() if variables else None
144
+
145
+
152
146
  def _data_array_like(
153
147
  *, da: xr.DataArray, values: np.ndarray | backend.Tensor
154
148
  ) -> xr.DataArray:
@@ -173,7 +167,7 @@ def _data_array_like(
173
167
 
174
168
 
175
169
  def stack_variables(
176
- ds: xr.Dataset, coord_name: str = _STACK_VAR_COORD_NAME
170
+ ds: xr.Dataset, coord_name: str = eda_constants.VARIABLE
177
171
  ) -> xr.DataArray:
178
172
  """Stacks data variables of a Dataset into a single DataArray.
179
173
 
@@ -219,12 +213,12 @@ def _compute_correlation_matrix(
219
213
  An xr.DataArray containing the correlation matrix.
220
214
  """
221
215
  # Create two versions for correlation
222
- da1 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR1})
223
- da2 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR2})
216
+ da1 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_1})
217
+ da2 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_2})
224
218
 
225
219
  # Compute pairwise correlation across dims. Other dims are broadcasted.
226
220
  corr_mat_da = xr.corr(da1, da2, dim=dims)
227
- corr_mat_da.name = _CORRELATION_MATRIX_NAME
221
+ corr_mat_da.name = eda_constants.CORRELATION_MATRIX_NAME
228
222
  return corr_mat_da
229
223
 
230
224
 
@@ -238,14 +232,14 @@ def _get_upper_triangle_corr_mat(corr_mat_da: xr.DataArray) -> xr.DataArray:
238
232
  An xr.DataArray containing only the elements in the upper triangle of the
239
233
  correlation matrix, with other elements masked as NaN.
240
234
  """
241
- n_vars = corr_mat_da.sizes[_CORR_VAR1]
235
+ n_vars = corr_mat_da.sizes[eda_constants.VARIABLE_1]
242
236
  mask_np = np.triu(np.ones((n_vars, n_vars), dtype=bool), k=1)
243
237
  mask = xr.DataArray(
244
238
  mask_np,
245
- dims=[_CORR_VAR1, _CORR_VAR2],
239
+ dims=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2],
246
240
  coords={
247
- _CORR_VAR1: corr_mat_da[_CORR_VAR1],
248
- _CORR_VAR2: corr_mat_da[_CORR_VAR2],
241
+ eda_constants.VARIABLE_1: corr_mat_da[eda_constants.VARIABLE_1],
242
+ eda_constants.VARIABLE_2: corr_mat_da[eda_constants.VARIABLE_2],
249
243
  },
250
244
  )
251
245
  return corr_mat_da.where(mask)
@@ -259,11 +253,11 @@ def _find_extreme_corr_pairs(
259
253
  extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
260
254
 
261
255
  return (
262
- extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME)
256
+ extreme_corr_da.to_dataframe(name=eda_constants.CORRELATION)
263
257
  .dropna()
264
258
  .assign(**{
265
259
  eda_constants.ABS_CORRELATION_COL_NAME: (
266
- lambda x: x[_CORRELATION_COL_NAME].abs()
260
+ lambda x: x[eda_constants.CORRELATION].abs()
267
261
  )
268
262
  })
269
263
  .sort_values(
@@ -286,11 +280,11 @@ def _get_outlier_bounds(
286
280
  A tuple containing the lower and upper bounds of outliers as DataArrays.
287
281
  """
288
282
  # TODO: Allow users to specify custom outlier definitions.
289
- q1 = input_da.quantile(_Q1_THRESHOLD, dim=constants.TIME)
290
- q3 = input_da.quantile(_Q3_THRESHOLD, dim=constants.TIME)
283
+ q1 = input_da.quantile(eda_constants.Q1_THRESHOLD, dim=constants.TIME)
284
+ q3 = input_da.quantile(eda_constants.Q3_THRESHOLD, dim=constants.TIME)
291
285
  iqr = q3 - q1
292
- lower_bound = q1 - _IQR_MULTIPLIER * iqr
293
- upper_bound = q3 + _IQR_MULTIPLIER * iqr
286
+ lower_bound = q1 - eda_constants.IQR_MULTIPLIER * iqr
287
+ upper_bound = q3 + eda_constants.IQR_MULTIPLIER * iqr
294
288
  return lower_bound, upper_bound
295
289
 
296
290
 
@@ -314,11 +308,10 @@ def _calculate_std(
314
308
  )
315
309
  std_without_outliers = da_no_outlier.std(dim=constants.TIME, ddof=1)
316
310
 
317
- std_ds = xr.Dataset({
318
- _STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
319
- _STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
311
+ return xr.Dataset({
312
+ eda_constants.STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
313
+ eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
320
314
  })
321
- return std_ds
322
315
 
323
316
 
324
317
  def _calculate_outliers(
@@ -334,23 +327,29 @@ def _calculate_outliers(
334
327
  outlier values.
335
328
  """
336
329
  lower_bound, upper_bound = _get_outlier_bounds(input_da)
337
- outlier_da = input_da.where(
338
- (input_da < lower_bound) | (input_da > upper_bound)
339
- )
340
- outlier_df = (
341
- outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME)
330
+ return (
331
+ input_da.where((input_da < lower_bound) | (input_da > upper_bound))
332
+ .to_dataframe(name=eda_constants.OUTLIERS_COL_NAME)
342
333
  .dropna()
343
- .assign(
344
- **{_ABS_OUTLIERS_COL_NAME: lambda x: np.abs(x[_OUTLIERS_COL_NAME])}
334
+ .assign(**{
335
+ eda_constants.ABS_OUTLIERS_COL_NAME: lambda x: np.abs(
336
+ x[eda_constants.OUTLIERS_COL_NAME]
337
+ )
338
+ })
339
+ .sort_values(
340
+ by=eda_constants.ABS_OUTLIERS_COL_NAME,
341
+ ascending=False,
342
+ inplace=False,
345
343
  )
346
- .sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
347
344
  )
348
- return outlier_df
349
345
 
350
346
 
351
347
  def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
352
348
  """Helper function to compute variance inflation factor.
353
349
 
350
+ The VIF calculation only focuses on multicollinearity among non-constant
351
+ variables. Any variable with constant values will result in a NaN VIF value.
352
+
354
353
  Args:
355
354
  input_da: A DataArray for which to calculate the VIF over sample dimensions
356
355
  (e.g. time and geo if applicable).
@@ -359,25 +358,28 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
359
358
  Returns:
360
359
  A DataArray containing the VIF for each variable in the variable dimension.
361
360
  """
361
+
362
362
  num_vars = input_da.sizes[var_dim]
363
363
  np_data = input_da.values.reshape(-1, num_vars)
364
- np_data_with_const = sm.add_constant(
365
- np_data, prepend=True, has_constant='add'
366
- )
367
364
 
368
- # Compute VIF for each variable excluding const which is the first one in the
369
- # 'variable' dimension.
370
- vifs = [
371
- outliers_influence.variance_inflation_factor(np_data_with_const, i)
372
- for i in range(1, num_vars + 1)
373
- ]
365
+ is_constant = np.std(np_data, axis=0) < eda_constants.STD_THRESHOLD
366
+ vif_values = np.full(num_vars, np.nan)
367
+ (non_constant_vars_indices,) = (~is_constant).nonzero()
368
+
369
+ if non_constant_vars_indices.size > 0:
370
+ design_matrix = sm.add_constant(
371
+ np_data[:, ~is_constant], prepend=True, has_constant='add'
372
+ )
373
+ for i, var_index in enumerate(non_constant_vars_indices):
374
+ vif_values[var_index] = outliers_influence.variance_inflation_factor(
375
+ design_matrix, i + 1
376
+ )
374
377
 
375
- vif_da = xr.DataArray(
376
- vifs,
378
+ return xr.DataArray(
379
+ vif_values,
377
380
  coords={var_dim: input_da[var_dim].values},
378
381
  dims=[var_dim],
379
382
  )
380
- return vif_da
381
383
 
382
384
 
383
385
  def _check_cost_media_unit_inconsistency(
@@ -392,9 +394,9 @@ def _check_cost_media_unit_inconsistency(
392
394
 
393
395
  Returns:
394
396
  A DataFrame of inconsistencies where either cost is zero and media units
395
- are
396
- positive, or cost is positive and media units are zero.
397
+ are positive, or cost is positive and media units are zero.
397
398
  """
399
+
398
400
  cost_media_units_ds = xr.merge([cost_da, media_units_da])
399
401
 
400
402
  # Condition 1: cost == 0 and media unit > 0
@@ -432,39 +434,69 @@ def _check_cost_per_media_unit(
432
434
  cost_da,
433
435
  media_units_da,
434
436
  )
437
+
438
+ # Calculate cost per media unit. Avoid division by zero by setting cost to
439
+ # NaN where media units are 0. Note that both (cost == media unit == 0) and
440
+ # (cost > 0 and media unit == 0) result in NaN, while the latter one is not
441
+ # desired.
442
+ cost_per_media_unit_da = xr.where(
443
+ media_units_da == 0,
444
+ np.nan,
445
+ cost_da / media_units_da,
446
+ )
447
+ cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
448
+ outlier_df = _calculate_outliers(cost_per_media_unit_da)
449
+
450
+ if not outlier_df.empty:
451
+ outlier_df = outlier_df.rename(
452
+ columns={
453
+ eda_constants.OUTLIERS_COL_NAME: eda_constants.COST_PER_MEDIA_UNIT,
454
+ eda_constants.ABS_OUTLIERS_COL_NAME: (
455
+ eda_constants.ABS_COST_PER_MEDIA_UNIT
456
+ ),
457
+ }
458
+ ).assign(**{
459
+ constants.SPEND: cost_da.to_series(),
460
+ constants.MEDIA_UNITS: media_units_da.to_series(),
461
+ })[[
462
+ constants.SPEND,
463
+ constants.MEDIA_UNITS,
464
+ eda_constants.COST_PER_MEDIA_UNIT,
465
+ eda_constants.ABS_COST_PER_MEDIA_UNIT,
466
+ ]]
467
+
468
+ artifact = eda_outcome.CostPerMediaUnitArtifact(
469
+ level=level,
470
+ cost_per_media_unit_da=cost_per_media_unit_da,
471
+ cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
472
+ outlier_df=outlier_df,
473
+ )
474
+
435
475
  if not cost_media_unit_inconsistency_df.empty:
436
476
  findings.append(
437
477
  eda_outcome.EDAFinding(
438
478
  severity=eda_outcome.EDASeverity.ATTENTION,
439
479
  explanation=(
440
- 'There are instances of inconsistent cost and media units.'
441
- ' This occurs when cost is zero but media units are positive,'
442
- ' or when cost is positive but media units are zero. Please'
443
- ' review the outcome artifact for more details.'
480
+ 'There are instances of inconsistent spend and media units.'
481
+ ' This occurs when spend is zero but media units are positive,'
482
+ ' or when spend is positive but media units are zero. Please'
483
+ ' review the data input for media units and spend.'
444
484
  ),
485
+ finding_cause=eda_outcome.FindingCause.INCONSISTENT_DATA,
486
+ associated_artifact=artifact,
445
487
  )
446
488
  )
447
489
 
448
- # Calculate cost per media unit
449
- # Avoid division by zero by setting cost to NaN where media units are 0.
450
- # Note that both (cost == media unit == 0) and (cost > 0 and media unit ==
451
- # 0) result in NaN, while the latter one is not desired.
452
- cost_per_media_unit_da = xr.where(
453
- media_units_da == 0,
454
- np.nan,
455
- cost_da / media_units_da,
456
- )
457
- cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
458
-
459
- outlier_df = _calculate_outliers(cost_per_media_unit_da)
460
490
  if not outlier_df.empty:
461
491
  findings.append(
462
492
  eda_outcome.EDAFinding(
463
493
  severity=eda_outcome.EDASeverity.ATTENTION,
464
494
  explanation=(
465
495
  'There are outliers in cost per media unit across time.'
466
- ' Please review the outcome artifact for more details.'
496
+ ' Please check for any possible data input error.'
467
497
  ),
498
+ finding_cause=eda_outcome.FindingCause.OUTLIER,
499
+ associated_artifact=artifact,
468
500
  )
469
501
  )
470
502
 
@@ -474,16 +506,10 @@ def _check_cost_per_media_unit(
474
506
  eda_outcome.EDAFinding(
475
507
  severity=eda_outcome.EDASeverity.INFO,
476
508
  explanation='Please review the cost per media unit data.',
509
+ finding_cause=eda_outcome.FindingCause.NONE,
477
510
  )
478
511
  )
479
512
 
480
- artifact = eda_outcome.CostPerMediaUnitArtifact(
481
- level=level,
482
- cost_per_media_unit_da=cost_per_media_unit_da,
483
- cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
484
- outlier_df=outlier_df,
485
- )
486
-
487
513
  return eda_outcome.EDAOutcome(
488
514
  check_type=eda_outcome.EDACheckType.COST_PER_MEDIA_UNIT,
489
515
  findings=findings,
@@ -491,42 +517,92 @@ def _check_cost_per_media_unit(
491
517
  )
492
518
 
493
519
 
520
+ def _calc_adj_r2(da: xr.DataArray, regressor: str) -> xr.DataArray:
521
+ """Calculates adjusted R-squared for a DataArray against a regressor.
522
+
523
+ If the input DataArray `da` is constant, it returns NaN.
524
+
525
+ Args:
526
+ da: The input DataArray.
527
+ regressor: The regressor to use in the formula.
528
+
529
+ Returns:
530
+ An xr.DataArray containing the adjusted R-squared value or NaN if `da` is
531
+ constant.
532
+ """
533
+ if da.std(ddof=1) < eda_constants.STD_THRESHOLD:
534
+ return xr.DataArray(np.nan)
535
+ tmp_name = 'dep_var'
536
+ df = da.to_dataframe(name=tmp_name).reset_index()
537
+ formula = f'{tmp_name} ~ C({regressor})'
538
+ ols = sm.OLS.from_formula(formula, df).fit()
539
+ return xr.DataArray(ols.rsquared_adj)
540
+
541
+
542
+ def _spearman_coeff(x, y):
543
+ """Computes spearman correlation coefficient between two ArrayLike objects."""
544
+
545
+ return stats.spearmanr(x, y, nan_policy='omit').statistic
546
+
547
+
494
548
  class EDAEngine:
495
549
  """Meridian EDA Engine."""
496
550
 
497
551
  def __init__(
498
552
  self,
499
- meridian: model.Meridian,
553
+ # TODO: b/476230365 - Remove meridian arg.
554
+ meridian: model.Meridian | None = None,
500
555
  spec: eda_spec.EDASpec = eda_spec.EDASpec(),
556
+ *,
557
+ model_context: context.ModelContext | None = None,
501
558
  ):
502
- self._meridian = meridian
559
+ if meridian is not None and model_context is not None:
560
+ raise ValueError(
561
+ 'Only one of `meridian` or `model_context` can be provided.'
562
+ )
563
+ if meridian is not None:
564
+ warnings.warn(
565
+ 'Initializing EDAEngine with a Meridian object is deprecated'
566
+ ' and will be removed in a future version. Please use'
567
+ ' `model_context` instead.',
568
+ DeprecationWarning,
569
+ stacklevel=2,
570
+ )
571
+ self._model_context = meridian.model_context
572
+ elif model_context is not None:
573
+ self._model_context = model_context
574
+ else:
575
+ raise ValueError('Either `meridian` or `model_context` must be provided.')
576
+
577
+ self._input_data = self._model_context.input_data
503
578
  self._spec = spec
504
579
  self._agg_config = self._spec.aggregation_config
505
580
 
506
581
  @property
507
582
  def spec(self) -> eda_spec.EDASpec:
583
+ """The EDA specification."""
508
584
  return self._spec
509
585
 
510
586
  @property
511
587
  def _is_national_data(self) -> bool:
512
- return self._meridian.is_national
588
+ return self._model_context.is_national
513
589
 
514
590
  @functools.cached_property
515
591
  def controls_scaled_da(self) -> xr.DataArray | None:
516
- """Returns the scaled controls data array."""
517
- if self._meridian.input_data.controls is None:
592
+ """The scaled controls data array."""
593
+ if self._input_data.controls is None:
518
594
  return None
519
595
  controls_scaled_da = _data_array_like(
520
- da=self._meridian.input_data.controls,
521
- values=self._meridian.controls_scaled,
596
+ da=self._input_data.controls,
597
+ values=self._model_context.controls_scaled,
522
598
  )
523
599
  controls_scaled_da.name = constants.CONTROLS_SCALED
524
600
  return controls_scaled_da
525
601
 
526
602
  @functools.cached_property
527
603
  def national_controls_scaled_da(self) -> xr.DataArray | None:
528
- """Returns the national scaled controls data array."""
529
- if self._meridian.input_data.controls is None:
604
+ """The national scaled controls data array."""
605
+ if self._input_data.controls is None:
530
606
  return None
531
607
  if self._is_national_data:
532
608
  if self.controls_scaled_da is None:
@@ -538,7 +614,7 @@ class EDAEngine:
538
614
  national_da.name = constants.NATIONAL_CONTROLS_SCALED
539
615
  else:
540
616
  national_da = self._aggregate_and_scale_geo_da(
541
- self._meridian.input_data.controls,
617
+ self._input_data.controls,
542
618
  constants.NATIONAL_CONTROLS_SCALED,
543
619
  transformers.CenteringAndScalingTransformer,
544
620
  constants.CONTROL_VARIABLE,
@@ -548,43 +624,43 @@ class EDAEngine:
548
624
 
549
625
  @functools.cached_property
550
626
  def media_raw_da(self) -> xr.DataArray | None:
551
- """Returns the raw media data array."""
552
- if self._meridian.input_data.media is None:
627
+ """The raw media data array."""
628
+ if self._input_data.media is None:
553
629
  return None
554
- raw_media_da = self._truncate_media_time(self._meridian.input_data.media)
630
+ raw_media_da = self._truncate_media_time(self._input_data.media)
555
631
  raw_media_da.name = constants.MEDIA
556
632
  return raw_media_da
557
633
 
558
634
  @functools.cached_property
559
635
  def media_scaled_da(self) -> xr.DataArray | None:
560
- """Returns the scaled media data array."""
561
- if self._meridian.input_data.media is None:
636
+ """The scaled media data array."""
637
+ if self._input_data.media is None:
562
638
  return None
563
639
  media_scaled_da = _data_array_like(
564
- da=self._meridian.input_data.media,
565
- values=self._meridian.media_tensors.media_scaled,
640
+ da=self._input_data.media,
641
+ values=self._model_context.media_tensors.media_scaled,
566
642
  )
567
643
  media_scaled_da.name = constants.MEDIA_SCALED
568
644
  return self._truncate_media_time(media_scaled_da)
569
645
 
570
646
  @functools.cached_property
571
647
  def media_spend_da(self) -> xr.DataArray | None:
572
- """Returns media spend.
648
+ """The media spend data.
573
649
 
574
650
  If the input spend is aggregated, it is allocated across geo and time
575
651
  proportionally to media units.
576
652
  """
577
653
  # No need to truncate the media time for media spend.
578
- da = self._meridian.input_data.allocated_media_spend
579
- if da is None:
654
+ allocated_media_spend = self._input_data.allocated_media_spend
655
+ if allocated_media_spend is None:
580
656
  return None
581
- da = da.copy()
657
+ da = allocated_media_spend.copy()
582
658
  da.name = constants.MEDIA_SPEND
583
659
  return da
584
660
 
585
661
  @functools.cached_property
586
662
  def national_media_spend_da(self) -> xr.DataArray | None:
587
- """Returns the national media spend data array."""
663
+ """The national media spend data array."""
588
664
  media_spend = self.media_spend_da
589
665
  if media_spend is None:
590
666
  return None
@@ -593,7 +669,7 @@ class EDAEngine:
593
669
  national_da.name = constants.NATIONAL_MEDIA_SPEND
594
670
  else:
595
671
  national_da = self._aggregate_and_scale_geo_da(
596
- self._meridian.input_data.allocated_media_spend,
672
+ self._input_data.allocated_media_spend,
597
673
  constants.NATIONAL_MEDIA_SPEND,
598
674
  None,
599
675
  )
@@ -601,7 +677,7 @@ class EDAEngine:
601
677
 
602
678
  @functools.cached_property
603
679
  def national_media_raw_da(self) -> xr.DataArray | None:
604
- """Returns the national raw media data array."""
680
+ """The national raw media data array."""
605
681
  if self.media_raw_da is None:
606
682
  return None
607
683
  if self._is_national_data:
@@ -618,7 +694,7 @@ class EDAEngine:
618
694
 
619
695
  @functools.cached_property
620
696
  def national_media_scaled_da(self) -> xr.DataArray | None:
621
- """Returns the national scaled media data array."""
697
+ """The national scaled media data array."""
622
698
  if self.media_scaled_da is None:
623
699
  return None
624
700
  if self._is_national_data:
@@ -635,30 +711,30 @@ class EDAEngine:
635
711
 
636
712
  @functools.cached_property
637
713
  def organic_media_raw_da(self) -> xr.DataArray | None:
638
- """Returns the raw organic media data array."""
639
- if self._meridian.input_data.organic_media is None:
714
+ """The raw organic media data array."""
715
+ if self._input_data.organic_media is None:
640
716
  return None
641
717
  raw_organic_media_da = self._truncate_media_time(
642
- self._meridian.input_data.organic_media
718
+ self._input_data.organic_media
643
719
  )
644
720
  raw_organic_media_da.name = constants.ORGANIC_MEDIA
645
721
  return raw_organic_media_da
646
722
 
647
723
  @functools.cached_property
648
724
  def organic_media_scaled_da(self) -> xr.DataArray | None:
649
- """Returns the scaled organic media data array."""
650
- if self._meridian.input_data.organic_media is None:
725
+ """The scaled organic media data array."""
726
+ if self._input_data.organic_media is None:
651
727
  return None
652
728
  organic_media_scaled_da = _data_array_like(
653
- da=self._meridian.input_data.organic_media,
654
- values=self._meridian.organic_media_tensors.organic_media_scaled,
729
+ da=self._input_data.organic_media,
730
+ values=self._model_context.organic_media_tensors.organic_media_scaled,
655
731
  )
656
732
  organic_media_scaled_da.name = constants.ORGANIC_MEDIA_SCALED
657
733
  return self._truncate_media_time(organic_media_scaled_da)
658
734
 
659
735
  @functools.cached_property
660
736
  def national_organic_media_raw_da(self) -> xr.DataArray | None:
661
- """Returns the national raw organic media data array."""
737
+ """The national raw organic media data array."""
662
738
  if self.organic_media_raw_da is None:
663
739
  return None
664
740
  if self._is_national_data:
@@ -673,7 +749,7 @@ class EDAEngine:
673
749
 
674
750
  @functools.cached_property
675
751
  def national_organic_media_scaled_da(self) -> xr.DataArray | None:
676
- """Returns the national scaled organic media data array."""
752
+ """The national scaled organic media data array."""
677
753
  if self.organic_media_scaled_da is None:
678
754
  return None
679
755
  if self._is_national_data:
@@ -692,20 +768,20 @@ class EDAEngine:
692
768
 
693
769
  @functools.cached_property
694
770
  def non_media_scaled_da(self) -> xr.DataArray | None:
695
- """Returns the scaled non-media treatments data array."""
696
- if self._meridian.input_data.non_media_treatments is None:
771
+ """The scaled non-media treatments data array."""
772
+ if self._input_data.non_media_treatments is None:
697
773
  return None
698
774
  non_media_scaled_da = _data_array_like(
699
- da=self._meridian.input_data.non_media_treatments,
700
- values=self._meridian.non_media_treatments_normalized,
775
+ da=self._input_data.non_media_treatments,
776
+ values=self._model_context.non_media_treatments_normalized,
701
777
  )
702
778
  non_media_scaled_da.name = constants.NON_MEDIA_TREATMENTS_SCALED
703
779
  return non_media_scaled_da
704
780
 
705
781
  @functools.cached_property
706
782
  def national_non_media_scaled_da(self) -> xr.DataArray | None:
707
- """Returns the national scaled non-media treatment data array."""
708
- if self._meridian.input_data.non_media_treatments is None:
783
+ """The national scaled non-media treatment data array."""
784
+ if self._input_data.non_media_treatments is None:
709
785
  return None
710
786
  if self._is_national_data:
711
787
  if self.non_media_scaled_da is None:
@@ -717,7 +793,7 @@ class EDAEngine:
717
793
  national_da.name = constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED
718
794
  else:
719
795
  national_da = self._aggregate_and_scale_geo_da(
720
- self._meridian.input_data.non_media_treatments,
796
+ self._input_data.non_media_treatments,
721
797
  constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
722
798
  transformers.CenteringAndScalingTransformer,
723
799
  constants.NON_MEDIA_CHANNEL,
@@ -727,12 +803,12 @@ class EDAEngine:
727
803
 
728
804
  @functools.cached_property
729
805
  def rf_spend_da(self) -> xr.DataArray | None:
730
- """Returns RF spend.
806
+ """The RF spend data.
731
807
 
732
808
  If the input spend is aggregated, it is allocated across geo and time
733
809
  proportionally to RF impressions (reach * frequency).
734
810
  """
735
- da = self._meridian.input_data.allocated_rf_spend
811
+ da = self._input_data.allocated_rf_spend
736
812
  if da is None:
737
813
  return None
738
814
  da = da.copy()
@@ -741,7 +817,7 @@ class EDAEngine:
741
817
 
742
818
  @functools.cached_property
743
819
  def national_rf_spend_da(self) -> xr.DataArray | None:
744
- """Returns the national RF spend data array."""
820
+ """The national RF spend data array."""
745
821
  rf_spend = self.rf_spend_da
746
822
  if rf_spend is None:
747
823
  return None
@@ -750,7 +826,7 @@ class EDAEngine:
750
826
  national_da.name = constants.NATIONAL_RF_SPEND
751
827
  else:
752
828
  national_da = self._aggregate_and_scale_geo_da(
753
- self._meridian.input_data.allocated_rf_spend,
829
+ self._input_data.allocated_rf_spend,
754
830
  constants.NATIONAL_RF_SPEND,
755
831
  None,
756
832
  )
@@ -758,182 +834,182 @@ class EDAEngine:
758
834
 
759
835
  @functools.cached_property
760
836
  def _rf_data(self) -> ReachFrequencyData | None:
761
- if self._meridian.input_data.reach is None:
837
+ if self._input_data.reach is None:
762
838
  return None
763
839
  return self._get_rf_data(
764
- self._meridian.input_data.reach,
765
- self._meridian.input_data.frequency,
840
+ self._input_data.reach,
841
+ self._input_data.frequency,
766
842
  is_organic=False,
767
843
  )
768
844
 
769
845
  @property
770
846
  def reach_raw_da(self) -> xr.DataArray | None:
771
- """Returns the raw reach data array."""
847
+ """The raw reach data array."""
772
848
  if self._rf_data is None:
773
849
  return None
774
- return self._rf_data.reach_raw_da
850
+ return self._rf_data.reach_raw_da # pytype: disable=attribute-error
775
851
 
776
852
  @property
777
853
  def reach_scaled_da(self) -> xr.DataArray | None:
778
- """Returns the scaled reach data array."""
854
+ """The scaled reach data array."""
779
855
  if self._rf_data is None:
780
856
  return None
781
857
  return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
782
858
 
783
859
  @property
784
860
  def national_reach_raw_da(self) -> xr.DataArray | None:
785
- """Returns the national raw reach data array."""
861
+ """The national raw reach data array."""
786
862
  if self._rf_data is None:
787
863
  return None
788
864
  return self._rf_data.national_reach_raw_da
789
865
 
790
866
  @property
791
867
  def national_reach_scaled_da(self) -> xr.DataArray | None:
792
- """Returns the national scaled reach data array."""
868
+ """The national scaled reach data array."""
793
869
  if self._rf_data is None:
794
870
  return None
795
871
  return self._rf_data.national_reach_scaled_da # pytype: disable=attribute-error
796
872
 
797
873
  @property
798
874
  def frequency_da(self) -> xr.DataArray | None:
799
- """Returns the frequency data array."""
875
+ """The frequency data array."""
800
876
  if self._rf_data is None:
801
877
  return None
802
878
  return self._rf_data.frequency_da # pytype: disable=attribute-error
803
879
 
804
880
  @property
805
881
  def national_frequency_da(self) -> xr.DataArray | None:
806
- """Returns the national frequency data array."""
882
+ """The national frequency data array."""
807
883
  if self._rf_data is None:
808
884
  return None
809
885
  return self._rf_data.national_frequency_da # pytype: disable=attribute-error
810
886
 
811
887
  @property
812
888
  def rf_impressions_raw_da(self) -> xr.DataArray | None:
813
- """Returns the raw RF impressions data array."""
889
+ """The raw RF impressions data array."""
814
890
  if self._rf_data is None:
815
891
  return None
816
892
  return self._rf_data.rf_impressions_raw_da # pytype: disable=attribute-error
817
893
 
818
894
  @property
819
895
  def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
820
- """Returns the national raw RF impressions data array."""
896
+ """The national raw RF impressions data array."""
821
897
  if self._rf_data is None:
822
898
  return None
823
899
  return self._rf_data.national_rf_impressions_raw_da # pytype: disable=attribute-error
824
900
 
825
901
  @property
826
902
  def rf_impressions_scaled_da(self) -> xr.DataArray | None:
827
- """Returns the scaled RF impressions data array."""
903
+ """The scaled RF impressions data array."""
828
904
  if self._rf_data is None:
829
905
  return None
830
906
  return self._rf_data.rf_impressions_scaled_da
831
907
 
832
908
  @property
833
909
  def national_rf_impressions_scaled_da(self) -> xr.DataArray | None:
834
- """Returns the national scaled RF impressions data array."""
910
+ """The national scaled RF impressions data array."""
835
911
  if self._rf_data is None:
836
912
  return None
837
913
  return self._rf_data.national_rf_impressions_scaled_da
838
914
 
839
915
  @functools.cached_property
840
916
  def _organic_rf_data(self) -> ReachFrequencyData | None:
841
- if self._meridian.input_data.organic_reach is None:
917
+ if self._input_data.organic_reach is None:
842
918
  return None
843
919
  return self._get_rf_data(
844
- self._meridian.input_data.organic_reach,
845
- self._meridian.input_data.organic_frequency,
920
+ self._input_data.organic_reach,
921
+ self._input_data.organic_frequency,
846
922
  is_organic=True,
847
923
  )
848
924
 
849
925
  @property
850
926
  def organic_reach_raw_da(self) -> xr.DataArray | None:
851
- """Returns the raw organic reach data array."""
927
+ """The raw organic reach data array."""
852
928
  if self._organic_rf_data is None:
853
929
  return None
854
- return self._organic_rf_data.reach_raw_da
930
+ return self._organic_rf_data.reach_raw_da # pytype: disable=attribute-error
855
931
 
856
932
  @property
857
933
  def organic_reach_scaled_da(self) -> xr.DataArray | None:
858
- """Returns the scaled organic reach data array."""
934
+ """The scaled organic reach data array."""
859
935
  if self._organic_rf_data is None:
860
936
  return None
861
937
  return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
862
938
 
863
939
  @property
864
940
  def national_organic_reach_raw_da(self) -> xr.DataArray | None:
865
- """Returns the national raw organic reach data array."""
941
+ """The national raw organic reach data array."""
866
942
  if self._organic_rf_data is None:
867
943
  return None
868
944
  return self._organic_rf_data.national_reach_raw_da
869
945
 
870
946
  @property
871
947
  def national_organic_reach_scaled_da(self) -> xr.DataArray | None:
872
- """Returns the national scaled organic reach data array."""
948
+ """The national scaled organic reach data array."""
873
949
  if self._organic_rf_data is None:
874
950
  return None
875
951
  return self._organic_rf_data.national_reach_scaled_da # pytype: disable=attribute-error
876
952
 
877
953
  @property
878
954
  def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
879
- """Returns the scaled organic RF impressions data array."""
955
+ """The scaled organic RF impressions data array."""
880
956
  if self._organic_rf_data is None:
881
957
  return None
882
958
  return self._organic_rf_data.rf_impressions_scaled_da
883
959
 
884
960
  @property
885
961
  def national_organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
886
- """Returns the national scaled organic RF impressions data array."""
962
+ """The national scaled organic RF impressions data array."""
887
963
  if self._organic_rf_data is None:
888
964
  return None
889
965
  return self._organic_rf_data.national_rf_impressions_scaled_da
890
966
 
891
967
  @property
892
968
  def organic_frequency_da(self) -> xr.DataArray | None:
893
- """Returns the organic frequency data array."""
969
+ """The organic frequency data array."""
894
970
  if self._organic_rf_data is None:
895
971
  return None
896
972
  return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
897
973
 
898
974
  @property
899
975
  def national_organic_frequency_da(self) -> xr.DataArray | None:
900
- """Returns the national organic frequency data array."""
976
+ """The national organic frequency data array."""
901
977
  if self._organic_rf_data is None:
902
978
  return None
903
979
  return self._organic_rf_data.national_frequency_da # pytype: disable=attribute-error
904
980
 
905
981
  @property
906
982
  def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
907
- """Returns the raw organic RF impressions data array."""
983
+ """The raw organic RF impressions data array."""
908
984
  if self._organic_rf_data is None:
909
985
  return None
910
986
  return self._organic_rf_data.rf_impressions_raw_da
911
987
 
912
988
  @property
913
989
  def national_organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
914
- """Returns the national raw organic RF impressions data array."""
990
+ """The national raw organic RF impressions data array."""
915
991
  if self._organic_rf_data is None:
916
992
  return None
917
993
  return self._organic_rf_data.national_rf_impressions_raw_da
918
994
 
919
995
  @functools.cached_property
920
996
  def geo_population_da(self) -> xr.DataArray | None:
921
- """Returns the geo population data array."""
997
+ """The geo population data array."""
922
998
  if self._is_national_data:
923
999
  return None
924
1000
  return xr.DataArray(
925
- self._meridian.population,
926
- coords={constants.GEO: self._meridian.input_data.geo.values},
1001
+ self._model_context.population,
1002
+ coords={constants.GEO: self._input_data.geo.values},
927
1003
  dims=[constants.GEO],
928
1004
  name=constants.POPULATION,
929
1005
  )
930
1006
 
931
1007
  @functools.cached_property
932
1008
  def kpi_scaled_da(self) -> xr.DataArray:
933
- """Returns the scaled KPI data array."""
1009
+ """The scaled KPI data array."""
934
1010
  scaled_kpi_da = _data_array_like(
935
- da=self._meridian.input_data.kpi,
936
- values=self._meridian.kpi_scaled,
1011
+ da=self._input_data.kpi,
1012
+ values=self._model_context.kpi_scaled,
937
1013
  )
938
1014
  scaled_kpi_da.name = constants.KPI_SCALED
939
1015
  return scaled_kpi_da
@@ -942,7 +1018,7 @@ class EDAEngine:
942
1018
  def _overall_scaled_kpi_invariability_artifact(
943
1019
  self,
944
1020
  ) -> eda_outcome.KpiInvariabilityArtifact:
945
- """Returns an artifact of overall scaled KPI invariability."""
1021
+ """An artifact of overall scaled KPI invariability."""
946
1022
  return eda_outcome.KpiInvariabilityArtifact(
947
1023
  level=eda_outcome.AnalysisLevel.OVERALL,
948
1024
  kpi_da=self.kpi_scaled_da,
@@ -951,14 +1027,14 @@ class EDAEngine:
951
1027
 
952
1028
  @functools.cached_property
953
1029
  def national_kpi_scaled_da(self) -> xr.DataArray:
954
- """Returns the national scaled KPI data array."""
1030
+ """The national scaled KPI data array."""
955
1031
  if self._is_national_data:
956
1032
  national_da = self.kpi_scaled_da.squeeze(constants.GEO, drop=True)
957
1033
  national_da.name = constants.NATIONAL_KPI_SCALED
958
1034
  else:
959
1035
  # Note that kpi is summable by assumption.
960
1036
  national_da = self._aggregate_and_scale_geo_da(
961
- self._meridian.input_data.kpi,
1037
+ self._input_data.kpi,
962
1038
  constants.NATIONAL_KPI_SCALED,
963
1039
  transformers.CenteringAndScalingTransformer,
964
1040
  )
@@ -966,7 +1042,7 @@ class EDAEngine:
966
1042
 
967
1043
  @functools.cached_property
968
1044
  def treatment_control_scaled_ds(self) -> xr.Dataset:
969
- """Returns a Dataset containing all scaled treatments and controls.
1045
+ """A Dataset containing all scaled treatments and controls.
970
1046
 
971
1047
  This includes media, RF impressions, organic media, organic RF impressions,
972
1048
  non-media treatments, and control variables, all at the geo level.
@@ -987,7 +1063,7 @@ class EDAEngine:
987
1063
 
988
1064
  @functools.cached_property
989
1065
  def all_spend_ds(self) -> xr.Dataset:
990
- """Returns a Dataset containing all spend data.
1066
+ """A Dataset containing all spend data.
991
1067
 
992
1068
  This includes media spend and rf spend.
993
1069
  """
@@ -1003,7 +1079,7 @@ class EDAEngine:
1003
1079
 
1004
1080
  @functools.cached_property
1005
1081
  def national_all_spend_ds(self) -> xr.Dataset:
1006
- """Returns a Dataset containing all national spend data.
1082
+ """A Dataset containing all national spend data.
1007
1083
 
1008
1084
  This includes media spend and rf spend.
1009
1085
  """
@@ -1019,14 +1095,14 @@ class EDAEngine:
1019
1095
 
1020
1096
  @functools.cached_property
1021
1097
  def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
1022
- """Returns a stacked DataArray of treatment_control_scaled_ds."""
1098
+ """A stacked DataArray of treatment_control_scaled_ds."""
1023
1099
  da = stack_variables(self.treatment_control_scaled_ds)
1024
1100
  da.name = constants.TREATMENT_CONTROL_SCALED
1025
1101
  return da
1026
1102
 
1027
1103
  @functools.cached_property
1028
1104
  def national_treatment_control_scaled_ds(self) -> xr.Dataset:
1029
- """Returns a Dataset containing all scaled treatments and controls.
1105
+ """A Dataset containing all scaled treatments and controls.
1030
1106
 
1031
1107
  This includes media, RF impressions, organic media, organic RF impressions,
1032
1108
  non-media treatments, and control variables, all at the national level.
@@ -1047,14 +1123,14 @@ class EDAEngine:
1047
1123
 
1048
1124
  @functools.cached_property
1049
1125
  def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
1050
- """Returns a stacked DataArray of national_treatment_control_scaled_ds."""
1126
+ """A stacked DataArray of national_treatment_control_scaled_ds."""
1051
1127
  da = stack_variables(self.national_treatment_control_scaled_ds)
1052
1128
  da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
1053
1129
  return da
1054
1130
 
1055
1131
  @functools.cached_property
1056
1132
  def treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1057
- """Returns a Dataset of scaled treatments excluding non-media."""
1133
+ """A Dataset of scaled treatments excluding non-media."""
1058
1134
  return self.treatment_control_scaled_ds.drop_dims(
1059
1135
  [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1060
1136
  errors='ignore',
@@ -1062,15 +1138,34 @@ class EDAEngine:
1062
1138
 
1063
1139
  @functools.cached_property
1064
1140
  def national_treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1065
- """Returns a Dataset of national scaled treatments excluding non-media."""
1141
+ """A Dataset of national scaled treatments excluding non-media."""
1066
1142
  return self.national_treatment_control_scaled_ds.drop_dims(
1067
1143
  [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1068
1144
  errors='ignore',
1069
1145
  )
1070
1146
 
1147
+ @functools.cached_property
1148
+ def controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
1149
+ """A Dataset of scaled controls and non-media treatments."""
1150
+ return _get_vars_from_dataset(
1151
+ self.treatment_control_scaled_ds,
1152
+ [constants.CONTROLS_SCALED, constants.NON_MEDIA_TREATMENTS_SCALED],
1153
+ )
1154
+
1155
+ @functools.cached_property
1156
+ def national_controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
1157
+ """A Dataset of national scaled controls and non-media treatments."""
1158
+ return _get_vars_from_dataset(
1159
+ self.national_treatment_control_scaled_ds,
1160
+ [
1161
+ constants.NATIONAL_CONTROLS_SCALED,
1162
+ constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
1163
+ ],
1164
+ )
1165
+
1071
1166
  @functools.cached_property
1072
1167
  def all_reach_scaled_da(self) -> xr.DataArray | None:
1073
- """Returns a DataArray containing all scaled reach data.
1168
+ """A DataArray containing all scaled reach data.
1074
1169
 
1075
1170
  This includes both paid and organic reach, concatenated along the RF_CHANNEL
1076
1171
  dimension.
@@ -1096,7 +1191,7 @@ class EDAEngine:
1096
1191
 
1097
1192
  @functools.cached_property
1098
1193
  def all_freq_da(self) -> xr.DataArray | None:
1099
- """Returns a DataArray containing all frequency data.
1194
+ """A DataArray containing all frequency data.
1100
1195
 
1101
1196
  This includes both paid and organic frequency, concatenated along the
1102
1197
  RF_CHANNEL dimension.
@@ -1122,7 +1217,7 @@ class EDAEngine:
1122
1217
 
1123
1218
  @functools.cached_property
1124
1219
  def national_all_reach_scaled_da(self) -> xr.DataArray | None:
1125
- """Returns a DataArray containing all national-level scaled reach data.
1220
+ """A DataArray containing all national-level scaled reach data.
1126
1221
 
1127
1222
  This includes both paid and organic reach, concatenated along the
1128
1223
  RF_CHANNEL dimension.
@@ -1149,7 +1244,7 @@ class EDAEngine:
1149
1244
 
1150
1245
  @functools.cached_property
1151
1246
  def national_all_freq_da(self) -> xr.DataArray | None:
1152
- """Returns a DataArray containing all national-level frequency data.
1247
+ """A DataArray containing all national-level frequency data.
1153
1248
 
1154
1249
  This includes both paid and organic frequency, concatenated along the
1155
1250
  RF_CHANNEL dimension.
@@ -1202,7 +1297,7 @@ class EDAEngine:
1202
1297
  def _critical_checks(
1203
1298
  self,
1204
1299
  ) -> list[tuple[_NamedEDACheckCallable, eda_outcome.EDACheckType]]:
1205
- """Returns a list of critical checks to be performed."""
1300
+ """A list of critical checks to be performed."""
1206
1301
  checks = [
1207
1302
  (
1208
1303
  self.check_overall_kpi_invariability,
@@ -1221,19 +1316,21 @@ class EDAEngine:
1221
1316
  # This should not happen. If it does, it means this function is mis-used.
1222
1317
  if constants.MEDIA_TIME not in da.coords:
1223
1318
  raise ValueError(
1224
- f'Variable does not have a media time coordinate: {da.name}.'
1319
+ f'Variable does not have a media time coordinate: {da.name!r}.'
1225
1320
  )
1226
1321
 
1227
- start = self._meridian.n_media_times - self._meridian.n_times
1228
- da = da.copy().isel({constants.MEDIA_TIME: slice(start, None)})
1229
- da = da.rename({constants.MEDIA_TIME: constants.TIME})
1230
- return da
1322
+ start = self._model_context.n_media_times - self._model_context.n_times
1323
+ return (
1324
+ da.copy()
1325
+ .isel({constants.MEDIA_TIME: slice(start, None)})
1326
+ .rename({constants.MEDIA_TIME: constants.TIME})
1327
+ )
1231
1328
 
1232
1329
  def _scale_xarray(
1233
1330
  self,
1234
1331
  xarray: xr.DataArray,
1235
- transformer_class: Optional[type[transformers.TensorTransformer]],
1236
- population: Optional[backend.Tensor] = None,
1332
+ transformer_class: type[transformers.TensorTransformer] | None,
1333
+ population: backend.Tensor | None = None,
1237
1334
  ) -> xr.DataArray:
1238
1335
  """Scales xarray values with a TensorTransformer."""
1239
1336
  da = xarray.copy()
@@ -1285,7 +1382,9 @@ class EDAEngine:
1285
1382
  agg_results = []
1286
1383
  for var_name in geo_da[channel_dim].values:
1287
1384
  var_data = geo_da.sel({channel_dim: var_name})
1288
- agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
1385
+ agg_func = da_var_agg_map.get(
1386
+ var_name, eda_constants.DEFAULT_DA_VAR_AGG_FUNCTION
1387
+ )
1289
1388
  # Apply the aggregation function over the GEO dimension
1290
1389
  aggregated_data = var_data.reduce(
1291
1390
  agg_func, dim=constants.GEO, keepdims=keepdims
@@ -1299,9 +1398,9 @@ class EDAEngine:
1299
1398
  self,
1300
1399
  geo_da: xr.DataArray,
1301
1400
  national_da_name: str,
1302
- transformer_class: Optional[type[transformers.TensorTransformer]],
1303
- channel_dim: Optional[str] = None,
1304
- da_var_agg_map: Optional[eda_spec.AggregationMap] = None,
1401
+ transformer_class: type[transformers.TensorTransformer] | None,
1402
+ channel_dim: str | None = None,
1403
+ da_var_agg_map: eda_spec.AggregationMap | None = None,
1305
1404
  ) -> xr.DataArray:
1306
1405
  """Aggregate geo-level xr.DataArray to national level and then scale values.
1307
1406
 
@@ -1351,11 +1450,11 @@ class EDAEngine:
1351
1450
  """Get impressions and frequencies data arrays for RF channels."""
1352
1451
  if is_organic:
1353
1452
  scaled_reach_values = (
1354
- self._meridian.organic_rf_tensors.organic_reach_scaled
1453
+ self._model_context.organic_rf_tensors.organic_reach_scaled
1355
1454
  )
1356
1455
  names = _ORGANIC_RF_NAMES
1357
1456
  else:
1358
- scaled_reach_values = self._meridian.rf_tensors.reach_scaled
1457
+ scaled_reach_values = self._model_context.rf_tensors.reach_scaled
1359
1458
  names = _RF_NAMES
1360
1459
 
1361
1460
  reach_scaled_da = _data_array_like(
@@ -1433,7 +1532,7 @@ class EDAEngine:
1433
1532
  impressions_scaled_da = self._scale_xarray(
1434
1533
  impressions_raw_da,
1435
1534
  transformers.MediaTransformer,
1436
- population=self._meridian.population,
1535
+ population=self._model_context.population,
1437
1536
  )
1438
1537
  impressions_scaled_da.name = names.impressions_scaled
1439
1538
 
@@ -1491,9 +1590,16 @@ class EDAEngine:
1491
1590
  overall_corr_mat, overall_extreme_corr_var_pairs_df = (
1492
1591
  self._pairwise_corr_for_geo_data(
1493
1592
  dims=[constants.GEO, constants.TIME],
1494
- extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
1593
+ extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
1495
1594
  )
1496
1595
  )
1596
+ overall_artifact = eda_outcome.PairwiseCorrArtifact(
1597
+ level=eda_outcome.AnalysisLevel.OVERALL,
1598
+ corr_matrix=overall_corr_mat,
1599
+ extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
1600
+ extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
1601
+ )
1602
+
1497
1603
  if not overall_extreme_corr_var_pairs_df.empty:
1498
1604
  var_pairs = overall_extreme_corr_var_pairs_df.index.to_list()
1499
1605
  findings.append(
@@ -1505,21 +1611,33 @@ class EDAEngine:
1505
1611
  ' variables, please remove one of the variables from the'
1506
1612
  f' model.\nPairs with perfect correlation: {var_pairs}'
1507
1613
  ),
1614
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
1615
+ associated_artifact=overall_artifact,
1508
1616
  )
1509
1617
  )
1510
1618
 
1511
1619
  geo_corr_mat, geo_extreme_corr_var_pairs_df = (
1512
1620
  self._pairwise_corr_for_geo_data(
1513
1621
  dims=constants.TIME,
1514
- extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
1622
+ extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
1515
1623
  )
1516
1624
  )
1517
- # Overall correlation and per-geo correlation findings are mutually
1518
- # exclusive, and overall correlation finding takes precedence.
1519
- if (
1520
- overall_extreme_corr_var_pairs_df.empty
1521
- and not geo_extreme_corr_var_pairs_df.empty
1522
- ):
1625
+ # Pairs that cause overall level findings are very likely to cause geo
1626
+ # level findings as well, so we exclude them when determining geo-level
1627
+ # findings. This is to avoid over-reporting findings.
1628
+ overall_pairs_index = overall_extreme_corr_var_pairs_df.index
1629
+ is_in_overall = geo_extreme_corr_var_pairs_df.index.droplevel(
1630
+ constants.GEO
1631
+ ).isin(overall_pairs_index)
1632
+ geo_df_for_attention = geo_extreme_corr_var_pairs_df[~is_in_overall]
1633
+ geo_artifact = eda_outcome.PairwiseCorrArtifact(
1634
+ level=eda_outcome.AnalysisLevel.GEO,
1635
+ corr_matrix=geo_corr_mat,
1636
+ extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
1637
+ extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
1638
+ )
1639
+
1640
+ if not geo_df_for_attention.empty:
1523
1641
  findings.append(
1524
1642
  eda_outcome.EDAFinding(
1525
1643
  severity=eda_outcome.EDASeverity.ATTENTION,
@@ -1529,6 +1647,8 @@ class EDAEngine:
1529
1647
  ' variables if they also have high pairwise correlations in'
1530
1648
  ' other geos.'
1531
1649
  ),
1650
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
1651
+ associated_artifact=geo_artifact,
1532
1652
  )
1533
1653
  )
1534
1654
 
@@ -1538,34 +1658,15 @@ class EDAEngine:
1538
1658
  findings.append(
1539
1659
  eda_outcome.EDAFinding(
1540
1660
  severity=eda_outcome.EDASeverity.INFO,
1541
- explanation=(
1542
- 'Please review the computed pairwise correlations. Note that'
1543
- ' high pairwise correlation may cause model identifiability'
1544
- ' and convergence issues. Consider combining the variables if'
1545
- ' high correlation exists.'
1546
- ),
1661
+ explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
1662
+ finding_cause=eda_outcome.FindingCause.NONE,
1547
1663
  )
1548
1664
  )
1549
1665
 
1550
- pairwise_corr_artifacts = [
1551
- eda_outcome.PairwiseCorrArtifact(
1552
- level=eda_outcome.AnalysisLevel.OVERALL,
1553
- corr_matrix=overall_corr_mat,
1554
- extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
1555
- extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
1556
- ),
1557
- eda_outcome.PairwiseCorrArtifact(
1558
- level=eda_outcome.AnalysisLevel.GEO,
1559
- corr_matrix=geo_corr_mat,
1560
- extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
1561
- extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
1562
- ),
1563
- ]
1564
-
1565
1666
  return eda_outcome.EDAOutcome(
1566
1667
  check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
1567
1668
  findings=findings,
1568
- analysis_artifacts=pairwise_corr_artifacts,
1669
+ analysis_artifacts=[overall_artifact, geo_artifact],
1569
1670
  )
1570
1671
 
1571
1672
  def check_national_pairwise_corr(
@@ -1582,7 +1683,14 @@ class EDAEngine:
1582
1683
  self._stacked_national_treatment_control_scaled_da, dims=constants.TIME
1583
1684
  )
1584
1685
  extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
1585
- corr_mat, _NATIONAL_PAIRWISE_CORR_THRESHOLD
1686
+ corr_mat, eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD
1687
+ )
1688
+
1689
+ artifact = eda_outcome.PairwiseCorrArtifact(
1690
+ level=eda_outcome.AnalysisLevel.NATIONAL,
1691
+ corr_matrix=corr_mat,
1692
+ extreme_corr_var_pairs=extreme_corr_var_pairs_df,
1693
+ extreme_corr_threshold=eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD,
1586
1694
  )
1587
1695
 
1588
1696
  if not extreme_corr_var_pairs_df.empty:
@@ -1596,33 +1704,23 @@ class EDAEngine:
1596
1704
  ' variables, please remove one of the variables from the'
1597
1705
  f' model.\nPairs with perfect correlation: {var_pairs}'
1598
1706
  ),
1707
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
1708
+ associated_artifact=artifact,
1599
1709
  )
1600
1710
  )
1601
1711
  else:
1602
1712
  findings.append(
1603
1713
  eda_outcome.EDAFinding(
1604
1714
  severity=eda_outcome.EDASeverity.INFO,
1605
- explanation=(
1606
- 'Please review the computed pairwise correlations. Note that'
1607
- ' high pairwise correlation may cause model identifiability'
1608
- ' and convergence issues. Consider combining the variables if'
1609
- ' high correlation exists.'
1610
- ),
1715
+ explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
1716
+ finding_cause=eda_outcome.FindingCause.NONE,
1611
1717
  )
1612
1718
  )
1613
1719
 
1614
- pairwise_corr_artifacts = [
1615
- eda_outcome.PairwiseCorrArtifact(
1616
- level=eda_outcome.AnalysisLevel.NATIONAL,
1617
- corr_matrix=corr_mat,
1618
- extreme_corr_var_pairs=extreme_corr_var_pairs_df,
1619
- extreme_corr_threshold=_NATIONAL_PAIRWISE_CORR_THRESHOLD,
1620
- )
1621
- ]
1622
1720
  return eda_outcome.EDAOutcome(
1623
1721
  check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
1624
1722
  findings=findings,
1625
- analysis_artifacts=pairwise_corr_artifacts,
1723
+ analysis_artifacts=[artifact],
1626
1724
  )
1627
1725
 
1628
1726
  def check_pairwise_corr(
@@ -1643,20 +1741,14 @@ class EDAEngine:
1643
1741
  data: xr.DataArray,
1644
1742
  level: eda_outcome.AnalysisLevel,
1645
1743
  zero_std_message: str,
1744
+ outlier_message: str,
1646
1745
  ) -> tuple[
1647
- eda_outcome.EDAFinding | None, eda_outcome.StandardDeviationArtifact
1746
+ list[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
1648
1747
  ]:
1649
1748
  """Helper to check standard deviation."""
1650
1749
  std_ds = _calculate_std(data)
1651
1750
  outlier_df = _calculate_outliers(data)
1652
1751
 
1653
- finding = None
1654
- if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
1655
- finding = eda_outcome.EDAFinding(
1656
- severity=eda_outcome.EDASeverity.ATTENTION,
1657
- explanation=zero_std_message,
1658
- )
1659
-
1660
1752
  artifact = eda_outcome.StandardDeviationArtifact(
1661
1753
  variable=str(data.name),
1662
1754
  level=level,
@@ -1664,7 +1756,31 @@ class EDAEngine:
1664
1756
  outlier_df=outlier_df,
1665
1757
  )
1666
1758
 
1667
- return finding, artifact
1759
+ findings = []
1760
+ if (
1761
+ std_ds[eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME]
1762
+ < eda_constants.STD_THRESHOLD
1763
+ ).any():
1764
+ findings.append(
1765
+ eda_outcome.EDAFinding(
1766
+ severity=eda_outcome.EDASeverity.ATTENTION,
1767
+ explanation=zero_std_message,
1768
+ finding_cause=eda_outcome.FindingCause.VARIABILITY,
1769
+ associated_artifact=artifact,
1770
+ )
1771
+ )
1772
+
1773
+ if not outlier_df.empty:
1774
+ findings.append(
1775
+ eda_outcome.EDAFinding(
1776
+ severity=eda_outcome.EDASeverity.ATTENTION,
1777
+ explanation=outlier_message,
1778
+ finding_cause=eda_outcome.FindingCause.OUTLIER,
1779
+ associated_artifact=artifact,
1780
+ )
1781
+ )
1782
+
1783
+ return findings, artifact
1668
1784
 
1669
1785
  def check_geo_std(
1670
1786
  self,
@@ -1685,6 +1801,10 @@ class EDAEngine:
1685
1801
  ' variable for these geos. Please review the input data,'
1686
1802
  ' and/or consider grouping these geos together.'
1687
1803
  ),
1804
+ (
1805
+ 'There are outliers in the scaled KPI in certain geos.'
1806
+ ' Please check for any possible data errors.'
1807
+ ),
1688
1808
  ),
1689
1809
  (
1690
1810
  self._stacked_treatment_control_scaled_da,
@@ -1695,6 +1815,11 @@ class EDAEngine:
1695
1815
  ' consider combining them to mitigate potential model'
1696
1816
  ' identifiability and convergence issues.'
1697
1817
  ),
1818
+ (
1819
+ 'There are outliers in the scaled treatment or control'
1820
+ ' variables in certain geos. Please check for any possible data'
1821
+ ' errors.'
1822
+ ),
1698
1823
  ),
1699
1824
  (
1700
1825
  self.all_reach_scaled_da,
@@ -1705,6 +1830,11 @@ class EDAEngine:
1705
1830
  ' geos, consider modeling them as impression-based channels'
1706
1831
  ' instead by taking reach * frequency.'
1707
1832
  ),
1833
+ (
1834
+ 'There are outliers in the scaled reach values of the RF or'
1835
+ ' Organic RF channels in certain geos. Please check for any'
1836
+ ' possible data errors.'
1837
+ ),
1708
1838
  ),
1709
1839
  (
1710
1840
  self.all_freq_da,
@@ -1715,30 +1845,34 @@ class EDAEngine:
1715
1845
  ' geos, consider modeling them as impression-based channels'
1716
1846
  ' instead by taking reach * frequency.'
1717
1847
  ),
1848
+ (
1849
+ 'There are outliers in the scaled frequency values of the RF or'
1850
+ ' Organic RF channels in certain geos. Please check for any'
1851
+ ' possible data errors.'
1852
+ ),
1718
1853
  ),
1719
1854
  ]
1720
1855
 
1721
- for data_da, message in checks:
1856
+ for data_da, std_message, outlier_message in checks:
1722
1857
  if data_da is None:
1723
1858
  continue
1724
- finding, artifact = self._check_std(
1859
+ current_findings, artifact = self._check_std(
1725
1860
  level=eda_outcome.AnalysisLevel.GEO,
1726
1861
  data=data_da,
1727
- zero_std_message=message,
1862
+ zero_std_message=std_message,
1863
+ outlier_message=outlier_message,
1728
1864
  )
1729
1865
  artifacts.append(artifact)
1730
- if finding:
1731
- findings.append(finding)
1866
+ if current_findings:
1867
+ findings.extend(current_findings)
1732
1868
 
1733
1869
  # Add an INFO finding if no findings were added.
1734
1870
  if not findings:
1735
1871
  findings.append(
1736
1872
  eda_outcome.EDAFinding(
1737
1873
  severity=eda_outcome.EDASeverity.INFO,
1738
- explanation=(
1739
- 'Please review any identified outliers and the standard'
1740
- ' deviation.'
1741
- ),
1874
+ explanation='Please review the computed standard deviations.',
1875
+ finding_cause=eda_outcome.FindingCause.NONE,
1742
1876
  )
1743
1877
  )
1744
1878
 
@@ -1765,6 +1899,10 @@ class EDAEngine:
1765
1899
  ' the input data, and/or reconsider the feasibility of model'
1766
1900
  ' fitting with this dataset.'
1767
1901
  ),
1902
+ (
1903
+ 'There are outliers in the scaled KPI.'
1904
+ ' Please check for any possible data errors.'
1905
+ ),
1768
1906
  ),
1769
1907
  (
1770
1908
  self._stacked_national_treatment_control_scaled_da,
@@ -1776,6 +1914,10 @@ class EDAEngine:
1776
1914
  ' Please review the input data, and/or consider combining these'
1777
1915
  ' variables to mitigate sparsity.'
1778
1916
  ),
1917
+ (
1918
+ 'There are outliers in the scaled treatment or control'
1919
+ ' variables. Please check for any possible data errors.'
1920
+ ),
1779
1921
  ),
1780
1922
  (
1781
1923
  self.national_all_reach_scaled_da,
@@ -1785,6 +1927,11 @@ class EDAEngine:
1785
1927
  ' Consider modeling these RF channels as impression-based'
1786
1928
  ' channels instead.'
1787
1929
  ),
1930
+ (
1931
+ 'There are outliers in the scaled reach values of the RF or'
1932
+ ' Organic RF channels. Please check for any possible data'
1933
+ ' errors.'
1934
+ ),
1788
1935
  ),
1789
1936
  (
1790
1937
  self.national_all_freq_da,
@@ -1794,30 +1941,34 @@ class EDAEngine:
1794
1941
  ' Consider modeling these RF channels as impression-based'
1795
1942
  ' channels instead.'
1796
1943
  ),
1944
+ (
1945
+ 'There are outliers in the scaled frequency values of the RF or'
1946
+ ' Organic RF channels. Please check for any possible data'
1947
+ ' errors.'
1948
+ ),
1797
1949
  ),
1798
1950
  ]
1799
1951
 
1800
- for data_da, message in checks:
1952
+ for data_da, std_message, outlier_message in checks:
1801
1953
  if data_da is None:
1802
1954
  continue
1803
- finding, artifact = self._check_std(
1955
+ current_findings, artifact = self._check_std(
1804
1956
  data=data_da,
1805
1957
  level=eda_outcome.AnalysisLevel.NATIONAL,
1806
- zero_std_message=message,
1958
+ zero_std_message=std_message,
1959
+ outlier_message=outlier_message,
1807
1960
  )
1808
1961
  artifacts.append(artifact)
1809
- if finding:
1810
- findings.append(finding)
1962
+ if current_findings:
1963
+ findings.extend(current_findings)
1811
1964
 
1812
1965
  # Add an INFO finding if no findings were added.
1813
1966
  if not findings:
1814
1967
  findings.append(
1815
1968
  eda_outcome.EDAFinding(
1816
1969
  severity=eda_outcome.EDASeverity.INFO,
1817
- explanation=(
1818
- 'Please review any identified outliers and the standard'
1819
- ' deviation.'
1820
- ),
1970
+ explanation='Please review the computed standard deviations.',
1971
+ finding_cause=eda_outcome.FindingCause.NONE,
1821
1972
  )
1822
1973
  )
1823
1974
 
@@ -1841,7 +1992,15 @@ class EDAEngine:
1841
1992
  return self.check_geo_std()
1842
1993
 
1843
1994
  def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1844
- """Computes geo-level variance inflation factor among treatments and controls."""
1995
+ """Checks geo variance inflation factor among treatments and controls.
1996
+
1997
+ The VIF calculation only focuses on multicollinearity among non-constant
1998
+ variables. Any variable with constant values will result in a NaN VIF value.
1999
+
2000
+ Returns:
2001
+ An EDAOutcome object with findings and result values.
2002
+ """
2003
+
1845
2004
  if self._is_national_data:
1846
2005
  raise ValueError(
1847
2006
  'Geo-level VIF checks are not applicable for national models.'
@@ -1851,12 +2010,12 @@ class EDAEngine:
1851
2010
  tc_da = self._stacked_treatment_control_scaled_da
1852
2011
  overall_threshold = self._spec.vif_spec.overall_threshold
1853
2012
 
1854
- overall_vif_da = _calculate_vif(tc_da, _STACK_VAR_COORD_NAME)
2013
+ overall_vif_da = _calculate_vif(tc_da, eda_constants.VARIABLE)
1855
2014
  extreme_overall_vif_da = overall_vif_da.where(
1856
2015
  overall_vif_da > overall_threshold
1857
2016
  )
1858
2017
  extreme_overall_vif_df = extreme_overall_vif_da.to_dataframe(
1859
- name=_VIF_COL_NAME
2018
+ name=eda_constants.VIF_COL_NAME
1860
2019
  ).dropna()
1861
2020
 
1862
2021
  overall_vif_artifact = eda_outcome.VIFArtifact(
@@ -1868,11 +2027,11 @@ class EDAEngine:
1868
2027
  # Geo level VIF check.
1869
2028
  geo_threshold = self._spec.vif_spec.geo_threshold
1870
2029
  geo_vif_da = tc_da.groupby(constants.GEO).map(
1871
- lambda x: _calculate_vif(x, _STACK_VAR_COORD_NAME)
2030
+ lambda x: _calculate_vif(x, eda_constants.VARIABLE)
1872
2031
  )
1873
2032
  extreme_geo_vif_da = geo_vif_da.where(geo_vif_da > geo_threshold)
1874
2033
  extreme_geo_vif_df = extreme_geo_vif_da.to_dataframe(
1875
- name=_VIF_COL_NAME
2034
+ name=eda_constants.VIF_COL_NAME
1876
2035
  ).dropna()
1877
2036
 
1878
2037
  geo_vif_artifact = eda_outcome.VIFArtifact(
@@ -1883,33 +2042,47 @@ class EDAEngine:
1883
2042
 
1884
2043
  findings = []
1885
2044
  if not extreme_overall_vif_df.empty:
1886
- high_vif_vars = extreme_overall_vif_df.index.to_list()
2045
+ high_vif_vars_message = (
2046
+ '\nVariables with extreme VIF:'
2047
+ f' {extreme_overall_vif_df.index.to_list()}'
2048
+ )
1887
2049
  findings.append(
1888
2050
  eda_outcome.EDAFinding(
1889
2051
  severity=eda_outcome.EDASeverity.ERROR,
1890
- explanation=(
1891
- 'Some variables have extreme multicollinearity (VIF'
1892
- f' >{overall_threshold}) across all times and geos. To'
1893
- ' address multicollinearity, please drop any variable that'
1894
- ' is a linear combination of other variables. Otherwise,'
1895
- ' consider combining variables.\n'
1896
- f'Variables with extreme VIF: {high_vif_vars}'
2052
+ explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
2053
+ threshold=overall_threshold,
2054
+ aggregation='times and geos',
2055
+ additional_info=high_vif_vars_message,
1897
2056
  ),
2057
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
2058
+ associated_artifact=overall_vif_artifact,
1898
2059
  )
1899
2060
  )
1900
- elif not extreme_geo_vif_df.empty:
2061
+
2062
+ # Variables that cause overall level findings are very likely to cause
2063
+ # geo-level findings as well, so we exclude them when determining
2064
+ # geo-level findings. This is to avoid over-reporting findings.
2065
+ overall_vars_index = extreme_overall_vif_df.index
2066
+ is_in_overall = extreme_geo_vif_df.index.get_level_values(
2067
+ eda_constants.VARIABLE
2068
+ ).isin(overall_vars_index)
2069
+ geo_df_for_attention = extreme_geo_vif_df[~is_in_overall]
2070
+
2071
+ if not geo_df_for_attention.empty:
1901
2072
  findings.append(
1902
2073
  eda_outcome.EDAFinding(
1903
2074
  severity=eda_outcome.EDASeverity.ATTENTION,
1904
2075
  explanation=(
1905
- 'Some variables have extreme multicollinearity (with VIF >'
1906
- f' {geo_threshold}) in certain geo(s). Consider checking your'
1907
- ' data, and/or combining these variables if they also have'
1908
- ' high VIF in other geos.'
2076
+ eda_constants.MULTICOLLINEARITY_ATTENTION.format(
2077
+ threshold=geo_threshold, additional_info=''
2078
+ )
1909
2079
  ),
2080
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
2081
+ associated_artifact=geo_vif_artifact,
1910
2082
  )
1911
2083
  )
1912
- else:
2084
+
2085
+ if not findings:
1913
2086
  findings.append(
1914
2087
  eda_outcome.EDAFinding(
1915
2088
  severity=eda_outcome.EDASeverity.INFO,
@@ -1919,6 +2092,7 @@ class EDAEngine:
1919
2092
  ' jeopardize model identifiability and model convergence.'
1920
2093
  ' Consider combining the variables if high VIF occurs.'
1921
2094
  ),
2095
+ finding_cause=eda_outcome.FindingCause.NONE,
1922
2096
  )
1923
2097
  )
1924
2098
 
@@ -1931,14 +2105,21 @@ class EDAEngine:
1931
2105
  def check_national_vif(
1932
2106
  self,
1933
2107
  ) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1934
- """Computes national-level variance inflation factor among treatments and controls."""
2108
+ """Checks national variance inflation factor among treatments and controls.
2109
+
2110
+ The VIF calculation only focuses on multicollinearity among non-constant
2111
+ variables. Any variable with constant values will result in a NaN VIF value.
2112
+
2113
+ Returns:
2114
+ An EDAOutcome object with findings and result values.
2115
+ """
1935
2116
  national_tc_da = self._stacked_national_treatment_control_scaled_da
1936
2117
  national_threshold = self._spec.vif_spec.national_threshold
1937
- national_vif_da = _calculate_vif(national_tc_da, _STACK_VAR_COORD_NAME)
2118
+ national_vif_da = _calculate_vif(national_tc_da, eda_constants.VARIABLE)
1938
2119
 
1939
2120
  extreme_national_vif_df = (
1940
2121
  national_vif_da.where(national_vif_da > national_threshold)
1941
- .to_dataframe(name=_VIF_COL_NAME)
2122
+ .to_dataframe(name=eda_constants.VIF_COL_NAME)
1942
2123
  .dropna()
1943
2124
  )
1944
2125
  national_vif_artifact = eda_outcome.VIFArtifact(
@@ -1949,18 +2130,20 @@ class EDAEngine:
1949
2130
 
1950
2131
  findings = []
1951
2132
  if not extreme_national_vif_df.empty:
1952
- high_vif_vars = extreme_national_vif_df.index.to_list()
2133
+ high_vif_vars_message = (
2134
+ '\nVariables with extreme VIF:'
2135
+ f' {extreme_national_vif_df.index.to_list()}'
2136
+ )
1953
2137
  findings.append(
1954
2138
  eda_outcome.EDAFinding(
1955
2139
  severity=eda_outcome.EDASeverity.ERROR,
1956
- explanation=(
1957
- 'Some variables have extreme multicollinearity (with VIF >'
1958
- f' {national_threshold}) across all times. To address'
1959
- ' multicollinearity, please drop any variable that is a'
1960
- ' linear combination of other variables. Otherwise, consider'
1961
- ' combining variables.\n'
1962
- f'Variables with extreme VIF: {high_vif_vars}'
2140
+ explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
2141
+ threshold=national_threshold,
2142
+ aggregation='times',
2143
+ additional_info=high_vif_vars_message,
1963
2144
  ),
2145
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
2146
+ associated_artifact=national_vif_artifact,
1964
2147
  )
1965
2148
  )
1966
2149
  else:
@@ -1973,6 +2156,7 @@ class EDAEngine:
1973
2156
  ' jeopardize model identifiability and model convergence.'
1974
2157
  ' Consider combining the variables if high VIF occurs.'
1975
2158
  ),
2159
+ finding_cause=eda_outcome.FindingCause.NONE,
1976
2160
  )
1977
2161
  )
1978
2162
  return eda_outcome.EDAOutcome(
@@ -1984,6 +2168,9 @@ class EDAEngine:
1984
2168
  def check_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1985
2169
  """Computes variance inflation factor among treatments and controls.
1986
2170
 
2171
+ The VIF calculation only focuses on multicollinearity among non-constant
2172
+ variables. Any variable with constant values will result in a NaN VIF value.
2173
+
1987
2174
  Returns:
1988
2175
  An EDAOutcome object with findings and result values.
1989
2176
  """
@@ -1994,15 +2181,18 @@ class EDAEngine:
1994
2181
 
1995
2182
  @property
1996
2183
  def kpi_has_variability(self) -> bool:
1997
- """Returns True if the KPI has variability across geos and times."""
2184
+ """Whether the KPI has variability across geos and times."""
1998
2185
  return (
1999
2186
  self._overall_scaled_kpi_invariability_artifact.kpi_stdev.item()
2000
- >= _STD_THRESHOLD
2187
+ >= eda_constants.STD_THRESHOLD
2001
2188
  )
2002
2189
 
2003
- def check_overall_kpi_invariability(self) -> eda_outcome.EDAOutcome:
2190
+ def check_overall_kpi_invariability(
2191
+ self,
2192
+ ) -> eda_outcome.EDAOutcome[eda_outcome.KpiInvariabilityArtifact]:
2004
2193
  """Checks if the KPI is constant across all geos and times."""
2005
- kpi = self._overall_scaled_kpi_invariability_artifact.kpi_da.name
2194
+ artifact = self._overall_scaled_kpi_invariability_artifact
2195
+ kpi = artifact.kpi_da.name
2006
2196
  geo_text = '' if self._is_national_data else 'geos and '
2007
2197
 
2008
2198
  if not self.kpi_has_variability:
@@ -2012,6 +2202,8 @@ class EDAEngine:
2012
2202
  f'`{kpi}` is constant across all {geo_text}times, indicating no'
2013
2203
  ' signal in the data. Please fix this data error.'
2014
2204
  ),
2205
+ finding_cause=eda_outcome.FindingCause.VARIABILITY,
2206
+ associated_artifact=artifact,
2015
2207
  )
2016
2208
  else:
2017
2209
  eda_finding = eda_outcome.EDAFinding(
@@ -2019,12 +2211,13 @@ class EDAEngine:
2019
2211
  explanation=(
2020
2212
  f'The {kpi} has variability across {geo_text}times in the data.'
2021
2213
  ),
2214
+ finding_cause=eda_outcome.FindingCause.NONE,
2022
2215
  )
2023
2216
 
2024
2217
  return eda_outcome.EDAOutcome(
2025
2218
  check_type=eda_outcome.EDACheckType.KPI_INVARIABILITY,
2026
2219
  findings=[eda_finding],
2027
- analysis_artifacts=[self._overall_scaled_kpi_invariability_artifact],
2220
+ analysis_artifacts=[artifact],
2028
2221
  )
2029
2222
 
2030
2223
  def check_geo_cost_per_media_unit(
@@ -2081,30 +2274,249 @@ class EDAEngine:
2081
2274
 
2082
2275
  return self.check_geo_cost_per_media_unit()
2083
2276
 
2084
- def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]:
2277
+ def run_all_critical_checks(self) -> eda_outcome.CriticalCheckEDAOutcomes:
2085
2278
  """Runs all critical EDA checks.
2086
2279
 
2087
2280
  Critical checks are those that can result in EDASeverity.ERROR findings.
2088
2281
 
2089
2282
  Returns:
2090
- A list of EDA outcomes, one for each check.
2283
+ A CriticalCheckEDAOutcomes object containing the results of all critical
2284
+ checks.
2091
2285
  """
2092
- outcomes = []
2286
+ outcomes = {}
2093
2287
  for check, check_type in self._critical_checks:
2094
2288
  try:
2095
- outcomes.append(check())
2289
+ outcomes[check_type] = check()
2096
2290
  except Exception as e: # pylint: disable=broad-except
2097
2291
  error_finding = eda_outcome.EDAFinding(
2098
2292
  severity=eda_outcome.EDASeverity.ERROR,
2099
2293
  explanation=(
2100
2294
  f'An error occurred during running {check.__name__}: {e!r}'
2101
2295
  ),
2296
+ finding_cause=eda_outcome.FindingCause.RUNTIME_ERROR,
2297
+ )
2298
+ outcomes[check_type] = eda_outcome.EDAOutcome(
2299
+ check_type=check_type,
2300
+ findings=[error_finding],
2301
+ analysis_artifacts=[],
2102
2302
  )
2103
- outcomes.append(
2104
- eda_outcome.EDAOutcome(
2105
- check_type=check_type,
2106
- findings=[error_finding],
2107
- analysis_artifacts=[],
2303
+
2304
+ return eda_outcome.CriticalCheckEDAOutcomes(
2305
+ kpi_invariability=outcomes[eda_outcome.EDACheckType.KPI_INVARIABILITY],
2306
+ multicollinearity=outcomes[eda_outcome.EDACheckType.MULTICOLLINEARITY],
2307
+ pairwise_correlation=outcomes[
2308
+ eda_outcome.EDACheckType.PAIRWISE_CORRELATION
2309
+ ],
2310
+ )
2311
+
2312
+ def check_variable_geo_time_collinearity(
2313
+ self,
2314
+ ) -> eda_outcome.EDAOutcome[eda_outcome.VariableGeoTimeCollinearityArtifact]:
2315
+ """Compute adjusted R-squared for treatments and controls vs geo and time.
2316
+
2317
+ These checks are applied to geo-level dataset only.
2318
+
2319
+ Returns:
2320
+ An EDAOutcome object containing a VariableGeoTimeCollinearityArtifact.
2321
+ The artifact includes a Dataset with 'rsquared_geo' and 'rsquared_time',
2322
+ showing the adjusted R-squared values for each treatment/control variable
2323
+ when regressed against 'geo' and 'time', respectively. If a variable is
2324
+ constant across geos or times, the corresponding 'rsquared_geo' or
2325
+ 'rsquared_time' value will be NaN.
2326
+ """
2327
+ if self._is_national_data:
2328
+ raise ValueError(
2329
+ 'check_variable_geo_time_collinearity is not supported for national'
2330
+ ' models.'
2331
+ )
2332
+
2333
+ grouped_da = self._stacked_treatment_control_scaled_da.groupby(
2334
+ eda_constants.VARIABLE
2335
+ )
2336
+ rsq_geo = grouped_da.map(_calc_adj_r2, args=(constants.GEO,))
2337
+ rsq_time = grouped_da.map(_calc_adj_r2, args=(constants.TIME,))
2338
+
2339
+ rsquared_ds = xr.Dataset({
2340
+ eda_constants.RSQUARED_GEO: rsq_geo,
2341
+ eda_constants.RSQUARED_TIME: rsq_time,
2342
+ })
2343
+
2344
+ artifact = eda_outcome.VariableGeoTimeCollinearityArtifact(
2345
+ level=eda_outcome.AnalysisLevel.OVERALL,
2346
+ rsquared_ds=rsquared_ds,
2347
+ )
2348
+ findings = [
2349
+ eda_outcome.EDAFinding(
2350
+ severity=eda_outcome.EDASeverity.INFO,
2351
+ explanation=eda_constants.R_SQUARED_TIME_INFO,
2352
+ finding_cause=eda_outcome.FindingCause.NONE,
2353
+ ),
2354
+ eda_outcome.EDAFinding(
2355
+ severity=eda_outcome.EDASeverity.INFO,
2356
+ explanation=eda_constants.R_SQUARED_GEO_INFO,
2357
+ finding_cause=eda_outcome.FindingCause.NONE,
2358
+ ),
2359
+ ]
2360
+
2361
+ return eda_outcome.EDAOutcome(
2362
+ check_type=eda_outcome.EDACheckType.VARIABLE_GEO_TIME_COLLINEARITY,
2363
+ findings=findings,
2364
+ analysis_artifacts=[artifact],
2365
+ )
2366
+
2367
+ def _calculate_population_corr(
2368
+ self, ds: xr.Dataset, *, explanation: str, check_name: str
2369
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
2370
+ """Calculates Spearman correlation between population and data variables.
2371
+
2372
+ Args:
2373
+ ds: An xr.Dataset containing the data variables for which to calculate the
2374
+ correlation with population. The Dataset is expected to have a 'geo'
2375
+ dimension.
2376
+ explanation: A string providing an explanation for the EDA finding.
2377
+ check_name: A string representing the name of the calling check function,
2378
+ used in error messages.
2379
+
2380
+ Returns:
2381
+ An EDAOutcome object containing a PopulationCorrelationArtifact. The
2382
+ artifact includes a Dataset with the Spearman correlation coefficients
2383
+ between each variable in `ds` and the geo population.
2384
+
2385
+ Raises:
2386
+ GeoLevelCheckOnNationalModelError: If the model is national or if
2387
+ `self.geo_population_da` is None.
2388
+ """
2389
+
2390
+ # self.geo_population_da can never be None if the model is geo-level. Adding
2391
+ # this check to make pytype happy.
2392
+ if self._is_national_data or self.geo_population_da is None:
2393
+ raise GeoLevelCheckOnNationalModelError(
2394
+ f'{check_name} is not supported for national models.'
2395
+ )
2396
+
2397
+ corr_ds: xr.Dataset = xr.apply_ufunc(
2398
+ _spearman_coeff,
2399
+ ds.mean(dim=constants.TIME),
2400
+ self.geo_population_da,
2401
+ input_core_dims=[[constants.GEO], [constants.GEO]],
2402
+ vectorize=True,
2403
+ output_dtypes=[float],
2404
+ )
2405
+
2406
+ artifact = eda_outcome.PopulationCorrelationArtifact(
2407
+ level=eda_outcome.AnalysisLevel.OVERALL,
2408
+ correlation_ds=corr_ds,
2409
+ )
2410
+
2411
+ return eda_outcome.EDAOutcome(
2412
+ check_type=eda_outcome.EDACheckType.POPULATION_CORRELATION,
2413
+ findings=[
2414
+ eda_outcome.EDAFinding(
2415
+ severity=eda_outcome.EDASeverity.INFO,
2416
+ explanation=explanation,
2417
+ finding_cause=eda_outcome.FindingCause.NONE,
2418
+ associated_artifact=artifact,
2108
2419
  )
2420
+ ],
2421
+ analysis_artifacts=[artifact],
2422
+ )
2423
+
2424
+ def check_population_corr_scaled_treatment_control(
2425
+ self,
2426
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
2427
+ """Checks Spearman correlation between population and treatments/controls.
2428
+
2429
+ Calculates correlation between population and time-averaged
2430
+ treatments and controls. High correlation for controls or non-media
2431
+ channels may indicate a need for population-scaling. High
2432
+ correlation for other media channels may indicate double-scaling.
2433
+
2434
+ Returns:
2435
+ An EDAOutcome object with findings and result values.
2436
+
2437
+ Raises:
2438
+ GeoLevelCheckOnNationalModelError: If the model is national or geo
2439
+ population data is missing.
2440
+ """
2441
+ return self._calculate_population_corr(
2442
+ ds=self.treatment_control_scaled_ds,
2443
+ explanation=eda_constants.POPULATION_CORRELATION_SCALED_TREATMENT_CONTROL_INFO,
2444
+ check_name='check_population_corr_scaled_treatment_control',
2445
+ )
2446
+
2447
+ def check_population_corr_raw_media(
2448
+ self,
2449
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
2450
+ """Checks Spearman correlation between population and raw media executions.
2451
+
2452
+ Calculates correlation between population and time-averaged raw
2453
+ media executions (paid/organic impressions/reach). These are
2454
+ expected to have reasonably high correlation with population.
2455
+
2456
+ Returns:
2457
+ An EDAOutcome object with findings and result values.
2458
+
2459
+ Raises:
2460
+ GeoLevelCheckOnNationalModelError: If the model is national or geo
2461
+ population data is missing.
2462
+ """
2463
+ to_merge = (
2464
+ da
2465
+ for da in [
2466
+ self.media_raw_da,
2467
+ self.organic_media_raw_da,
2468
+ self.reach_raw_da,
2469
+ self.organic_reach_raw_da,
2470
+ ]
2471
+ if da is not None
2472
+ )
2473
+ # Handle the case where there are no media channels.
2474
+
2475
+ return self._calculate_population_corr(
2476
+ ds=xr.merge(to_merge, join='inner'),
2477
+ explanation=eda_constants.POPULATION_CORRELATION_RAW_MEDIA_INFO,
2478
+ check_name='check_population_corr_raw_media',
2479
+ )
2480
+
2481
+ def _check_prior_probability(
2482
+ self,
2483
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PriorProbabilityArtifact]:
2484
+ """Checks the prior probability of a negative baseline.
2485
+
2486
+ Returns:
2487
+ An EDAOutcome object containing a PriorProbabilityArtifact. The artifact
2488
+ includes a mock prior negative baseline probability and a DataArray of
2489
+ mock mean prior contributions per channel.
2490
+ """
2491
+ # TODO: b/476128592 - currently, this check is blocked. for the meantime,
2492
+ # we will return mock data for the report.
2493
+ channel_names = self._model_context.input_data.get_all_channels()
2494
+ mean_prior_contribution = np.random.uniform(
2495
+ size=len(channel_names), low=0.0, high=0.05
2496
+ )
2497
+ mean_prior_contribution_da = xr.DataArray(
2498
+ mean_prior_contribution,
2499
+ coords={constants.CHANNEL: channel_names},
2500
+ dims=[constants.CHANNEL],
2501
+ )
2502
+
2503
+ artifact = eda_outcome.PriorProbabilityArtifact(
2504
+ level=eda_outcome.AnalysisLevel.OVERALL,
2505
+ prior_negative_baseline_prob=0.123,
2506
+ mean_prior_contribution_da=mean_prior_contribution_da,
2507
+ )
2508
+
2509
+ findings = [
2510
+ eda_outcome.EDAFinding(
2511
+ severity=eda_outcome.EDASeverity.INFO,
2512
+ explanation=eda_constants.PRIOR_PROBABILITY_REPORT_INFO,
2513
+ finding_cause=eda_outcome.FindingCause.NONE,
2514
+ associated_artifact=artifact,
2109
2515
  )
2110
- return outcomes
2516
+ ]
2517
+
2518
+ return eda_outcome.EDAOutcome(
2519
+ check_type=eda_outcome.EDACheckType.PRIOR_PROBABILITY,
2520
+ findings=findings,
2521
+ analysis_artifacts=[artifact],
2522
+ )