google-meridian 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/RECORD +47 -43
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/reviewer.py +4 -1
  7. meridian/analysis/summarizer.py +6 -1
  8. meridian/analysis/test_utils.py +2898 -2538
  9. meridian/analysis/visualizer.py +28 -9
  10. meridian/backend/__init__.py +106 -0
  11. meridian/constants.py +1 -0
  12. meridian/data/input_data.py +30 -52
  13. meridian/data/input_data_builder.py +2 -9
  14. meridian/data/test_utils.py +25 -41
  15. meridian/data/validator.py +48 -0
  16. meridian/mlflow/autolog.py +19 -9
  17. meridian/model/adstock_hill.py +3 -5
  18. meridian/model/context.py +134 -0
  19. meridian/model/eda/constants.py +334 -4
  20. meridian/model/eda/eda_engine.py +723 -312
  21. meridian/model/eda/eda_outcome.py +177 -33
  22. meridian/model/model.py +159 -110
  23. meridian/model/model_test_data.py +38 -0
  24. meridian/model/posterior_sampler.py +103 -62
  25. meridian/model/prior_sampler.py +114 -94
  26. meridian/model/spec.py +23 -14
  27. meridian/templates/card.html.jinja +9 -7
  28. meridian/templates/chart.html.jinja +1 -6
  29. meridian/templates/finding.html.jinja +19 -0
  30. meridian/templates/findings.html.jinja +33 -0
  31. meridian/templates/formatter.py +41 -5
  32. meridian/templates/formatter_test.py +127 -0
  33. meridian/templates/style.css +66 -9
  34. meridian/templates/style.scss +85 -4
  35. meridian/templates/table.html.jinja +1 -0
  36. meridian/version.py +1 -1
  37. scenarioplanner/linkingapi/constants.py +1 -1
  38. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  39. schema/processors/marketing_processor.py +11 -10
  40. schema/processors/model_processor.py +4 -1
  41. schema/serde/distribution.py +12 -7
  42. schema/serde/hyperparameters.py +54 -107
  43. schema/serde/meridian_serde.py +6 -1
  44. schema/utils/__init__.py +1 -0
  45. schema/utils/proto_enum_converter.py +127 -0
  46. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
  47. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +0 -0
@@ -16,19 +16,23 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
+ from collections.abc import Collection, Sequence
19
20
  import dataclasses
20
21
  import functools
21
22
  import typing
22
- from typing import Optional, Protocol, Sequence
23
+ from typing import Protocol
24
+ import warnings
23
25
 
24
26
  from meridian import backend
25
27
  from meridian import constants
28
+ from meridian.model import context
26
29
  from meridian.model import transformers
27
30
  from meridian.model.eda import constants as eda_constants
28
31
  from meridian.model.eda import eda_outcome
29
32
  from meridian.model.eda import eda_spec
30
33
  import numpy as np
31
34
  import pandas as pd
35
+ from scipy import stats
32
36
  import statsmodels.api as sm
33
37
  from statsmodels.stats import outliers_influence
34
38
  import xarray as xr
@@ -39,25 +43,6 @@ if typing.TYPE_CHECKING:
39
43
 
40
44
  __all__ = ['EDAEngine', 'GeoLevelCheckOnNationalModelError']
41
45
 
42
- _DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
43
- _CORRELATION_COL_NAME = eda_constants.CORRELATION
44
- _STACK_VAR_COORD_NAME = eda_constants.VARIABLE
45
- _CORR_VAR1 = eda_constants.VARIABLE_1
46
- _CORR_VAR2 = eda_constants.VARIABLE_2
47
- _CORRELATION_MATRIX_NAME = 'correlation_matrix'
48
- _OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
49
- _GEO_PAIRWISE_CORR_THRESHOLD = 0.999
50
- _NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
51
- _Q1_THRESHOLD = 0.25
52
- _Q3_THRESHOLD = 0.75
53
- _IQR_MULTIPLIER = 1.5
54
- _STD_WITH_OUTLIERS_VAR_NAME = 'std_with_outliers'
55
- _STD_WITHOUT_OUTLIERS_VAR_NAME = 'std_without_outliers'
56
- _STD_THRESHOLD = 1e-4
57
- _OUTLIERS_COL_NAME = 'outliers'
58
- _ABS_OUTLIERS_COL_NAME = 'abs_outliers'
59
- _VIF_COL_NAME = 'VIF'
60
-
61
46
 
62
47
  class _NamedEDACheckCallable(Protocol):
63
48
  """A callable that returns an EDAOutcome and has a __name__ attribute."""
@@ -149,6 +134,15 @@ class ReachFrequencyData:
149
134
  national_rf_impressions_raw_da: xr.DataArray
150
135
 
151
136
 
137
+ def _get_vars_from_dataset(
138
+ base_ds: xr.Dataset,
139
+ variables_to_include: Collection[str],
140
+ ) -> xr.Dataset | None:
141
+ """Helper to get a subset of variables from a Dataset."""
142
+ variables = [v for v in base_ds.data_vars if v in variables_to_include]
143
+ return base_ds[variables].copy() if variables else None
144
+
145
+
152
146
  def _data_array_like(
153
147
  *, da: xr.DataArray, values: np.ndarray | backend.Tensor
154
148
  ) -> xr.DataArray:
@@ -173,7 +167,7 @@ def _data_array_like(
173
167
 
174
168
 
175
169
  def stack_variables(
176
- ds: xr.Dataset, coord_name: str = _STACK_VAR_COORD_NAME
170
+ ds: xr.Dataset, coord_name: str = eda_constants.VARIABLE
177
171
  ) -> xr.DataArray:
178
172
  """Stacks data variables of a Dataset into a single DataArray.
179
173
 
@@ -219,12 +213,12 @@ def _compute_correlation_matrix(
219
213
  An xr.DataArray containing the correlation matrix.
220
214
  """
221
215
  # Create two versions for correlation
222
- da1 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR1})
223
- da2 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR2})
216
+ da1 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_1})
217
+ da2 = input_da.rename({eda_constants.VARIABLE: eda_constants.VARIABLE_2})
224
218
 
225
219
  # Compute pairwise correlation across dims. Other dims are broadcasted.
226
220
  corr_mat_da = xr.corr(da1, da2, dim=dims)
227
- corr_mat_da.name = _CORRELATION_MATRIX_NAME
221
+ corr_mat_da.name = eda_constants.CORRELATION_MATRIX_NAME
228
222
  return corr_mat_da
229
223
 
230
224
 
@@ -238,14 +232,14 @@ def _get_upper_triangle_corr_mat(corr_mat_da: xr.DataArray) -> xr.DataArray:
238
232
  An xr.DataArray containing only the elements in the upper triangle of the
239
233
  correlation matrix, with other elements masked as NaN.
240
234
  """
241
- n_vars = corr_mat_da.sizes[_CORR_VAR1]
235
+ n_vars = corr_mat_da.sizes[eda_constants.VARIABLE_1]
242
236
  mask_np = np.triu(np.ones((n_vars, n_vars), dtype=bool), k=1)
243
237
  mask = xr.DataArray(
244
238
  mask_np,
245
- dims=[_CORR_VAR1, _CORR_VAR2],
239
+ dims=[eda_constants.VARIABLE_1, eda_constants.VARIABLE_2],
246
240
  coords={
247
- _CORR_VAR1: corr_mat_da[_CORR_VAR1],
248
- _CORR_VAR2: corr_mat_da[_CORR_VAR2],
241
+ eda_constants.VARIABLE_1: corr_mat_da[eda_constants.VARIABLE_1],
242
+ eda_constants.VARIABLE_2: corr_mat_da[eda_constants.VARIABLE_2],
249
243
  },
250
244
  )
251
245
  return corr_mat_da.where(mask)
@@ -259,11 +253,11 @@ def _find_extreme_corr_pairs(
259
253
  extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
260
254
 
261
255
  return (
262
- extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME)
256
+ extreme_corr_da.to_dataframe(name=eda_constants.CORRELATION)
263
257
  .dropna()
264
258
  .assign(**{
265
259
  eda_constants.ABS_CORRELATION_COL_NAME: (
266
- lambda x: x[_CORRELATION_COL_NAME].abs()
260
+ lambda x: x[eda_constants.CORRELATION].abs()
267
261
  )
268
262
  })
269
263
  .sort_values(
@@ -286,11 +280,11 @@ def _get_outlier_bounds(
286
280
  A tuple containing the lower and upper bounds of outliers as DataArrays.
287
281
  """
288
282
  # TODO: Allow users to specify custom outlier definitions.
289
- q1 = input_da.quantile(_Q1_THRESHOLD, dim=constants.TIME)
290
- q3 = input_da.quantile(_Q3_THRESHOLD, dim=constants.TIME)
283
+ q1 = input_da.quantile(eda_constants.Q1_THRESHOLD, dim=constants.TIME)
284
+ q3 = input_da.quantile(eda_constants.Q3_THRESHOLD, dim=constants.TIME)
291
285
  iqr = q3 - q1
292
- lower_bound = q1 - _IQR_MULTIPLIER * iqr
293
- upper_bound = q3 + _IQR_MULTIPLIER * iqr
286
+ lower_bound = q1 - eda_constants.IQR_MULTIPLIER * iqr
287
+ upper_bound = q3 + eda_constants.IQR_MULTIPLIER * iqr
294
288
  return lower_bound, upper_bound
295
289
 
296
290
 
@@ -314,11 +308,10 @@ def _calculate_std(
314
308
  )
315
309
  std_without_outliers = da_no_outlier.std(dim=constants.TIME, ddof=1)
316
310
 
317
- std_ds = xr.Dataset({
318
- _STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
319
- _STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
311
+ return xr.Dataset({
312
+ eda_constants.STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
313
+ eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
320
314
  })
321
- return std_ds
322
315
 
323
316
 
324
317
  def _calculate_outliers(
@@ -334,23 +327,29 @@ def _calculate_outliers(
334
327
  outlier values.
335
328
  """
336
329
  lower_bound, upper_bound = _get_outlier_bounds(input_da)
337
- outlier_da = input_da.where(
338
- (input_da < lower_bound) | (input_da > upper_bound)
339
- )
340
- outlier_df = (
341
- outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME)
330
+ return (
331
+ input_da.where((input_da < lower_bound) | (input_da > upper_bound))
332
+ .to_dataframe(name=eda_constants.OUTLIERS_COL_NAME)
342
333
  .dropna()
343
- .assign(
344
- **{_ABS_OUTLIERS_COL_NAME: lambda x: np.abs(x[_OUTLIERS_COL_NAME])}
334
+ .assign(**{
335
+ eda_constants.ABS_OUTLIERS_COL_NAME: lambda x: np.abs(
336
+ x[eda_constants.OUTLIERS_COL_NAME]
337
+ )
338
+ })
339
+ .sort_values(
340
+ by=eda_constants.ABS_OUTLIERS_COL_NAME,
341
+ ascending=False,
342
+ inplace=False,
345
343
  )
346
- .sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
347
344
  )
348
- return outlier_df
349
345
 
350
346
 
351
347
  def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
352
348
  """Helper function to compute variance inflation factor.
353
349
 
350
+ The VIF calculation only focuses on multicollinearity among non-constant
351
+ variables. Any variable with constant values will result in a NaN VIF value.
352
+
354
353
  Args:
355
354
  input_da: A DataArray for which to calculate the VIF over sample dimensions
356
355
  (e.g. time and geo if applicable).
@@ -359,25 +358,28 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
359
358
  Returns:
360
359
  A DataArray containing the VIF for each variable in the variable dimension.
361
360
  """
361
+
362
362
  num_vars = input_da.sizes[var_dim]
363
363
  np_data = input_da.values.reshape(-1, num_vars)
364
- np_data_with_const = sm.add_constant(
365
- np_data, prepend=True, has_constant='add'
366
- )
367
364
 
368
- # Compute VIF for each variable excluding const which is the first one in the
369
- # 'variable' dimension.
370
- vifs = [
371
- outliers_influence.variance_inflation_factor(np_data_with_const, i)
372
- for i in range(1, num_vars + 1)
373
- ]
365
+ is_constant = np.std(np_data, axis=0) < eda_constants.STD_THRESHOLD
366
+ vif_values = np.full(num_vars, np.nan)
367
+ (non_constant_vars_indices,) = (~is_constant).nonzero()
368
+
369
+ if non_constant_vars_indices.size > 0:
370
+ design_matrix = sm.add_constant(
371
+ np_data[:, ~is_constant], prepend=True, has_constant='add'
372
+ )
373
+ for i, var_index in enumerate(non_constant_vars_indices):
374
+ vif_values[var_index] = outliers_influence.variance_inflation_factor(
375
+ design_matrix, i + 1
376
+ )
374
377
 
375
- vif_da = xr.DataArray(
376
- vifs,
378
+ return xr.DataArray(
379
+ vif_values,
377
380
  coords={var_dim: input_da[var_dim].values},
378
381
  dims=[var_dim],
379
382
  )
380
- return vif_da
381
383
 
382
384
 
383
385
  def _check_cost_media_unit_inconsistency(
@@ -392,9 +394,9 @@ def _check_cost_media_unit_inconsistency(
392
394
 
393
395
  Returns:
394
396
  A DataFrame of inconsistencies where either cost is zero and media units
395
- are
396
- positive, or cost is positive and media units are zero.
397
+ are positive, or cost is positive and media units are zero.
397
398
  """
399
+
398
400
  cost_media_units_ds = xr.merge([cost_da, media_units_da])
399
401
 
400
402
  # Condition 1: cost == 0 and media unit > 0
@@ -432,39 +434,69 @@ def _check_cost_per_media_unit(
432
434
  cost_da,
433
435
  media_units_da,
434
436
  )
437
+
438
+ # Calculate cost per media unit. Avoid division by zero by setting cost to
439
+ # NaN where media units are 0. Note that both (cost == media unit == 0) and
440
+ # (cost > 0 and media unit == 0) result in NaN, while the latter one is not
441
+ # desired.
442
+ cost_per_media_unit_da = xr.where(
443
+ media_units_da == 0,
444
+ np.nan,
445
+ cost_da / media_units_da,
446
+ )
447
+ cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
448
+ outlier_df = _calculate_outliers(cost_per_media_unit_da)
449
+
450
+ if not outlier_df.empty:
451
+ outlier_df = outlier_df.rename(
452
+ columns={
453
+ eda_constants.OUTLIERS_COL_NAME: eda_constants.COST_PER_MEDIA_UNIT,
454
+ eda_constants.ABS_OUTLIERS_COL_NAME: (
455
+ eda_constants.ABS_COST_PER_MEDIA_UNIT
456
+ ),
457
+ }
458
+ ).assign(**{
459
+ constants.SPEND: cost_da.to_series(),
460
+ constants.MEDIA_UNITS: media_units_da.to_series(),
461
+ })[[
462
+ constants.SPEND,
463
+ constants.MEDIA_UNITS,
464
+ eda_constants.COST_PER_MEDIA_UNIT,
465
+ eda_constants.ABS_COST_PER_MEDIA_UNIT,
466
+ ]]
467
+
468
+ artifact = eda_outcome.CostPerMediaUnitArtifact(
469
+ level=level,
470
+ cost_per_media_unit_da=cost_per_media_unit_da,
471
+ cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
472
+ outlier_df=outlier_df,
473
+ )
474
+
435
475
  if not cost_media_unit_inconsistency_df.empty:
436
476
  findings.append(
437
477
  eda_outcome.EDAFinding(
438
478
  severity=eda_outcome.EDASeverity.ATTENTION,
439
479
  explanation=(
440
- 'There are instances of inconsistent cost and media units.'
441
- ' This occurs when cost is zero but media units are positive,'
442
- ' or when cost is positive but media units are zero. Please'
443
- ' review the outcome artifact for more details.'
480
+ 'There are instances of inconsistent spend and media units.'
481
+ ' This occurs when spend is zero but media units are positive,'
482
+ ' or when spend is positive but media units are zero. Please'
483
+ ' review the data input for media units and spend.'
444
484
  ),
485
+ finding_cause=eda_outcome.FindingCause.INCONSISTENT_DATA,
486
+ associated_artifact=artifact,
445
487
  )
446
488
  )
447
489
 
448
- # Calculate cost per media unit
449
- # Avoid division by zero by setting cost to NaN where media units are 0.
450
- # Note that both (cost == media unit == 0) and (cost > 0 and media unit ==
451
- # 0) result in NaN, while the latter one is not desired.
452
- cost_per_media_unit_da = xr.where(
453
- media_units_da == 0,
454
- np.nan,
455
- cost_da / media_units_da,
456
- )
457
- cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
458
-
459
- outlier_df = _calculate_outliers(cost_per_media_unit_da)
460
490
  if not outlier_df.empty:
461
491
  findings.append(
462
492
  eda_outcome.EDAFinding(
463
493
  severity=eda_outcome.EDASeverity.ATTENTION,
464
494
  explanation=(
465
495
  'There are outliers in cost per media unit across time.'
466
- ' Please review the outcome artifact for more details.'
496
+ ' Please check for any possible data input error.'
467
497
  ),
498
+ finding_cause=eda_outcome.FindingCause.OUTLIER,
499
+ associated_artifact=artifact,
468
500
  )
469
501
  )
470
502
 
@@ -474,16 +506,10 @@ def _check_cost_per_media_unit(
474
506
  eda_outcome.EDAFinding(
475
507
  severity=eda_outcome.EDASeverity.INFO,
476
508
  explanation='Please review the cost per media unit data.',
509
+ finding_cause=eda_outcome.FindingCause.NONE,
477
510
  )
478
511
  )
479
512
 
480
- artifact = eda_outcome.CostPerMediaUnitArtifact(
481
- level=level,
482
- cost_per_media_unit_da=cost_per_media_unit_da,
483
- cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
484
- outlier_df=outlier_df,
485
- )
486
-
487
513
  return eda_outcome.EDAOutcome(
488
514
  check_type=eda_outcome.EDACheckType.COST_PER_MEDIA_UNIT,
489
515
  findings=findings,
@@ -491,42 +517,91 @@ def _check_cost_per_media_unit(
491
517
  )
492
518
 
493
519
 
520
+ def _calc_adj_r2(da: xr.DataArray, regressor: str) -> xr.DataArray:
521
+ """Calculates adjusted R-squared for a DataArray against a regressor.
522
+
523
+ If the input DataArray `da` is constant, it returns NaN.
524
+
525
+ Args:
526
+ da: The input DataArray.
527
+ regressor: The regressor to use in the formula.
528
+
529
+ Returns:
530
+ An xr.DataArray containing the adjusted R-squared value or NaN if `da` is
531
+ constant.
532
+ """
533
+ if da.std(ddof=1) < eda_constants.STD_THRESHOLD:
534
+ return xr.DataArray(np.nan)
535
+ tmp_name = 'dep_var'
536
+ df = da.to_dataframe(name=tmp_name).reset_index()
537
+ formula = f'{tmp_name} ~ C({regressor})'
538
+ ols = sm.OLS.from_formula(formula, df).fit()
539
+ return xr.DataArray(ols.rsquared_adj)
540
+
541
+
542
+ def _spearman_coeff(x, y):
543
+ """Computes spearman correlation coefficient between two ArrayLike objects."""
544
+
545
+ return stats.spearmanr(x, y, nan_policy='omit').statistic
546
+
547
+
494
548
  class EDAEngine:
495
549
  """Meridian EDA Engine."""
496
550
 
497
551
  def __init__(
498
552
  self,
499
- meridian: model.Meridian,
553
+ meridian: model.Meridian | None = None,
500
554
  spec: eda_spec.EDASpec = eda_spec.EDASpec(),
555
+ *,
556
+ model_context: context.ModelContext | None = None,
501
557
  ):
502
- self._meridian = meridian
558
+ if meridian is not None and model_context is not None:
559
+ raise ValueError(
560
+ 'Only one of `meridian` or `model_context` can be provided.'
561
+ )
562
+ if meridian is not None:
563
+ warnings.warn(
564
+ 'Initializing EDAEngine with a Meridian object is deprecated'
565
+ ' and will be removed in a future version. Please use'
566
+ ' `model_context` instead.',
567
+ DeprecationWarning,
568
+ stacklevel=2,
569
+ )
570
+ self._model_context = meridian.model_context
571
+ elif model_context is not None:
572
+ self._model_context = model_context
573
+ else:
574
+ raise ValueError('Either `meridian` or `model_context` must be provided.')
575
+
576
+ self._input_data = self._model_context.input_data
503
577
  self._spec = spec
504
578
  self._agg_config = self._spec.aggregation_config
505
579
 
506
580
  @property
507
581
  def spec(self) -> eda_spec.EDASpec:
582
+ """The EDA specification."""
508
583
  return self._spec
509
584
 
510
585
  @property
511
586
  def _is_national_data(self) -> bool:
512
- return self._meridian.is_national
587
+ return self._model_context.is_national
513
588
 
514
589
  @functools.cached_property
515
590
  def controls_scaled_da(self) -> xr.DataArray | None:
516
- """Returns the scaled controls data array."""
517
- if self._meridian.input_data.controls is None:
591
+ """The scaled controls data array."""
592
+ if self._input_data.controls is None:
518
593
  return None
519
594
  controls_scaled_da = _data_array_like(
520
- da=self._meridian.input_data.controls,
521
- values=self._meridian.controls_scaled,
595
+ da=self._input_data.controls,
596
+ values=self._model_context.controls_scaled,
522
597
  )
523
598
  controls_scaled_da.name = constants.CONTROLS_SCALED
524
599
  return controls_scaled_da
525
600
 
526
601
  @functools.cached_property
527
602
  def national_controls_scaled_da(self) -> xr.DataArray | None:
528
- """Returns the national scaled controls data array."""
529
- if self._meridian.input_data.controls is None:
603
+ """The national scaled controls data array."""
604
+ if self._input_data.controls is None:
530
605
  return None
531
606
  if self._is_national_data:
532
607
  if self.controls_scaled_da is None:
@@ -538,7 +613,7 @@ class EDAEngine:
538
613
  national_da.name = constants.NATIONAL_CONTROLS_SCALED
539
614
  else:
540
615
  national_da = self._aggregate_and_scale_geo_da(
541
- self._meridian.input_data.controls,
616
+ self._input_data.controls,
542
617
  constants.NATIONAL_CONTROLS_SCALED,
543
618
  transformers.CenteringAndScalingTransformer,
544
619
  constants.CONTROL_VARIABLE,
@@ -548,43 +623,43 @@ class EDAEngine:
548
623
 
549
624
  @functools.cached_property
550
625
  def media_raw_da(self) -> xr.DataArray | None:
551
- """Returns the raw media data array."""
552
- if self._meridian.input_data.media is None:
626
+ """The raw media data array."""
627
+ if self._input_data.media is None:
553
628
  return None
554
- raw_media_da = self._truncate_media_time(self._meridian.input_data.media)
629
+ raw_media_da = self._truncate_media_time(self._input_data.media)
555
630
  raw_media_da.name = constants.MEDIA
556
631
  return raw_media_da
557
632
 
558
633
  @functools.cached_property
559
634
  def media_scaled_da(self) -> xr.DataArray | None:
560
- """Returns the scaled media data array."""
561
- if self._meridian.input_data.media is None:
635
+ """The scaled media data array."""
636
+ if self._input_data.media is None:
562
637
  return None
563
638
  media_scaled_da = _data_array_like(
564
- da=self._meridian.input_data.media,
565
- values=self._meridian.media_tensors.media_scaled,
639
+ da=self._input_data.media,
640
+ values=self._model_context.media_tensors.media_scaled,
566
641
  )
567
642
  media_scaled_da.name = constants.MEDIA_SCALED
568
643
  return self._truncate_media_time(media_scaled_da)
569
644
 
570
645
  @functools.cached_property
571
646
  def media_spend_da(self) -> xr.DataArray | None:
572
- """Returns media spend.
647
+ """The media spend data.
573
648
 
574
649
  If the input spend is aggregated, it is allocated across geo and time
575
650
  proportionally to media units.
576
651
  """
577
652
  # No need to truncate the media time for media spend.
578
- da = self._meridian.input_data.allocated_media_spend
579
- if da is None:
653
+ allocated_media_spend = self._input_data.allocated_media_spend
654
+ if allocated_media_spend is None:
580
655
  return None
581
- da = da.copy()
656
+ da = allocated_media_spend.copy()
582
657
  da.name = constants.MEDIA_SPEND
583
658
  return da
584
659
 
585
660
  @functools.cached_property
586
661
  def national_media_spend_da(self) -> xr.DataArray | None:
587
- """Returns the national media spend data array."""
662
+ """The national media spend data array."""
588
663
  media_spend = self.media_spend_da
589
664
  if media_spend is None:
590
665
  return None
@@ -593,7 +668,7 @@ class EDAEngine:
593
668
  national_da.name = constants.NATIONAL_MEDIA_SPEND
594
669
  else:
595
670
  national_da = self._aggregate_and_scale_geo_da(
596
- self._meridian.input_data.allocated_media_spend,
671
+ self._input_data.allocated_media_spend,
597
672
  constants.NATIONAL_MEDIA_SPEND,
598
673
  None,
599
674
  )
@@ -601,7 +676,7 @@ class EDAEngine:
601
676
 
602
677
  @functools.cached_property
603
678
  def national_media_raw_da(self) -> xr.DataArray | None:
604
- """Returns the national raw media data array."""
679
+ """The national raw media data array."""
605
680
  if self.media_raw_da is None:
606
681
  return None
607
682
  if self._is_national_data:
@@ -618,7 +693,7 @@ class EDAEngine:
618
693
 
619
694
  @functools.cached_property
620
695
  def national_media_scaled_da(self) -> xr.DataArray | None:
621
- """Returns the national scaled media data array."""
696
+ """The national scaled media data array."""
622
697
  if self.media_scaled_da is None:
623
698
  return None
624
699
  if self._is_national_data:
@@ -635,30 +710,30 @@ class EDAEngine:
635
710
 
636
711
  @functools.cached_property
637
712
  def organic_media_raw_da(self) -> xr.DataArray | None:
638
- """Returns the raw organic media data array."""
639
- if self._meridian.input_data.organic_media is None:
713
+ """The raw organic media data array."""
714
+ if self._input_data.organic_media is None:
640
715
  return None
641
716
  raw_organic_media_da = self._truncate_media_time(
642
- self._meridian.input_data.organic_media
717
+ self._input_data.organic_media
643
718
  )
644
719
  raw_organic_media_da.name = constants.ORGANIC_MEDIA
645
720
  return raw_organic_media_da
646
721
 
647
722
  @functools.cached_property
648
723
  def organic_media_scaled_da(self) -> xr.DataArray | None:
649
- """Returns the scaled organic media data array."""
650
- if self._meridian.input_data.organic_media is None:
724
+ """The scaled organic media data array."""
725
+ if self._input_data.organic_media is None:
651
726
  return None
652
727
  organic_media_scaled_da = _data_array_like(
653
- da=self._meridian.input_data.organic_media,
654
- values=self._meridian.organic_media_tensors.organic_media_scaled,
728
+ da=self._input_data.organic_media,
729
+ values=self._model_context.organic_media_tensors.organic_media_scaled,
655
730
  )
656
731
  organic_media_scaled_da.name = constants.ORGANIC_MEDIA_SCALED
657
732
  return self._truncate_media_time(organic_media_scaled_da)
658
733
 
659
734
  @functools.cached_property
660
735
  def national_organic_media_raw_da(self) -> xr.DataArray | None:
661
- """Returns the national raw organic media data array."""
736
+ """The national raw organic media data array."""
662
737
  if self.organic_media_raw_da is None:
663
738
  return None
664
739
  if self._is_national_data:
@@ -673,7 +748,7 @@ class EDAEngine:
673
748
 
674
749
  @functools.cached_property
675
750
  def national_organic_media_scaled_da(self) -> xr.DataArray | None:
676
- """Returns the national scaled organic media data array."""
751
+ """The national scaled organic media data array."""
677
752
  if self.organic_media_scaled_da is None:
678
753
  return None
679
754
  if self._is_national_data:
@@ -692,20 +767,20 @@ class EDAEngine:
692
767
 
693
768
  @functools.cached_property
694
769
  def non_media_scaled_da(self) -> xr.DataArray | None:
695
- """Returns the scaled non-media treatments data array."""
696
- if self._meridian.input_data.non_media_treatments is None:
770
+ """The scaled non-media treatments data array."""
771
+ if self._input_data.non_media_treatments is None:
697
772
  return None
698
773
  non_media_scaled_da = _data_array_like(
699
- da=self._meridian.input_data.non_media_treatments,
700
- values=self._meridian.non_media_treatments_normalized,
774
+ da=self._input_data.non_media_treatments,
775
+ values=self._model_context.non_media_treatments_normalized,
701
776
  )
702
777
  non_media_scaled_da.name = constants.NON_MEDIA_TREATMENTS_SCALED
703
778
  return non_media_scaled_da
704
779
 
705
780
  @functools.cached_property
706
781
  def national_non_media_scaled_da(self) -> xr.DataArray | None:
707
- """Returns the national scaled non-media treatment data array."""
708
- if self._meridian.input_data.non_media_treatments is None:
782
+ """The national scaled non-media treatment data array."""
783
+ if self._input_data.non_media_treatments is None:
709
784
  return None
710
785
  if self._is_national_data:
711
786
  if self.non_media_scaled_da is None:
@@ -717,7 +792,7 @@ class EDAEngine:
717
792
  national_da.name = constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED
718
793
  else:
719
794
  national_da = self._aggregate_and_scale_geo_da(
720
- self._meridian.input_data.non_media_treatments,
795
+ self._input_data.non_media_treatments,
721
796
  constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
722
797
  transformers.CenteringAndScalingTransformer,
723
798
  constants.NON_MEDIA_CHANNEL,
@@ -727,12 +802,12 @@ class EDAEngine:
727
802
 
728
803
  @functools.cached_property
729
804
  def rf_spend_da(self) -> xr.DataArray | None:
730
- """Returns RF spend.
805
+ """The RF spend data.
731
806
 
732
807
  If the input spend is aggregated, it is allocated across geo and time
733
808
  proportionally to RF impressions (reach * frequency).
734
809
  """
735
- da = self._meridian.input_data.allocated_rf_spend
810
+ da = self._input_data.allocated_rf_spend
736
811
  if da is None:
737
812
  return None
738
813
  da = da.copy()
@@ -741,7 +816,7 @@ class EDAEngine:
741
816
 
742
817
  @functools.cached_property
743
818
  def national_rf_spend_da(self) -> xr.DataArray | None:
744
- """Returns the national RF spend data array."""
819
+ """The national RF spend data array."""
745
820
  rf_spend = self.rf_spend_da
746
821
  if rf_spend is None:
747
822
  return None
@@ -750,7 +825,7 @@ class EDAEngine:
750
825
  national_da.name = constants.NATIONAL_RF_SPEND
751
826
  else:
752
827
  national_da = self._aggregate_and_scale_geo_da(
753
- self._meridian.input_data.allocated_rf_spend,
828
+ self._input_data.allocated_rf_spend,
754
829
  constants.NATIONAL_RF_SPEND,
755
830
  None,
756
831
  )
@@ -758,182 +833,182 @@ class EDAEngine:
758
833
 
759
834
  @functools.cached_property
760
835
  def _rf_data(self) -> ReachFrequencyData | None:
761
- if self._meridian.input_data.reach is None:
836
+ if self._input_data.reach is None:
762
837
  return None
763
838
  return self._get_rf_data(
764
- self._meridian.input_data.reach,
765
- self._meridian.input_data.frequency,
839
+ self._input_data.reach,
840
+ self._input_data.frequency,
766
841
  is_organic=False,
767
842
  )
768
843
 
769
844
  @property
770
845
  def reach_raw_da(self) -> xr.DataArray | None:
771
- """Returns the raw reach data array."""
846
+ """The raw reach data array."""
772
847
  if self._rf_data is None:
773
848
  return None
774
- return self._rf_data.reach_raw_da
849
+ return self._rf_data.reach_raw_da # pytype: disable=attribute-error
775
850
 
776
851
  @property
777
852
  def reach_scaled_da(self) -> xr.DataArray | None:
778
- """Returns the scaled reach data array."""
853
+ """The scaled reach data array."""
779
854
  if self._rf_data is None:
780
855
  return None
781
856
  return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
782
857
 
783
858
  @property
784
859
  def national_reach_raw_da(self) -> xr.DataArray | None:
785
- """Returns the national raw reach data array."""
860
+ """The national raw reach data array."""
786
861
  if self._rf_data is None:
787
862
  return None
788
863
  return self._rf_data.national_reach_raw_da
789
864
 
790
865
  @property
791
866
  def national_reach_scaled_da(self) -> xr.DataArray | None:
792
- """Returns the national scaled reach data array."""
867
+ """The national scaled reach data array."""
793
868
  if self._rf_data is None:
794
869
  return None
795
870
  return self._rf_data.national_reach_scaled_da # pytype: disable=attribute-error
796
871
 
797
872
  @property
798
873
  def frequency_da(self) -> xr.DataArray | None:
799
- """Returns the frequency data array."""
874
+ """The frequency data array."""
800
875
  if self._rf_data is None:
801
876
  return None
802
877
  return self._rf_data.frequency_da # pytype: disable=attribute-error
803
878
 
804
879
  @property
805
880
  def national_frequency_da(self) -> xr.DataArray | None:
806
- """Returns the national frequency data array."""
881
+ """The national frequency data array."""
807
882
  if self._rf_data is None:
808
883
  return None
809
884
  return self._rf_data.national_frequency_da # pytype: disable=attribute-error
810
885
 
811
886
  @property
812
887
  def rf_impressions_raw_da(self) -> xr.DataArray | None:
813
- """Returns the raw RF impressions data array."""
888
+ """The raw RF impressions data array."""
814
889
  if self._rf_data is None:
815
890
  return None
816
891
  return self._rf_data.rf_impressions_raw_da # pytype: disable=attribute-error
817
892
 
818
893
  @property
819
894
  def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
820
- """Returns the national raw RF impressions data array."""
895
+ """The national raw RF impressions data array."""
821
896
  if self._rf_data is None:
822
897
  return None
823
898
  return self._rf_data.national_rf_impressions_raw_da # pytype: disable=attribute-error
824
899
 
825
900
  @property
826
901
  def rf_impressions_scaled_da(self) -> xr.DataArray | None:
827
- """Returns the scaled RF impressions data array."""
902
+ """The scaled RF impressions data array."""
828
903
  if self._rf_data is None:
829
904
  return None
830
905
  return self._rf_data.rf_impressions_scaled_da
831
906
 
832
907
  @property
833
908
  def national_rf_impressions_scaled_da(self) -> xr.DataArray | None:
834
- """Returns the national scaled RF impressions data array."""
909
+ """The national scaled RF impressions data array."""
835
910
  if self._rf_data is None:
836
911
  return None
837
912
  return self._rf_data.national_rf_impressions_scaled_da
838
913
 
839
914
  @functools.cached_property
840
915
  def _organic_rf_data(self) -> ReachFrequencyData | None:
841
- if self._meridian.input_data.organic_reach is None:
916
+ if self._input_data.organic_reach is None:
842
917
  return None
843
918
  return self._get_rf_data(
844
- self._meridian.input_data.organic_reach,
845
- self._meridian.input_data.organic_frequency,
919
+ self._input_data.organic_reach,
920
+ self._input_data.organic_frequency,
846
921
  is_organic=True,
847
922
  )
848
923
 
849
924
  @property
850
925
  def organic_reach_raw_da(self) -> xr.DataArray | None:
851
- """Returns the raw organic reach data array."""
926
+ """The raw organic reach data array."""
852
927
  if self._organic_rf_data is None:
853
928
  return None
854
- return self._organic_rf_data.reach_raw_da
929
+ return self._organic_rf_data.reach_raw_da # pytype: disable=attribute-error
855
930
 
856
931
  @property
857
932
  def organic_reach_scaled_da(self) -> xr.DataArray | None:
858
- """Returns the scaled organic reach data array."""
933
+ """The scaled organic reach data array."""
859
934
  if self._organic_rf_data is None:
860
935
  return None
861
936
  return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
862
937
 
863
938
  @property
864
939
  def national_organic_reach_raw_da(self) -> xr.DataArray | None:
865
- """Returns the national raw organic reach data array."""
940
+ """The national raw organic reach data array."""
866
941
  if self._organic_rf_data is None:
867
942
  return None
868
943
  return self._organic_rf_data.national_reach_raw_da
869
944
 
870
945
  @property
871
946
  def national_organic_reach_scaled_da(self) -> xr.DataArray | None:
872
- """Returns the national scaled organic reach data array."""
947
+ """The national scaled organic reach data array."""
873
948
  if self._organic_rf_data is None:
874
949
  return None
875
950
  return self._organic_rf_data.national_reach_scaled_da # pytype: disable=attribute-error
876
951
 
877
952
  @property
878
953
  def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
879
- """Returns the scaled organic RF impressions data array."""
954
+ """The scaled organic RF impressions data array."""
880
955
  if self._organic_rf_data is None:
881
956
  return None
882
957
  return self._organic_rf_data.rf_impressions_scaled_da
883
958
 
884
959
  @property
885
960
  def national_organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
886
- """Returns the national scaled organic RF impressions data array."""
961
+ """The national scaled organic RF impressions data array."""
887
962
  if self._organic_rf_data is None:
888
963
  return None
889
964
  return self._organic_rf_data.national_rf_impressions_scaled_da
890
965
 
891
966
  @property
892
967
  def organic_frequency_da(self) -> xr.DataArray | None:
893
- """Returns the organic frequency data array."""
968
+ """The organic frequency data array."""
894
969
  if self._organic_rf_data is None:
895
970
  return None
896
971
  return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
897
972
 
898
973
  @property
899
974
  def national_organic_frequency_da(self) -> xr.DataArray | None:
900
- """Returns the national organic frequency data array."""
975
+ """The national organic frequency data array."""
901
976
  if self._organic_rf_data is None:
902
977
  return None
903
978
  return self._organic_rf_data.national_frequency_da # pytype: disable=attribute-error
904
979
 
905
980
  @property
906
981
  def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
907
- """Returns the raw organic RF impressions data array."""
982
+ """The raw organic RF impressions data array."""
908
983
  if self._organic_rf_data is None:
909
984
  return None
910
985
  return self._organic_rf_data.rf_impressions_raw_da
911
986
 
912
987
  @property
913
988
  def national_organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
914
- """Returns the national raw organic RF impressions data array."""
989
+ """The national raw organic RF impressions data array."""
915
990
  if self._organic_rf_data is None:
916
991
  return None
917
992
  return self._organic_rf_data.national_rf_impressions_raw_da
918
993
 
919
994
  @functools.cached_property
920
995
  def geo_population_da(self) -> xr.DataArray | None:
921
- """Returns the geo population data array."""
996
+ """The geo population data array."""
922
997
  if self._is_national_data:
923
998
  return None
924
999
  return xr.DataArray(
925
- self._meridian.population,
926
- coords={constants.GEO: self._meridian.input_data.geo.values},
1000
+ self._model_context.population,
1001
+ coords={constants.GEO: self._input_data.geo.values},
927
1002
  dims=[constants.GEO],
928
1003
  name=constants.POPULATION,
929
1004
  )
930
1005
 
931
1006
  @functools.cached_property
932
1007
  def kpi_scaled_da(self) -> xr.DataArray:
933
- """Returns the scaled KPI data array."""
1008
+ """The scaled KPI data array."""
934
1009
  scaled_kpi_da = _data_array_like(
935
- da=self._meridian.input_data.kpi,
936
- values=self._meridian.kpi_scaled,
1010
+ da=self._input_data.kpi,
1011
+ values=self._model_context.kpi_scaled,
937
1012
  )
938
1013
  scaled_kpi_da.name = constants.KPI_SCALED
939
1014
  return scaled_kpi_da
@@ -942,7 +1017,7 @@ class EDAEngine:
942
1017
  def _overall_scaled_kpi_invariability_artifact(
943
1018
  self,
944
1019
  ) -> eda_outcome.KpiInvariabilityArtifact:
945
- """Returns an artifact of overall scaled KPI invariability."""
1020
+ """An artifact of overall scaled KPI invariability."""
946
1021
  return eda_outcome.KpiInvariabilityArtifact(
947
1022
  level=eda_outcome.AnalysisLevel.OVERALL,
948
1023
  kpi_da=self.kpi_scaled_da,
@@ -951,14 +1026,14 @@ class EDAEngine:
951
1026
 
952
1027
  @functools.cached_property
953
1028
  def national_kpi_scaled_da(self) -> xr.DataArray:
954
- """Returns the national scaled KPI data array."""
1029
+ """The national scaled KPI data array."""
955
1030
  if self._is_national_data:
956
1031
  national_da = self.kpi_scaled_da.squeeze(constants.GEO, drop=True)
957
1032
  national_da.name = constants.NATIONAL_KPI_SCALED
958
1033
  else:
959
1034
  # Note that kpi is summable by assumption.
960
1035
  national_da = self._aggregate_and_scale_geo_da(
961
- self._meridian.input_data.kpi,
1036
+ self._input_data.kpi,
962
1037
  constants.NATIONAL_KPI_SCALED,
963
1038
  transformers.CenteringAndScalingTransformer,
964
1039
  )
@@ -966,7 +1041,7 @@ class EDAEngine:
966
1041
 
967
1042
  @functools.cached_property
968
1043
  def treatment_control_scaled_ds(self) -> xr.Dataset:
969
- """Returns a Dataset containing all scaled treatments and controls.
1044
+ """A Dataset containing all scaled treatments and controls.
970
1045
 
971
1046
  This includes media, RF impressions, organic media, organic RF impressions,
972
1047
  non-media treatments, and control variables, all at the geo level.
@@ -987,7 +1062,7 @@ class EDAEngine:
987
1062
 
988
1063
  @functools.cached_property
989
1064
  def all_spend_ds(self) -> xr.Dataset:
990
- """Returns a Dataset containing all spend data.
1065
+ """A Dataset containing all spend data.
991
1066
 
992
1067
  This includes media spend and rf spend.
993
1068
  """
@@ -1003,7 +1078,7 @@ class EDAEngine:
1003
1078
 
1004
1079
  @functools.cached_property
1005
1080
  def national_all_spend_ds(self) -> xr.Dataset:
1006
- """Returns a Dataset containing all national spend data.
1081
+ """A Dataset containing all national spend data.
1007
1082
 
1008
1083
  This includes media spend and rf spend.
1009
1084
  """
@@ -1019,14 +1094,14 @@ class EDAEngine:
1019
1094
 
1020
1095
  @functools.cached_property
1021
1096
  def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
1022
- """Returns a stacked DataArray of treatment_control_scaled_ds."""
1097
+ """A stacked DataArray of treatment_control_scaled_ds."""
1023
1098
  da = stack_variables(self.treatment_control_scaled_ds)
1024
1099
  da.name = constants.TREATMENT_CONTROL_SCALED
1025
1100
  return da
1026
1101
 
1027
1102
  @functools.cached_property
1028
1103
  def national_treatment_control_scaled_ds(self) -> xr.Dataset:
1029
- """Returns a Dataset containing all scaled treatments and controls.
1104
+ """A Dataset containing all scaled treatments and controls.
1030
1105
 
1031
1106
  This includes media, RF impressions, organic media, organic RF impressions,
1032
1107
  non-media treatments, and control variables, all at the national level.
@@ -1047,14 +1122,14 @@ class EDAEngine:
1047
1122
 
1048
1123
  @functools.cached_property
1049
1124
  def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
1050
- """Returns a stacked DataArray of national_treatment_control_scaled_ds."""
1125
+ """A stacked DataArray of national_treatment_control_scaled_ds."""
1051
1126
  da = stack_variables(self.national_treatment_control_scaled_ds)
1052
1127
  da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
1053
1128
  return da
1054
1129
 
1055
1130
  @functools.cached_property
1056
1131
  def treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1057
- """Returns a Dataset of scaled treatments excluding non-media."""
1132
+ """A Dataset of scaled treatments excluding non-media."""
1058
1133
  return self.treatment_control_scaled_ds.drop_dims(
1059
1134
  [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1060
1135
  errors='ignore',
@@ -1062,15 +1137,34 @@ class EDAEngine:
1062
1137
 
1063
1138
  @functools.cached_property
1064
1139
  def national_treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1065
- """Returns a Dataset of national scaled treatments excluding non-media."""
1140
+ """A Dataset of national scaled treatments excluding non-media."""
1066
1141
  return self.national_treatment_control_scaled_ds.drop_dims(
1067
1142
  [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1068
1143
  errors='ignore',
1069
1144
  )
1070
1145
 
1146
+ @functools.cached_property
1147
+ def controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
1148
+ """A Dataset of scaled controls and non-media treatments."""
1149
+ return _get_vars_from_dataset(
1150
+ self.treatment_control_scaled_ds,
1151
+ [constants.CONTROLS_SCALED, constants.NON_MEDIA_TREATMENTS_SCALED],
1152
+ )
1153
+
1154
+ @functools.cached_property
1155
+ def national_controls_and_non_media_scaled_ds(self) -> xr.Dataset | None:
1156
+ """A Dataset of national scaled controls and non-media treatments."""
1157
+ return _get_vars_from_dataset(
1158
+ self.national_treatment_control_scaled_ds,
1159
+ [
1160
+ constants.NATIONAL_CONTROLS_SCALED,
1161
+ constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
1162
+ ],
1163
+ )
1164
+
1071
1165
  @functools.cached_property
1072
1166
  def all_reach_scaled_da(self) -> xr.DataArray | None:
1073
- """Returns a DataArray containing all scaled reach data.
1167
+ """A DataArray containing all scaled reach data.
1074
1168
 
1075
1169
  This includes both paid and organic reach, concatenated along the RF_CHANNEL
1076
1170
  dimension.
@@ -1096,7 +1190,7 @@ class EDAEngine:
1096
1190
 
1097
1191
  @functools.cached_property
1098
1192
  def all_freq_da(self) -> xr.DataArray | None:
1099
- """Returns a DataArray containing all frequency data.
1193
+ """A DataArray containing all frequency data.
1100
1194
 
1101
1195
  This includes both paid and organic frequency, concatenated along the
1102
1196
  RF_CHANNEL dimension.
@@ -1122,7 +1216,7 @@ class EDAEngine:
1122
1216
 
1123
1217
  @functools.cached_property
1124
1218
  def national_all_reach_scaled_da(self) -> xr.DataArray | None:
1125
- """Returns a DataArray containing all national-level scaled reach data.
1219
+ """A DataArray containing all national-level scaled reach data.
1126
1220
 
1127
1221
  This includes both paid and organic reach, concatenated along the
1128
1222
  RF_CHANNEL dimension.
@@ -1149,7 +1243,7 @@ class EDAEngine:
1149
1243
 
1150
1244
  @functools.cached_property
1151
1245
  def national_all_freq_da(self) -> xr.DataArray | None:
1152
- """Returns a DataArray containing all national-level frequency data.
1246
+ """A DataArray containing all national-level frequency data.
1153
1247
 
1154
1248
  This includes both paid and organic frequency, concatenated along the
1155
1249
  RF_CHANNEL dimension.
@@ -1202,7 +1296,7 @@ class EDAEngine:
1202
1296
  def _critical_checks(
1203
1297
  self,
1204
1298
  ) -> list[tuple[_NamedEDACheckCallable, eda_outcome.EDACheckType]]:
1205
- """Returns a list of critical checks to be performed."""
1299
+ """A list of critical checks to be performed."""
1206
1300
  checks = [
1207
1301
  (
1208
1302
  self.check_overall_kpi_invariability,
@@ -1221,19 +1315,21 @@ class EDAEngine:
1221
1315
  # This should not happen. If it does, it means this function is mis-used.
1222
1316
  if constants.MEDIA_TIME not in da.coords:
1223
1317
  raise ValueError(
1224
- f'Variable does not have a media time coordinate: {da.name}.'
1318
+ f'Variable does not have a media time coordinate: {da.name!r}.'
1225
1319
  )
1226
1320
 
1227
- start = self._meridian.n_media_times - self._meridian.n_times
1228
- da = da.copy().isel({constants.MEDIA_TIME: slice(start, None)})
1229
- da = da.rename({constants.MEDIA_TIME: constants.TIME})
1230
- return da
1321
+ start = self._model_context.n_media_times - self._model_context.n_times
1322
+ return (
1323
+ da.copy()
1324
+ .isel({constants.MEDIA_TIME: slice(start, None)})
1325
+ .rename({constants.MEDIA_TIME: constants.TIME})
1326
+ )
1231
1327
 
1232
1328
  def _scale_xarray(
1233
1329
  self,
1234
1330
  xarray: xr.DataArray,
1235
- transformer_class: Optional[type[transformers.TensorTransformer]],
1236
- population: Optional[backend.Tensor] = None,
1331
+ transformer_class: type[transformers.TensorTransformer] | None,
1332
+ population: backend.Tensor | None = None,
1237
1333
  ) -> xr.DataArray:
1238
1334
  """Scales xarray values with a TensorTransformer."""
1239
1335
  da = xarray.copy()
@@ -1285,7 +1381,9 @@ class EDAEngine:
1285
1381
  agg_results = []
1286
1382
  for var_name in geo_da[channel_dim].values:
1287
1383
  var_data = geo_da.sel({channel_dim: var_name})
1288
- agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
1384
+ agg_func = da_var_agg_map.get(
1385
+ var_name, eda_constants.DEFAULT_DA_VAR_AGG_FUNCTION
1386
+ )
1289
1387
  # Apply the aggregation function over the GEO dimension
1290
1388
  aggregated_data = var_data.reduce(
1291
1389
  agg_func, dim=constants.GEO, keepdims=keepdims
@@ -1299,9 +1397,9 @@ class EDAEngine:
1299
1397
  self,
1300
1398
  geo_da: xr.DataArray,
1301
1399
  national_da_name: str,
1302
- transformer_class: Optional[type[transformers.TensorTransformer]],
1303
- channel_dim: Optional[str] = None,
1304
- da_var_agg_map: Optional[eda_spec.AggregationMap] = None,
1400
+ transformer_class: type[transformers.TensorTransformer] | None,
1401
+ channel_dim: str | None = None,
1402
+ da_var_agg_map: eda_spec.AggregationMap | None = None,
1305
1403
  ) -> xr.DataArray:
1306
1404
  """Aggregate geo-level xr.DataArray to national level and then scale values.
1307
1405
 
@@ -1351,11 +1449,11 @@ class EDAEngine:
1351
1449
  """Get impressions and frequencies data arrays for RF channels."""
1352
1450
  if is_organic:
1353
1451
  scaled_reach_values = (
1354
- self._meridian.organic_rf_tensors.organic_reach_scaled
1452
+ self._model_context.organic_rf_tensors.organic_reach_scaled
1355
1453
  )
1356
1454
  names = _ORGANIC_RF_NAMES
1357
1455
  else:
1358
- scaled_reach_values = self._meridian.rf_tensors.reach_scaled
1456
+ scaled_reach_values = self._model_context.rf_tensors.reach_scaled
1359
1457
  names = _RF_NAMES
1360
1458
 
1361
1459
  reach_scaled_da = _data_array_like(
@@ -1433,7 +1531,7 @@ class EDAEngine:
1433
1531
  impressions_scaled_da = self._scale_xarray(
1434
1532
  impressions_raw_da,
1435
1533
  transformers.MediaTransformer,
1436
- population=self._meridian.population,
1534
+ population=self._model_context.population,
1437
1535
  )
1438
1536
  impressions_scaled_da.name = names.impressions_scaled
1439
1537
 
@@ -1491,9 +1589,16 @@ class EDAEngine:
1491
1589
  overall_corr_mat, overall_extreme_corr_var_pairs_df = (
1492
1590
  self._pairwise_corr_for_geo_data(
1493
1591
  dims=[constants.GEO, constants.TIME],
1494
- extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
1592
+ extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
1495
1593
  )
1496
1594
  )
1595
+ overall_artifact = eda_outcome.PairwiseCorrArtifact(
1596
+ level=eda_outcome.AnalysisLevel.OVERALL,
1597
+ corr_matrix=overall_corr_mat,
1598
+ extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
1599
+ extreme_corr_threshold=eda_constants.OVERALL_PAIRWISE_CORR_THRESHOLD,
1600
+ )
1601
+
1497
1602
  if not overall_extreme_corr_var_pairs_df.empty:
1498
1603
  var_pairs = overall_extreme_corr_var_pairs_df.index.to_list()
1499
1604
  findings.append(
@@ -1505,21 +1610,33 @@ class EDAEngine:
1505
1610
  ' variables, please remove one of the variables from the'
1506
1611
  f' model.\nPairs with perfect correlation: {var_pairs}'
1507
1612
  ),
1613
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
1614
+ associated_artifact=overall_artifact,
1508
1615
  )
1509
1616
  )
1510
1617
 
1511
1618
  geo_corr_mat, geo_extreme_corr_var_pairs_df = (
1512
1619
  self._pairwise_corr_for_geo_data(
1513
1620
  dims=constants.TIME,
1514
- extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
1621
+ extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
1515
1622
  )
1516
1623
  )
1517
- # Overall correlation and per-geo correlation findings are mutually
1518
- # exclusive, and overall correlation finding takes precedence.
1519
- if (
1520
- overall_extreme_corr_var_pairs_df.empty
1521
- and not geo_extreme_corr_var_pairs_df.empty
1522
- ):
1624
+ # Pairs that cause overall level findings are very likely to cause geo
1625
+ # level findings as well, so we exclude them when determining geo-level
1626
+ # findings. This is to avoid over-reporting findings.
1627
+ overall_pairs_index = overall_extreme_corr_var_pairs_df.index
1628
+ is_in_overall = geo_extreme_corr_var_pairs_df.index.droplevel(
1629
+ constants.GEO
1630
+ ).isin(overall_pairs_index)
1631
+ geo_df_for_attention = geo_extreme_corr_var_pairs_df[~is_in_overall]
1632
+ geo_artifact = eda_outcome.PairwiseCorrArtifact(
1633
+ level=eda_outcome.AnalysisLevel.GEO,
1634
+ corr_matrix=geo_corr_mat,
1635
+ extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
1636
+ extreme_corr_threshold=eda_constants.GEO_PAIRWISE_CORR_THRESHOLD,
1637
+ )
1638
+
1639
+ if not geo_df_for_attention.empty:
1523
1640
  findings.append(
1524
1641
  eda_outcome.EDAFinding(
1525
1642
  severity=eda_outcome.EDASeverity.ATTENTION,
@@ -1529,6 +1646,8 @@ class EDAEngine:
1529
1646
  ' variables if they also have high pairwise correlations in'
1530
1647
  ' other geos.'
1531
1648
  ),
1649
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
1650
+ associated_artifact=geo_artifact,
1532
1651
  )
1533
1652
  )
1534
1653
 
@@ -1538,34 +1657,15 @@ class EDAEngine:
1538
1657
  findings.append(
1539
1658
  eda_outcome.EDAFinding(
1540
1659
  severity=eda_outcome.EDASeverity.INFO,
1541
- explanation=(
1542
- 'Please review the computed pairwise correlations. Note that'
1543
- ' high pairwise correlation may cause model identifiability'
1544
- ' and convergence issues. Consider combining the variables if'
1545
- ' high correlation exists.'
1546
- ),
1660
+ explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
1661
+ finding_cause=eda_outcome.FindingCause.NONE,
1547
1662
  )
1548
1663
  )
1549
1664
 
1550
- pairwise_corr_artifacts = [
1551
- eda_outcome.PairwiseCorrArtifact(
1552
- level=eda_outcome.AnalysisLevel.OVERALL,
1553
- corr_matrix=overall_corr_mat,
1554
- extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
1555
- extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
1556
- ),
1557
- eda_outcome.PairwiseCorrArtifact(
1558
- level=eda_outcome.AnalysisLevel.GEO,
1559
- corr_matrix=geo_corr_mat,
1560
- extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
1561
- extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
1562
- ),
1563
- ]
1564
-
1565
1665
  return eda_outcome.EDAOutcome(
1566
1666
  check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
1567
1667
  findings=findings,
1568
- analysis_artifacts=pairwise_corr_artifacts,
1668
+ analysis_artifacts=[overall_artifact, geo_artifact],
1569
1669
  )
1570
1670
 
1571
1671
  def check_national_pairwise_corr(
@@ -1582,7 +1682,14 @@ class EDAEngine:
1582
1682
  self._stacked_national_treatment_control_scaled_da, dims=constants.TIME
1583
1683
  )
1584
1684
  extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
1585
- corr_mat, _NATIONAL_PAIRWISE_CORR_THRESHOLD
1685
+ corr_mat, eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD
1686
+ )
1687
+
1688
+ artifact = eda_outcome.PairwiseCorrArtifact(
1689
+ level=eda_outcome.AnalysisLevel.NATIONAL,
1690
+ corr_matrix=corr_mat,
1691
+ extreme_corr_var_pairs=extreme_corr_var_pairs_df,
1692
+ extreme_corr_threshold=eda_constants.NATIONAL_PAIRWISE_CORR_THRESHOLD,
1586
1693
  )
1587
1694
 
1588
1695
  if not extreme_corr_var_pairs_df.empty:
@@ -1596,33 +1703,23 @@ class EDAEngine:
1596
1703
  ' variables, please remove one of the variables from the'
1597
1704
  f' model.\nPairs with perfect correlation: {var_pairs}'
1598
1705
  ),
1706
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
1707
+ associated_artifact=artifact,
1599
1708
  )
1600
1709
  )
1601
1710
  else:
1602
1711
  findings.append(
1603
1712
  eda_outcome.EDAFinding(
1604
1713
  severity=eda_outcome.EDASeverity.INFO,
1605
- explanation=(
1606
- 'Please review the computed pairwise correlations. Note that'
1607
- ' high pairwise correlation may cause model identifiability'
1608
- ' and convergence issues. Consider combining the variables if'
1609
- ' high correlation exists.'
1610
- ),
1714
+ explanation=(eda_constants.PAIRWISE_CORRELATION_CHECK_INFO),
1715
+ finding_cause=eda_outcome.FindingCause.NONE,
1611
1716
  )
1612
1717
  )
1613
1718
 
1614
- pairwise_corr_artifacts = [
1615
- eda_outcome.PairwiseCorrArtifact(
1616
- level=eda_outcome.AnalysisLevel.NATIONAL,
1617
- corr_matrix=corr_mat,
1618
- extreme_corr_var_pairs=extreme_corr_var_pairs_df,
1619
- extreme_corr_threshold=_NATIONAL_PAIRWISE_CORR_THRESHOLD,
1620
- )
1621
- ]
1622
1719
  return eda_outcome.EDAOutcome(
1623
1720
  check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
1624
1721
  findings=findings,
1625
- analysis_artifacts=pairwise_corr_artifacts,
1722
+ analysis_artifacts=[artifact],
1626
1723
  )
1627
1724
 
1628
1725
  def check_pairwise_corr(
@@ -1643,20 +1740,14 @@ class EDAEngine:
1643
1740
  data: xr.DataArray,
1644
1741
  level: eda_outcome.AnalysisLevel,
1645
1742
  zero_std_message: str,
1743
+ outlier_message: str,
1646
1744
  ) -> tuple[
1647
- eda_outcome.EDAFinding | None, eda_outcome.StandardDeviationArtifact
1745
+ list[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
1648
1746
  ]:
1649
1747
  """Helper to check standard deviation."""
1650
1748
  std_ds = _calculate_std(data)
1651
1749
  outlier_df = _calculate_outliers(data)
1652
1750
 
1653
- finding = None
1654
- if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
1655
- finding = eda_outcome.EDAFinding(
1656
- severity=eda_outcome.EDASeverity.ATTENTION,
1657
- explanation=zero_std_message,
1658
- )
1659
-
1660
1751
  artifact = eda_outcome.StandardDeviationArtifact(
1661
1752
  variable=str(data.name),
1662
1753
  level=level,
@@ -1664,7 +1755,31 @@ class EDAEngine:
1664
1755
  outlier_df=outlier_df,
1665
1756
  )
1666
1757
 
1667
- return finding, artifact
1758
+ findings = []
1759
+ if (
1760
+ std_ds[eda_constants.STD_WITHOUT_OUTLIERS_VAR_NAME]
1761
+ < eda_constants.STD_THRESHOLD
1762
+ ).any():
1763
+ findings.append(
1764
+ eda_outcome.EDAFinding(
1765
+ severity=eda_outcome.EDASeverity.ATTENTION,
1766
+ explanation=zero_std_message,
1767
+ finding_cause=eda_outcome.FindingCause.VARIABILITY,
1768
+ associated_artifact=artifact,
1769
+ )
1770
+ )
1771
+
1772
+ if not outlier_df.empty:
1773
+ findings.append(
1774
+ eda_outcome.EDAFinding(
1775
+ severity=eda_outcome.EDASeverity.ATTENTION,
1776
+ explanation=outlier_message,
1777
+ finding_cause=eda_outcome.FindingCause.OUTLIER,
1778
+ associated_artifact=artifact,
1779
+ )
1780
+ )
1781
+
1782
+ return findings, artifact
1668
1783
 
1669
1784
  def check_geo_std(
1670
1785
  self,
@@ -1685,6 +1800,10 @@ class EDAEngine:
1685
1800
  ' variable for these geos. Please review the input data,'
1686
1801
  ' and/or consider grouping these geos together.'
1687
1802
  ),
1803
+ (
1804
+ 'There are outliers in the scaled KPI in certain geos.'
1805
+ ' Please check for any possible data errors.'
1806
+ ),
1688
1807
  ),
1689
1808
  (
1690
1809
  self._stacked_treatment_control_scaled_da,
@@ -1695,6 +1814,11 @@ class EDAEngine:
1695
1814
  ' consider combining them to mitigate potential model'
1696
1815
  ' identifiability and convergence issues.'
1697
1816
  ),
1817
+ (
1818
+ 'There are outliers in the scaled treatment or control'
1819
+ ' variables in certain geos. Please check for any possible data'
1820
+ ' errors.'
1821
+ ),
1698
1822
  ),
1699
1823
  (
1700
1824
  self.all_reach_scaled_da,
@@ -1705,6 +1829,11 @@ class EDAEngine:
1705
1829
  ' geos, consider modeling them as impression-based channels'
1706
1830
  ' instead by taking reach * frequency.'
1707
1831
  ),
1832
+ (
1833
+ 'There are outliers in the scaled reach values of the RF or'
1834
+ ' Organic RF channels in certain geos. Please check for any'
1835
+ ' possible data errors.'
1836
+ ),
1708
1837
  ),
1709
1838
  (
1710
1839
  self.all_freq_da,
@@ -1715,30 +1844,34 @@ class EDAEngine:
1715
1844
  ' geos, consider modeling them as impression-based channels'
1716
1845
  ' instead by taking reach * frequency.'
1717
1846
  ),
1847
+ (
1848
+ 'There are outliers in the scaled frequency values of the RF or'
1849
+ ' Organic RF channels in certain geos. Please check for any'
1850
+ ' possible data errors.'
1851
+ ),
1718
1852
  ),
1719
1853
  ]
1720
1854
 
1721
- for data_da, message in checks:
1855
+ for data_da, std_message, outlier_message in checks:
1722
1856
  if data_da is None:
1723
1857
  continue
1724
- finding, artifact = self._check_std(
1858
+ current_findings, artifact = self._check_std(
1725
1859
  level=eda_outcome.AnalysisLevel.GEO,
1726
1860
  data=data_da,
1727
- zero_std_message=message,
1861
+ zero_std_message=std_message,
1862
+ outlier_message=outlier_message,
1728
1863
  )
1729
1864
  artifacts.append(artifact)
1730
- if finding:
1731
- findings.append(finding)
1865
+ if current_findings:
1866
+ findings.extend(current_findings)
1732
1867
 
1733
1868
  # Add an INFO finding if no findings were added.
1734
1869
  if not findings:
1735
1870
  findings.append(
1736
1871
  eda_outcome.EDAFinding(
1737
1872
  severity=eda_outcome.EDASeverity.INFO,
1738
- explanation=(
1739
- 'Please review any identified outliers and the standard'
1740
- ' deviation.'
1741
- ),
1873
+ explanation='Please review the computed standard deviations.',
1874
+ finding_cause=eda_outcome.FindingCause.NONE,
1742
1875
  )
1743
1876
  )
1744
1877
 
@@ -1765,6 +1898,10 @@ class EDAEngine:
1765
1898
  ' the input data, and/or reconsider the feasibility of model'
1766
1899
  ' fitting with this dataset.'
1767
1900
  ),
1901
+ (
1902
+ 'There are outliers in the scaled KPI.'
1903
+ ' Please check for any possible data errors.'
1904
+ ),
1768
1905
  ),
1769
1906
  (
1770
1907
  self._stacked_national_treatment_control_scaled_da,
@@ -1776,6 +1913,10 @@ class EDAEngine:
1776
1913
  ' Please review the input data, and/or consider combining these'
1777
1914
  ' variables to mitigate sparsity.'
1778
1915
  ),
1916
+ (
1917
+ 'There are outliers in the scaled treatment or control'
1918
+ ' variables. Please check for any possible data errors.'
1919
+ ),
1779
1920
  ),
1780
1921
  (
1781
1922
  self.national_all_reach_scaled_da,
@@ -1785,6 +1926,11 @@ class EDAEngine:
1785
1926
  ' Consider modeling these RF channels as impression-based'
1786
1927
  ' channels instead.'
1787
1928
  ),
1929
+ (
1930
+ 'There are outliers in the scaled reach values of the RF or'
1931
+ ' Organic RF channels. Please check for any possible data'
1932
+ ' errors.'
1933
+ ),
1788
1934
  ),
1789
1935
  (
1790
1936
  self.national_all_freq_da,
@@ -1794,30 +1940,34 @@ class EDAEngine:
1794
1940
  ' Consider modeling these RF channels as impression-based'
1795
1941
  ' channels instead.'
1796
1942
  ),
1943
+ (
1944
+ 'There are outliers in the scaled frequency values of the RF or'
1945
+ ' Organic RF channels. Please check for any possible data'
1946
+ ' errors.'
1947
+ ),
1797
1948
  ),
1798
1949
  ]
1799
1950
 
1800
- for data_da, message in checks:
1951
+ for data_da, std_message, outlier_message in checks:
1801
1952
  if data_da is None:
1802
1953
  continue
1803
- finding, artifact = self._check_std(
1954
+ current_findings, artifact = self._check_std(
1804
1955
  data=data_da,
1805
1956
  level=eda_outcome.AnalysisLevel.NATIONAL,
1806
- zero_std_message=message,
1957
+ zero_std_message=std_message,
1958
+ outlier_message=outlier_message,
1807
1959
  )
1808
1960
  artifacts.append(artifact)
1809
- if finding:
1810
- findings.append(finding)
1961
+ if current_findings:
1962
+ findings.extend(current_findings)
1811
1963
 
1812
1964
  # Add an INFO finding if no findings were added.
1813
1965
  if not findings:
1814
1966
  findings.append(
1815
1967
  eda_outcome.EDAFinding(
1816
1968
  severity=eda_outcome.EDASeverity.INFO,
1817
- explanation=(
1818
- 'Please review any identified outliers and the standard'
1819
- ' deviation.'
1820
- ),
1969
+ explanation='Please review the computed standard deviations.',
1970
+ finding_cause=eda_outcome.FindingCause.NONE,
1821
1971
  )
1822
1972
  )
1823
1973
 
@@ -1841,7 +1991,15 @@ class EDAEngine:
1841
1991
  return self.check_geo_std()
1842
1992
 
1843
1993
  def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1844
- """Computes geo-level variance inflation factor among treatments and controls."""
1994
+ """Checks geo variance inflation factor among treatments and controls.
1995
+
1996
+ The VIF calculation only focuses on multicollinearity among non-constant
1997
+ variables. Any variable with constant values will result in a NaN VIF value.
1998
+
1999
+ Returns:
2000
+ An EDAOutcome object with findings and result values.
2001
+ """
2002
+
1845
2003
  if self._is_national_data:
1846
2004
  raise ValueError(
1847
2005
  'Geo-level VIF checks are not applicable for national models.'
@@ -1851,12 +2009,12 @@ class EDAEngine:
1851
2009
  tc_da = self._stacked_treatment_control_scaled_da
1852
2010
  overall_threshold = self._spec.vif_spec.overall_threshold
1853
2011
 
1854
- overall_vif_da = _calculate_vif(tc_da, _STACK_VAR_COORD_NAME)
2012
+ overall_vif_da = _calculate_vif(tc_da, eda_constants.VARIABLE)
1855
2013
  extreme_overall_vif_da = overall_vif_da.where(
1856
2014
  overall_vif_da > overall_threshold
1857
2015
  )
1858
2016
  extreme_overall_vif_df = extreme_overall_vif_da.to_dataframe(
1859
- name=_VIF_COL_NAME
2017
+ name=eda_constants.VIF_COL_NAME
1860
2018
  ).dropna()
1861
2019
 
1862
2020
  overall_vif_artifact = eda_outcome.VIFArtifact(
@@ -1868,11 +2026,11 @@ class EDAEngine:
1868
2026
  # Geo level VIF check.
1869
2027
  geo_threshold = self._spec.vif_spec.geo_threshold
1870
2028
  geo_vif_da = tc_da.groupby(constants.GEO).map(
1871
- lambda x: _calculate_vif(x, _STACK_VAR_COORD_NAME)
2029
+ lambda x: _calculate_vif(x, eda_constants.VARIABLE)
1872
2030
  )
1873
2031
  extreme_geo_vif_da = geo_vif_da.where(geo_vif_da > geo_threshold)
1874
2032
  extreme_geo_vif_df = extreme_geo_vif_da.to_dataframe(
1875
- name=_VIF_COL_NAME
2033
+ name=eda_constants.VIF_COL_NAME
1876
2034
  ).dropna()
1877
2035
 
1878
2036
  geo_vif_artifact = eda_outcome.VIFArtifact(
@@ -1883,33 +2041,47 @@ class EDAEngine:
1883
2041
 
1884
2042
  findings = []
1885
2043
  if not extreme_overall_vif_df.empty:
1886
- high_vif_vars = extreme_overall_vif_df.index.to_list()
2044
+ high_vif_vars_message = (
2045
+ '\nVariables with extreme VIF:'
2046
+ f' {extreme_overall_vif_df.index.to_list()}'
2047
+ )
1887
2048
  findings.append(
1888
2049
  eda_outcome.EDAFinding(
1889
2050
  severity=eda_outcome.EDASeverity.ERROR,
1890
- explanation=(
1891
- 'Some variables have extreme multicollinearity (VIF'
1892
- f' >{overall_threshold}) across all times and geos. To'
1893
- ' address multicollinearity, please drop any variable that'
1894
- ' is a linear combination of other variables. Otherwise,'
1895
- ' consider combining variables.\n'
1896
- f'Variables with extreme VIF: {high_vif_vars}'
2051
+ explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
2052
+ threshold=overall_threshold,
2053
+ aggregation='times and geos',
2054
+ additional_info=high_vif_vars_message,
1897
2055
  ),
2056
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
2057
+ associated_artifact=overall_vif_artifact,
1898
2058
  )
1899
2059
  )
1900
- elif not extreme_geo_vif_df.empty:
2060
+
2061
+ # Variables that cause overall level findings are very likely to cause
2062
+ # geo-level findings as well, so we exclude them when determining
2063
+ # geo-level findings. This is to avoid over-reporting findings.
2064
+ overall_vars_index = extreme_overall_vif_df.index
2065
+ is_in_overall = extreme_geo_vif_df.index.get_level_values(
2066
+ eda_constants.VARIABLE
2067
+ ).isin(overall_vars_index)
2068
+ geo_df_for_attention = extreme_geo_vif_df[~is_in_overall]
2069
+
2070
+ if not geo_df_for_attention.empty:
1901
2071
  findings.append(
1902
2072
  eda_outcome.EDAFinding(
1903
2073
  severity=eda_outcome.EDASeverity.ATTENTION,
1904
2074
  explanation=(
1905
- 'Some variables have extreme multicollinearity (with VIF >'
1906
- f' {geo_threshold}) in certain geo(s). Consider checking your'
1907
- ' data, and/or combining these variables if they also have'
1908
- ' high VIF in other geos.'
2075
+ eda_constants.MULTICOLLINEARITY_ATTENTION.format(
2076
+ threshold=geo_threshold, additional_info=''
2077
+ )
1909
2078
  ),
2079
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
2080
+ associated_artifact=geo_vif_artifact,
1910
2081
  )
1911
2082
  )
1912
- else:
2083
+
2084
+ if not findings:
1913
2085
  findings.append(
1914
2086
  eda_outcome.EDAFinding(
1915
2087
  severity=eda_outcome.EDASeverity.INFO,
@@ -1919,6 +2091,7 @@ class EDAEngine:
1919
2091
  ' jeopardize model identifiability and model convergence.'
1920
2092
  ' Consider combining the variables if high VIF occurs.'
1921
2093
  ),
2094
+ finding_cause=eda_outcome.FindingCause.NONE,
1922
2095
  )
1923
2096
  )
1924
2097
 
@@ -1931,14 +2104,21 @@ class EDAEngine:
1931
2104
  def check_national_vif(
1932
2105
  self,
1933
2106
  ) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1934
- """Computes national-level variance inflation factor among treatments and controls."""
2107
+ """Checks national variance inflation factor among treatments and controls.
2108
+
2109
+ The VIF calculation only focuses on multicollinearity among non-constant
2110
+ variables. Any variable with constant values will result in a NaN VIF value.
2111
+
2112
+ Returns:
2113
+ An EDAOutcome object with findings and result values.
2114
+ """
1935
2115
  national_tc_da = self._stacked_national_treatment_control_scaled_da
1936
2116
  national_threshold = self._spec.vif_spec.national_threshold
1937
- national_vif_da = _calculate_vif(national_tc_da, _STACK_VAR_COORD_NAME)
2117
+ national_vif_da = _calculate_vif(national_tc_da, eda_constants.VARIABLE)
1938
2118
 
1939
2119
  extreme_national_vif_df = (
1940
2120
  national_vif_da.where(national_vif_da > national_threshold)
1941
- .to_dataframe(name=_VIF_COL_NAME)
2121
+ .to_dataframe(name=eda_constants.VIF_COL_NAME)
1942
2122
  .dropna()
1943
2123
  )
1944
2124
  national_vif_artifact = eda_outcome.VIFArtifact(
@@ -1949,18 +2129,20 @@ class EDAEngine:
1949
2129
 
1950
2130
  findings = []
1951
2131
  if not extreme_national_vif_df.empty:
1952
- high_vif_vars = extreme_national_vif_df.index.to_list()
2132
+ high_vif_vars_message = (
2133
+ '\nVariables with extreme VIF:'
2134
+ f' {extreme_national_vif_df.index.to_list()}'
2135
+ )
1953
2136
  findings.append(
1954
2137
  eda_outcome.EDAFinding(
1955
2138
  severity=eda_outcome.EDASeverity.ERROR,
1956
- explanation=(
1957
- 'Some variables have extreme multicollinearity (with VIF >'
1958
- f' {national_threshold}) across all times. To address'
1959
- ' multicollinearity, please drop any variable that is a'
1960
- ' linear combination of other variables. Otherwise, consider'
1961
- ' combining variables.\n'
1962
- f'Variables with extreme VIF: {high_vif_vars}'
2139
+ explanation=eda_constants.MULTICOLLINEARITY_ERROR.format(
2140
+ threshold=national_threshold,
2141
+ aggregation='times',
2142
+ additional_info=high_vif_vars_message,
1963
2143
  ),
2144
+ finding_cause=eda_outcome.FindingCause.MULTICOLLINEARITY,
2145
+ associated_artifact=national_vif_artifact,
1964
2146
  )
1965
2147
  )
1966
2148
  else:
@@ -1973,6 +2155,7 @@ class EDAEngine:
1973
2155
  ' jeopardize model identifiability and model convergence.'
1974
2156
  ' Consider combining the variables if high VIF occurs.'
1975
2157
  ),
2158
+ finding_cause=eda_outcome.FindingCause.NONE,
1976
2159
  )
1977
2160
  )
1978
2161
  return eda_outcome.EDAOutcome(
@@ -1984,6 +2167,9 @@ class EDAEngine:
1984
2167
  def check_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1985
2168
  """Computes variance inflation factor among treatments and controls.
1986
2169
 
2170
+ The VIF calculation only focuses on multicollinearity among non-constant
2171
+ variables. Any variable with constant values will result in a NaN VIF value.
2172
+
1987
2173
  Returns:
1988
2174
  An EDAOutcome object with findings and result values.
1989
2175
  """
@@ -1994,15 +2180,18 @@ class EDAEngine:
1994
2180
 
1995
2181
  @property
1996
2182
  def kpi_has_variability(self) -> bool:
1997
- """Returns True if the KPI has variability across geos and times."""
2183
+ """Whether the KPI has variability across geos and times."""
1998
2184
  return (
1999
2185
  self._overall_scaled_kpi_invariability_artifact.kpi_stdev.item()
2000
- >= _STD_THRESHOLD
2186
+ >= eda_constants.STD_THRESHOLD
2001
2187
  )
2002
2188
 
2003
- def check_overall_kpi_invariability(self) -> eda_outcome.EDAOutcome:
2189
+ def check_overall_kpi_invariability(
2190
+ self,
2191
+ ) -> eda_outcome.EDAOutcome[eda_outcome.KpiInvariabilityArtifact]:
2004
2192
  """Checks if the KPI is constant across all geos and times."""
2005
- kpi = self._overall_scaled_kpi_invariability_artifact.kpi_da.name
2193
+ artifact = self._overall_scaled_kpi_invariability_artifact
2194
+ kpi = artifact.kpi_da.name
2006
2195
  geo_text = '' if self._is_national_data else 'geos and '
2007
2196
 
2008
2197
  if not self.kpi_has_variability:
@@ -2012,6 +2201,8 @@ class EDAEngine:
2012
2201
  f'`{kpi}` is constant across all {geo_text}times, indicating no'
2013
2202
  ' signal in the data. Please fix this data error.'
2014
2203
  ),
2204
+ finding_cause=eda_outcome.FindingCause.VARIABILITY,
2205
+ associated_artifact=artifact,
2015
2206
  )
2016
2207
  else:
2017
2208
  eda_finding = eda_outcome.EDAFinding(
@@ -2019,12 +2210,13 @@ class EDAEngine:
2019
2210
  explanation=(
2020
2211
  f'The {kpi} has variability across {geo_text}times in the data.'
2021
2212
  ),
2213
+ finding_cause=eda_outcome.FindingCause.NONE,
2022
2214
  )
2023
2215
 
2024
2216
  return eda_outcome.EDAOutcome(
2025
2217
  check_type=eda_outcome.EDACheckType.KPI_INVARIABILITY,
2026
2218
  findings=[eda_finding],
2027
- analysis_artifacts=[self._overall_scaled_kpi_invariability_artifact],
2219
+ analysis_artifacts=[artifact],
2028
2220
  )
2029
2221
 
2030
2222
  def check_geo_cost_per_media_unit(
@@ -2081,30 +2273,249 @@ class EDAEngine:
2081
2273
 
2082
2274
  return self.check_geo_cost_per_media_unit()
2083
2275
 
2084
- def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]:
2276
+ def run_all_critical_checks(self) -> eda_outcome.CriticalCheckEDAOutcomes:
2085
2277
  """Runs all critical EDA checks.
2086
2278
 
2087
2279
  Critical checks are those that can result in EDASeverity.ERROR findings.
2088
2280
 
2089
2281
  Returns:
2090
- A list of EDA outcomes, one for each check.
2282
+ A CriticalCheckEDAOutcomes object containing the results of all critical
2283
+ checks.
2091
2284
  """
2092
- outcomes = []
2285
+ outcomes = {}
2093
2286
  for check, check_type in self._critical_checks:
2094
2287
  try:
2095
- outcomes.append(check())
2288
+ outcomes[check_type] = check()
2096
2289
  except Exception as e: # pylint: disable=broad-except
2097
2290
  error_finding = eda_outcome.EDAFinding(
2098
2291
  severity=eda_outcome.EDASeverity.ERROR,
2099
2292
  explanation=(
2100
2293
  f'An error occurred during running {check.__name__}: {e!r}'
2101
2294
  ),
2295
+ finding_cause=eda_outcome.FindingCause.RUNTIME_ERROR,
2296
+ )
2297
+ outcomes[check_type] = eda_outcome.EDAOutcome(
2298
+ check_type=check_type,
2299
+ findings=[error_finding],
2300
+ analysis_artifacts=[],
2102
2301
  )
2103
- outcomes.append(
2104
- eda_outcome.EDAOutcome(
2105
- check_type=check_type,
2106
- findings=[error_finding],
2107
- analysis_artifacts=[],
2302
+
2303
+ return eda_outcome.CriticalCheckEDAOutcomes(
2304
+ kpi_invariability=outcomes[eda_outcome.EDACheckType.KPI_INVARIABILITY],
2305
+ multicollinearity=outcomes[eda_outcome.EDACheckType.MULTICOLLINEARITY],
2306
+ pairwise_correlation=outcomes[
2307
+ eda_outcome.EDACheckType.PAIRWISE_CORRELATION
2308
+ ],
2309
+ )
2310
+
2311
+ def check_variable_geo_time_collinearity(
2312
+ self,
2313
+ ) -> eda_outcome.EDAOutcome[eda_outcome.VariableGeoTimeCollinearityArtifact]:
2314
+ """Compute adjusted R-squared for treatments and controls vs geo and time.
2315
+
2316
+ These checks are applied to geo-level dataset only.
2317
+
2318
+ Returns:
2319
+ An EDAOutcome object containing a VariableGeoTimeCollinearityArtifact.
2320
+ The artifact includes a Dataset with 'rsquared_geo' and 'rsquared_time',
2321
+ showing the adjusted R-squared values for each treatment/control variable
2322
+ when regressed against 'geo' and 'time', respectively. If a variable is
2323
+ constant across geos or times, the corresponding 'rsquared_geo' or
2324
+ 'rsquared_time' value will be NaN.
2325
+ """
2326
+ if self._is_national_data:
2327
+ raise ValueError(
2328
+ 'check_variable_geo_time_collinearity is not supported for national'
2329
+ ' models.'
2330
+ )
2331
+
2332
+ grouped_da = self._stacked_treatment_control_scaled_da.groupby(
2333
+ eda_constants.VARIABLE
2334
+ )
2335
+ rsq_geo = grouped_da.map(_calc_adj_r2, args=(constants.GEO,))
2336
+ rsq_time = grouped_da.map(_calc_adj_r2, args=(constants.TIME,))
2337
+
2338
+ rsquared_ds = xr.Dataset({
2339
+ eda_constants.RSQUARED_GEO: rsq_geo,
2340
+ eda_constants.RSQUARED_TIME: rsq_time,
2341
+ })
2342
+
2343
+ artifact = eda_outcome.VariableGeoTimeCollinearityArtifact(
2344
+ level=eda_outcome.AnalysisLevel.OVERALL,
2345
+ rsquared_ds=rsquared_ds,
2346
+ )
2347
+ findings = [
2348
+ eda_outcome.EDAFinding(
2349
+ severity=eda_outcome.EDASeverity.INFO,
2350
+ explanation=eda_constants.R_SQUARED_TIME_INFO,
2351
+ finding_cause=eda_outcome.FindingCause.NONE,
2352
+ ),
2353
+ eda_outcome.EDAFinding(
2354
+ severity=eda_outcome.EDASeverity.INFO,
2355
+ explanation=eda_constants.R_SQUARED_GEO_INFO,
2356
+ finding_cause=eda_outcome.FindingCause.NONE,
2357
+ ),
2358
+ ]
2359
+
2360
+ return eda_outcome.EDAOutcome(
2361
+ check_type=eda_outcome.EDACheckType.VARIABLE_GEO_TIME_COLLINEARITY,
2362
+ findings=findings,
2363
+ analysis_artifacts=[artifact],
2364
+ )
2365
+
2366
+ def _calculate_population_corr(
2367
+ self, ds: xr.Dataset, *, explanation: str, check_name: str
2368
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
2369
+ """Calculates Spearman correlation between population and data variables.
2370
+
2371
+ Args:
2372
+ ds: An xr.Dataset containing the data variables for which to calculate the
2373
+ correlation with population. The Dataset is expected to have a 'geo'
2374
+ dimension.
2375
+ explanation: A string providing an explanation for the EDA finding.
2376
+ check_name: A string representing the name of the calling check function,
2377
+ used in error messages.
2378
+
2379
+ Returns:
2380
+ An EDAOutcome object containing a PopulationCorrelationArtifact. The
2381
+ artifact includes a Dataset with the Spearman correlation coefficients
2382
+ between each variable in `ds` and the geo population.
2383
+
2384
+ Raises:
2385
+ GeoLevelCheckOnNationalModelError: If the model is national or if
2386
+ `self.geo_population_da` is None.
2387
+ """
2388
+
2389
+ # self.geo_population_da can never be None if the model is geo-level. Adding
2390
+ # this check to make pytype happy.
2391
+ if self._is_national_data or self.geo_population_da is None:
2392
+ raise GeoLevelCheckOnNationalModelError(
2393
+ f'{check_name} is not supported for national models.'
2394
+ )
2395
+
2396
+ corr_ds: xr.Dataset = xr.apply_ufunc(
2397
+ _spearman_coeff,
2398
+ ds.mean(dim=constants.TIME),
2399
+ self.geo_population_da,
2400
+ input_core_dims=[[constants.GEO], [constants.GEO]],
2401
+ vectorize=True,
2402
+ output_dtypes=[float],
2403
+ )
2404
+
2405
+ artifact = eda_outcome.PopulationCorrelationArtifact(
2406
+ level=eda_outcome.AnalysisLevel.OVERALL,
2407
+ correlation_ds=corr_ds,
2408
+ )
2409
+
2410
+ return eda_outcome.EDAOutcome(
2411
+ check_type=eda_outcome.EDACheckType.POPULATION_CORRELATION,
2412
+ findings=[
2413
+ eda_outcome.EDAFinding(
2414
+ severity=eda_outcome.EDASeverity.INFO,
2415
+ explanation=explanation,
2416
+ finding_cause=eda_outcome.FindingCause.NONE,
2417
+ associated_artifact=artifact,
2108
2418
  )
2419
+ ],
2420
+ analysis_artifacts=[artifact],
2421
+ )
2422
+
2423
+ def check_population_corr_scaled_treatment_control(
2424
+ self,
2425
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
2426
+ """Checks Spearman correlation between population and treatments/controls.
2427
+
2428
+ Calculates correlation between population and time-averaged
2429
+ treatments and controls. High correlation for controls or non-media
2430
+ channels may indicate a need for population-scaling. High
2431
+ correlation for other media channels may indicate double-scaling.
2432
+
2433
+ Returns:
2434
+ An EDAOutcome object with findings and result values.
2435
+
2436
+ Raises:
2437
+ GeoLevelCheckOnNationalModelError: If the model is national or geo
2438
+ population data is missing.
2439
+ """
2440
+ return self._calculate_population_corr(
2441
+ ds=self.treatment_control_scaled_ds,
2442
+ explanation=eda_constants.POPULATION_CORRELATION_SCALED_TREATMENT_CONTROL_INFO,
2443
+ check_name='check_population_corr_scaled_treatment_control',
2444
+ )
2445
+
2446
+ def check_population_corr_raw_media(
2447
+ self,
2448
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PopulationCorrelationArtifact]:
2449
+ """Checks Spearman correlation between population and raw media executions.
2450
+
2451
+ Calculates correlation between population and time-averaged raw
2452
+ media executions (paid/organic impressions/reach). These are
2453
+ expected to have reasonably high correlation with population.
2454
+
2455
+ Returns:
2456
+ An EDAOutcome object with findings and result values.
2457
+
2458
+ Raises:
2459
+ GeoLevelCheckOnNationalModelError: If the model is national or geo
2460
+ population data is missing.
2461
+ """
2462
+ to_merge = (
2463
+ da
2464
+ for da in [
2465
+ self.media_raw_da,
2466
+ self.organic_media_raw_da,
2467
+ self.reach_raw_da,
2468
+ self.organic_reach_raw_da,
2469
+ ]
2470
+ if da is not None
2471
+ )
2472
+ # Handle the case where there are no media channels.
2473
+
2474
+ return self._calculate_population_corr(
2475
+ ds=xr.merge(to_merge, join='inner'),
2476
+ explanation=eda_constants.POPULATION_CORRELATION_RAW_MEDIA_INFO,
2477
+ check_name='check_population_corr_raw_media',
2478
+ )
2479
+
2480
+ def _check_prior_probability(
2481
+ self,
2482
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PriorProbabilityArtifact]:
2483
+ """Checks the prior probability of a negative baseline.
2484
+
2485
+ Returns:
2486
+ An EDAOutcome object containing a PriorProbabilityArtifact. The artifact
2487
+ includes a mock prior negative baseline probability and a DataArray of
2488
+ mock mean prior contributions per channel.
2489
+ """
2490
+ # TODO: b/476128592 - currently, this check is blocked. for the meantime,
2491
+ # we will return mock data for the report.
2492
+ channel_names = self._model_context.input_data.get_all_channels()
2493
+ mean_prior_contribution = np.random.uniform(
2494
+ size=len(channel_names), low=0.0, high=0.05
2495
+ )
2496
+ mean_prior_contribution_da = xr.DataArray(
2497
+ mean_prior_contribution,
2498
+ coords={constants.CHANNEL: channel_names},
2499
+ dims=[constants.CHANNEL],
2500
+ )
2501
+
2502
+ artifact = eda_outcome.PriorProbabilityArtifact(
2503
+ level=eda_outcome.AnalysisLevel.OVERALL,
2504
+ prior_negative_baseline_prob=0.123,
2505
+ mean_prior_contribution_da=mean_prior_contribution_da,
2506
+ )
2507
+
2508
+ findings = [
2509
+ eda_outcome.EDAFinding(
2510
+ severity=eda_outcome.EDASeverity.INFO,
2511
+ explanation=eda_constants.PRIOR_PROBABILITY_REPORT_INFO,
2512
+ finding_cause=eda_outcome.FindingCause.NONE,
2513
+ associated_artifact=artifact,
2109
2514
  )
2110
- return outcomes
2515
+ ]
2516
+
2517
+ return eda_outcome.EDAOutcome(
2518
+ check_type=eda_outcome.EDACheckType.PRIOR_PROBABILITY,
2519
+ findings=findings,
2520
+ analysis_artifacts=[artifact],
2521
+ )