google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -14,19 +14,111 @@
14
14
 
15
15
  """Meridian EDA Engine."""
16
16
 
17
+ from __future__ import annotations
18
+
17
19
  import dataclasses
18
20
  import functools
19
- from typing import Callable, Dict, Optional, TypeAlias
21
+ import typing
22
+ from typing import Optional, Protocol, Sequence
23
+
24
+ from meridian import backend
20
25
  from meridian import constants
21
- from meridian.model import model
22
26
  from meridian.model import transformers
27
+ from meridian.model.eda import constants as eda_constants
28
+ from meridian.model.eda import eda_outcome
29
+ from meridian.model.eda import eda_spec
23
30
  import numpy as np
24
- import tensorflow as tf
31
+ import pandas as pd
32
+ import statsmodels.api as sm
33
+ from statsmodels.stats import outliers_influence
25
34
  import xarray as xr
26
35
 
27
36
 
37
+ if typing.TYPE_CHECKING:
38
+ from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
39
+
40
+ __all__ = ['EDAEngine', 'GeoLevelCheckOnNationalModelError']
41
+
28
42
  _DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
29
- AggregationMap: TypeAlias = Dict[str, Callable[[xr.DataArray], np.ndarray]]
43
+ _CORRELATION_COL_NAME = eda_constants.CORRELATION
44
+ _STACK_VAR_COORD_NAME = eda_constants.VARIABLE
45
+ _CORR_VAR1 = eda_constants.VARIABLE_1
46
+ _CORR_VAR2 = eda_constants.VARIABLE_2
47
+ _CORRELATION_MATRIX_NAME = 'correlation_matrix'
48
+ _OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
49
+ _GEO_PAIRWISE_CORR_THRESHOLD = 0.999
50
+ _NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
51
+ _EMPTY_DF_FOR_EXTREME_CORR_PAIRS = pd.DataFrame(
52
+ columns=[_CORR_VAR1, _CORR_VAR2, _CORRELATION_COL_NAME]
53
+ )
54
+ _Q1_THRESHOLD = 0.25
55
+ _Q3_THRESHOLD = 0.75
56
+ _IQR_MULTIPLIER = 1.5
57
+ _STD_WITH_OUTLIERS_VAR_NAME = 'std_with_outliers'
58
+ _STD_WITHOUT_OUTLIERS_VAR_NAME = 'std_without_outliers'
59
+ _STD_THRESHOLD = 1e-4
60
+ _OUTLIERS_COL_NAME = 'outliers'
61
+ _ABS_OUTLIERS_COL_NAME = 'abs_outliers'
62
+ _VIF_COL_NAME = 'VIF'
63
+
64
+
65
+ class _NamedEDACheckCallable(Protocol):
66
+ """A callable that returns an EDAOutcome and has a __name__ attribute."""
67
+
68
+ __name__: str
69
+
70
+ def __call__(self) -> eda_outcome.EDAOutcome:
71
+ ...
72
+
73
+
74
+ class GeoLevelCheckOnNationalModelError(Exception):
75
+ """Raised when a geo-level check is called on a national model."""
76
+
77
+ pass
78
+
79
+
80
+ @dataclasses.dataclass(frozen=True)
81
+ class _RFNames:
82
+ """Holds constant names for reach and frequency data arrays."""
83
+
84
+ reach: str
85
+ reach_scaled: str
86
+ frequency: str
87
+ impressions: str
88
+ impressions_scaled: str
89
+ national_reach: str
90
+ national_reach_scaled: str
91
+ national_frequency: str
92
+ national_impressions: str
93
+ national_impressions_scaled: str
94
+
95
+
96
+ _ORGANIC_RF_NAMES = _RFNames(
97
+ reach=constants.ORGANIC_REACH,
98
+ reach_scaled=constants.ORGANIC_REACH_SCALED,
99
+ frequency=constants.ORGANIC_FREQUENCY,
100
+ impressions=constants.ORGANIC_RF_IMPRESSIONS,
101
+ impressions_scaled=constants.ORGANIC_RF_IMPRESSIONS_SCALED,
102
+ national_reach=constants.NATIONAL_ORGANIC_REACH,
103
+ national_reach_scaled=constants.NATIONAL_ORGANIC_REACH_SCALED,
104
+ national_frequency=constants.NATIONAL_ORGANIC_FREQUENCY,
105
+ national_impressions=constants.NATIONAL_ORGANIC_RF_IMPRESSIONS,
106
+ national_impressions_scaled=constants.NATIONAL_ORGANIC_RF_IMPRESSIONS_SCALED,
107
+ )
108
+
109
+
110
+ _RF_NAMES = _RFNames(
111
+ reach=constants.REACH,
112
+ reach_scaled=constants.REACH_SCALED,
113
+ frequency=constants.FREQUENCY,
114
+ impressions=constants.RF_IMPRESSIONS,
115
+ impressions_scaled=constants.RF_IMPRESSIONS_SCALED,
116
+ national_reach=constants.NATIONAL_REACH,
117
+ national_reach_scaled=constants.NATIONAL_REACH_SCALED,
118
+ national_frequency=constants.NATIONAL_FREQUENCY,
119
+ national_impressions=constants.NATIONAL_RF_IMPRESSIONS,
120
+ national_impressions_scaled=constants.NATIONAL_RF_IMPRESSIONS_SCALED,
121
+ )
30
122
 
31
123
 
32
124
  @dataclasses.dataclass(frozen=True, kw_only=True)
@@ -36,45 +128,220 @@ class ReachFrequencyData:
36
128
  Attributes:
37
129
  reach_raw_da: Raw reach data.
38
130
  reach_scaled_da: Scaled reach data.
39
- reach_raw_da_national: National raw reach data.
40
- reach_scaled_da_national: National scaled reach data.
131
+ national_reach_raw_da: National raw reach data.
132
+ national_reach_scaled_da: National scaled reach data.
41
133
  frequency_da: Frequency data.
42
- frequency_da_national: National frequency data.
134
+ national_frequency_da: National frequency data.
43
135
  rf_impressions_scaled_da: Scaled reach * frequency impressions data.
44
- rf_impressions_scaled_da_national: National scaled reach * frequency
136
+ national_rf_impressions_scaled_da: National scaled reach * frequency
45
137
  impressions data.
46
138
  rf_impressions_raw_da: Raw reach * frequency impressions data.
47
- rf_impressions_raw_da_national: National raw reach * frequency impressions
139
+ national_rf_impressions_raw_da: National raw reach * frequency impressions
48
140
  data.
49
141
  """
50
142
 
51
143
  reach_raw_da: xr.DataArray
52
144
  reach_scaled_da: xr.DataArray
53
- reach_raw_da_national: xr.DataArray
54
- reach_scaled_da_national: xr.DataArray
145
+ national_reach_raw_da: xr.DataArray
146
+ national_reach_scaled_da: xr.DataArray
55
147
  frequency_da: xr.DataArray
56
- frequency_da_national: xr.DataArray
148
+ national_frequency_da: xr.DataArray
57
149
  rf_impressions_scaled_da: xr.DataArray
58
- rf_impressions_scaled_da_national: xr.DataArray
150
+ national_rf_impressions_scaled_da: xr.DataArray
59
151
  rf_impressions_raw_da: xr.DataArray
60
- rf_impressions_raw_da_national: xr.DataArray
152
+ national_rf_impressions_raw_da: xr.DataArray
61
153
 
62
154
 
63
- @dataclasses.dataclass(frozen=True, kw_only=True)
64
- class AggregationConfig:
65
- """Configuration for custom aggregation functions.
155
+ def _data_array_like(
156
+ *, da: xr.DataArray, values: np.ndarray | backend.Tensor
157
+ ) -> xr.DataArray:
158
+ """Returns a DataArray from `values` with the same structure as `da`.
66
159
 
67
- Attributes:
68
- control_variables: A dictionary mapping control variable names to
69
- aggregation functions. Defaults to `np.sum` if a variable is not
70
- specified.
71
- non_media_treatments: A dictionary mapping non-media variable names to
72
- aggregation functions. Defaults to `np.sum` if a variable is not
73
- specified.
160
+ Args:
161
+ da: The DataArray whose structure (dimensions, coordinates, name, and attrs)
162
+ will be used for the new DataArray.
163
+ values: The numpy array or backend tensor to use as the values for the new
164
+ DataArray.
165
+
166
+ Returns:
167
+ A new DataArray with the provided `values` and the same structure as `da`.
74
168
  """
169
+ return xr.DataArray(
170
+ values,
171
+ coords=da.coords,
172
+ dims=da.dims,
173
+ name=da.name,
174
+ attrs=da.attrs,
175
+ )
176
+
177
+
178
+ def stack_variables(
179
+ ds: xr.Dataset, coord_name: str = _STACK_VAR_COORD_NAME
180
+ ) -> xr.DataArray:
181
+ """Stacks data variables of a Dataset into a single DataArray.
75
182
 
76
- control_variables: AggregationMap = dataclasses.field(default_factory=dict)
77
- non_media_treatments: AggregationMap = dataclasses.field(default_factory=dict)
183
+ This function is designed to work with Datasets that have 'time' or 'geo'
184
+ dimensions, which are preserved. Other dimensions are stacked into a new
185
+ dimension.
186
+
187
+ Args:
188
+ ds: The input xarray.Dataset to stack.
189
+ coord_name: The name of the new coordinate for the stacked dimension.
190
+
191
+ Returns:
192
+ An xarray.DataArray with the specified dimensions stacked.
193
+ """
194
+ dims = []
195
+ coords = []
196
+ sample_dims = []
197
+ # Dimensions have the same names as the coordinates.
198
+ for dim in ds.dims:
199
+ if dim in [constants.TIME, constants.GEO]:
200
+ sample_dims.append(dim)
201
+ continue
202
+ dims.append(dim)
203
+ coords.extend(ds.coords[dim].values.tolist())
204
+
205
+ da = ds.to_stacked_array(coord_name, sample_dims=sample_dims)
206
+ da = da.reset_index(dims, drop=True).assign_coords({coord_name: coords})
207
+ return da
208
+
209
+
210
+ def _compute_correlation_matrix(
211
+ input_da: xr.DataArray, dims: str | Sequence[str]
212
+ ) -> xr.DataArray:
213
+ """Computes the correlation matrix for variables in a DataArray.
214
+
215
+ Args:
216
+ input_da: An xr.DataArray containing variables for which to compute
217
+ correlations.
218
+ dims: Dimensions along which to compute correlations. Can only be TIME or
219
+ GEO.
220
+
221
+ Returns:
222
+ An xr.DataArray containing the correlation matrix.
223
+ """
224
+ # Create two versions for correlation
225
+ da1 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR1})
226
+ da2 = input_da.rename({_STACK_VAR_COORD_NAME: _CORR_VAR2})
227
+
228
+ # Compute pairwise correlation across dims. Other dims are broadcasted.
229
+ corr_mat_da = xr.corr(da1, da2, dim=dims)
230
+ corr_mat_da.name = _CORRELATION_MATRIX_NAME
231
+ return corr_mat_da
232
+
233
+
234
+ def _get_upper_triangle_corr_mat(corr_mat_da: xr.DataArray) -> xr.DataArray:
235
+ """Gets the upper triangle of a correlation matrix.
236
+
237
+ Args:
238
+ corr_mat_da: An xr.DataArray containing the correlation matrix.
239
+
240
+ Returns:
241
+ An xr.DataArray containing only the elements in the upper triangle of the
242
+ correlation matrix, with other elements masked as NaN.
243
+ """
244
+ n_vars = corr_mat_da.sizes[_CORR_VAR1]
245
+ mask_np = np.triu(np.ones((n_vars, n_vars), dtype=bool), k=1)
246
+ mask = xr.DataArray(
247
+ mask_np,
248
+ dims=[_CORR_VAR1, _CORR_VAR2],
249
+ coords={
250
+ _CORR_VAR1: corr_mat_da[_CORR_VAR1],
251
+ _CORR_VAR2: corr_mat_da[_CORR_VAR2],
252
+ },
253
+ )
254
+ return corr_mat_da.where(mask)
255
+
256
+
257
+ def _find_extreme_corr_pairs(
258
+ extreme_corr_da: xr.DataArray, extreme_corr_threshold: float
259
+ ) -> pd.DataFrame:
260
+ """Finds extreme correlation pairs in a correlation matrix."""
261
+ corr_tri = _get_upper_triangle_corr_mat(extreme_corr_da)
262
+ extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
263
+
264
+ df = extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME).dropna()
265
+ if df.empty:
266
+ return _EMPTY_DF_FOR_EXTREME_CORR_PAIRS.copy()
267
+ return df.sort_values(
268
+ by=_CORRELATION_COL_NAME, ascending=False, inplace=False
269
+ )
270
+
271
+
272
+ def _calculate_std(
273
+ input_da: xr.DataArray,
274
+ ) -> tuple[xr.Dataset, pd.DataFrame]:
275
+ """Helper function to compute std with and without outliers.
276
+
277
+ Args:
278
+ input_da: A DataArray for which to calculate the std.
279
+
280
+ Returns:
281
+ A tuple where the first element is a Dataset with two data variables:
282
+ 'std_incl_outliers' and 'std_excl_outliers'. The second element is a
283
+ DataFrame with columns for variables, geo (if applicable), time, and
284
+ outlier values.
285
+ """
286
+ std_with_outliers = input_da.std(dim=constants.TIME, ddof=1)
287
+
288
+ # TODO: Allow users to specify custom outlier definitions.
289
+ q1 = input_da.quantile(_Q1_THRESHOLD, dim=constants.TIME)
290
+ q3 = input_da.quantile(_Q3_THRESHOLD, dim=constants.TIME)
291
+ iqr = q3 - q1
292
+ lower_bound = q1 - _IQR_MULTIPLIER * iqr
293
+ upper_bound = q3 + _IQR_MULTIPLIER * iqr
294
+
295
+ da_no_outlier = input_da.where(
296
+ (input_da >= lower_bound) & (input_da <= upper_bound)
297
+ )
298
+ std_without_outliers = da_no_outlier.std(dim=constants.TIME, ddof=1)
299
+
300
+ std_ds = xr.Dataset({
301
+ _STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
302
+ _STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
303
+ })
304
+
305
+ outlier_da = input_da.where(
306
+ (input_da < lower_bound) | (input_da > upper_bound)
307
+ )
308
+
309
+ outlier_df = outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME).dropna()
310
+ outlier_df = outlier_df.assign(
311
+ **{_ABS_OUTLIERS_COL_NAME: np.abs(outlier_df[_OUTLIERS_COL_NAME])}
312
+ ).sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
313
+
314
+ return std_ds, outlier_df
315
+
316
+
317
+ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
318
+ """Helper function to compute variance inflation factor.
319
+
320
+ Args:
321
+ input_da: A DataArray for which to calculate the VIF over sample dimensions
322
+ (e.g. time and geo if applicable).
323
+ var_dim: The dimension name of the variable to compute VIF for.
324
+
325
+ Returns:
326
+ A DataArray containing the VIF for each variable in the variable dimension.
327
+ """
328
+ num_vars = input_da.sizes[var_dim]
329
+ np_data = input_da.values.reshape(-1, num_vars)
330
+ np_data_with_const = sm.add_constant(np_data, prepend=True)
331
+
332
+ # Compute VIF for each variable excluding const which is the first one in the
333
+ # 'variable' dimension.
334
+ vifs = [
335
+ outliers_influence.variance_inflation_factor(np_data_with_const, i)
336
+ for i in range(1, num_vars + 1)
337
+ ]
338
+
339
+ vif_da = xr.DataArray(
340
+ vifs,
341
+ coords={var_dim: input_da[var_dim].values},
342
+ dims=[var_dim],
343
+ )
344
+ return vif_da
78
345
 
79
346
 
80
347
  class EDAEngine:
@@ -83,10 +350,19 @@ class EDAEngine:
83
350
  def __init__(
84
351
  self,
85
352
  meridian: model.Meridian,
86
- agg_config: AggregationConfig = AggregationConfig(),
353
+ spec: eda_spec.EDASpec = eda_spec.EDASpec(),
87
354
  ):
88
355
  self._meridian = meridian
89
- self._agg_config = agg_config
356
+ self._spec = spec
357
+ self._agg_config = self._spec.aggregation_config
358
+
359
+ @property
360
+ def spec(self) -> eda_spec.EDASpec:
361
+ return self._spec
362
+
363
+ @property
364
+ def _is_national_data(self) -> bool:
365
+ return self._meridian.is_national
90
366
 
91
367
  @functools.cached_property
92
368
  def controls_scaled_da(self) -> xr.DataArray | None:
@@ -96,33 +372,39 @@ class EDAEngine:
96
372
  da=self._meridian.input_data.controls,
97
373
  values=self._meridian.controls_scaled,
98
374
  )
375
+ controls_scaled_da.name = constants.CONTROLS_SCALED
99
376
  return controls_scaled_da
100
377
 
101
378
  @functools.cached_property
102
- def controls_scaled_da_national(self) -> xr.DataArray | None:
103
- """Returns the national controls data array."""
379
+ def national_controls_scaled_da(self) -> xr.DataArray | None:
380
+ """Returns the national scaled controls data array."""
104
381
  if self._meridian.input_data.controls is None:
105
382
  return None
106
- if self._meridian.is_national:
383
+ if self._is_national_data:
107
384
  if self.controls_scaled_da is None:
108
385
  # This case should be impossible given the check above.
109
386
  raise RuntimeError(
110
387
  'controls_scaled_da is None when controls is not None.'
111
388
  )
112
- return self.controls_scaled_da.squeeze(constants.GEO)
389
+ national_da = self.controls_scaled_da.squeeze(constants.GEO, drop=True)
390
+ national_da.name = constants.NATIONAL_CONTROLS_SCALED
113
391
  else:
114
- return self._aggregate_and_scale_geo_da(
392
+ national_da = self._aggregate_and_scale_geo_da(
115
393
  self._meridian.input_data.controls,
394
+ constants.NATIONAL_CONTROLS_SCALED,
116
395
  transformers.CenteringAndScalingTransformer,
117
396
  constants.CONTROL_VARIABLE,
118
397
  self._agg_config.control_variables,
119
398
  )
399
+ return national_da
120
400
 
121
401
  @functools.cached_property
122
402
  def media_raw_da(self) -> xr.DataArray | None:
123
403
  if self._meridian.input_data.media is None:
124
404
  return None
125
- return self._truncate_media_time(self._meridian.input_data.media)
405
+ raw_media_da = self._truncate_media_time(self._meridian.input_data.media)
406
+ raw_media_da.name = constants.MEDIA
407
+ return raw_media_da
126
408
 
127
409
  @functools.cached_property
128
410
  def media_scaled_da(self) -> xr.DataArray | None:
@@ -132,68 +414,84 @@ class EDAEngine:
132
414
  da=self._meridian.input_data.media,
133
415
  values=self._meridian.media_tensors.media_scaled,
134
416
  )
417
+ media_scaled_da.name = constants.MEDIA_SCALED
135
418
  return self._truncate_media_time(media_scaled_da)
136
419
 
137
420
  @functools.cached_property
138
421
  def media_spend_da(self) -> xr.DataArray | None:
139
- if self._meridian.input_data.media_spend is None:
140
- return None
141
- media_spend_da = _data_array_like(
142
- da=self._meridian.input_data.media_spend,
143
- values=self._meridian.media_tensors.media_spend,
144
- )
422
+ """Returns media spend.
423
+
424
+ If the input spend is aggregated, it is allocated across geo and time
425
+ proportionally to media units.
426
+ """
145
427
  # No need to truncate the media time for media spend.
146
- return media_spend_da
428
+ da = self._meridian.input_data.allocated_media_spend
429
+ if da is None:
430
+ return None
431
+ da = da.copy()
432
+ da.name = constants.MEDIA_SPEND
433
+ return da
147
434
 
148
435
  @functools.cached_property
149
- def media_spend_da_national(self) -> xr.DataArray | None:
436
+ def national_media_spend_da(self) -> xr.DataArray | None:
150
437
  """Returns the national media spend data array."""
151
- if self._meridian.input_data.media_spend is None:
438
+ media_spend = self.media_spend_da
439
+ if media_spend is None:
152
440
  return None
153
- if self._meridian.is_national:
154
- if self.media_spend_da is None:
155
- # This case should be impossible given the check above.
156
- raise RuntimeError(
157
- 'media_spend_da is None when media_spend is not None.'
158
- )
159
- return self.media_spend_da.squeeze(constants.GEO)
441
+ if self._is_national_data:
442
+ national_da = media_spend.squeeze(constants.GEO, drop=True)
443
+ national_da.name = constants.NATIONAL_MEDIA_SPEND
160
444
  else:
161
- return self._aggregate_and_scale_geo_da(
162
- self._meridian.input_data.media_spend,
445
+ national_da = self._aggregate_and_scale_geo_da(
446
+ self._meridian.input_data.allocated_media_spend,
447
+ constants.NATIONAL_MEDIA_SPEND,
163
448
  None,
164
449
  )
450
+ return national_da
165
451
 
166
452
  @functools.cached_property
167
- def media_raw_da_national(self) -> xr.DataArray | None:
453
+ def national_media_raw_da(self) -> xr.DataArray | None:
454
+ """Returns the national raw media data array."""
168
455
  if self.media_raw_da is None:
169
456
  return None
170
- if self._meridian.is_national:
171
- return self.media_raw_da.squeeze(constants.GEO)
457
+ if self._is_national_data:
458
+ national_da = self.media_raw_da.squeeze(constants.GEO, drop=True)
459
+ national_da.name = constants.NATIONAL_MEDIA
172
460
  else:
173
461
  # Note that media is summable by assumption.
174
- return self._aggregate_and_scale_geo_da(
462
+ national_da = self._aggregate_and_scale_geo_da(
175
463
  self.media_raw_da,
464
+ constants.NATIONAL_MEDIA,
176
465
  None,
177
466
  )
467
+ return national_da
178
468
 
179
469
  @functools.cached_property
180
- def media_scaled_da_national(self) -> xr.DataArray | None:
470
+ def national_media_scaled_da(self) -> xr.DataArray | None:
471
+ """Returns the national scaled media data array."""
181
472
  if self.media_scaled_da is None:
182
473
  return None
183
- if self._meridian.is_national:
184
- return self.media_scaled_da.squeeze(constants.GEO)
474
+ if self._is_national_data:
475
+ national_da = self.media_scaled_da.squeeze(constants.GEO, drop=True)
476
+ national_da.name = constants.NATIONAL_MEDIA_SCALED
185
477
  else:
186
478
  # Note that media is summable by assumption.
187
- return self._aggregate_and_scale_geo_da(
479
+ national_da = self._aggregate_and_scale_geo_da(
188
480
  self.media_raw_da,
481
+ constants.NATIONAL_MEDIA_SCALED,
189
482
  transformers.MediaTransformer,
190
483
  )
484
+ return national_da
191
485
 
192
486
  @functools.cached_property
193
487
  def organic_media_raw_da(self) -> xr.DataArray | None:
194
488
  if self._meridian.input_data.organic_media is None:
195
489
  return None
196
- return self._truncate_media_time(self._meridian.input_data.organic_media)
490
+ raw_organic_media_da = self._truncate_media_time(
491
+ self._meridian.input_data.organic_media
492
+ )
493
+ raw_organic_media_da.name = constants.ORGANIC_MEDIA
494
+ return raw_organic_media_da
197
495
 
198
496
  @functools.cached_property
199
497
  def organic_media_scaled_da(self) -> xr.DataArray | None:
@@ -203,30 +501,42 @@ class EDAEngine:
203
501
  da=self._meridian.input_data.organic_media,
204
502
  values=self._meridian.organic_media_tensors.organic_media_scaled,
205
503
  )
504
+ organic_media_scaled_da.name = constants.ORGANIC_MEDIA_SCALED
206
505
  return self._truncate_media_time(organic_media_scaled_da)
207
506
 
208
507
  @functools.cached_property
209
- def organic_media_raw_da_national(self) -> xr.DataArray | None:
508
+ def national_organic_media_raw_da(self) -> xr.DataArray | None:
509
+ """Returns the national raw organic media data array."""
210
510
  if self.organic_media_raw_da is None:
211
511
  return None
212
- if self._meridian.is_national:
213
- return self.organic_media_raw_da.squeeze(constants.GEO)
512
+ if self._is_national_data:
513
+ national_da = self.organic_media_raw_da.squeeze(constants.GEO, drop=True)
514
+ national_da.name = constants.NATIONAL_ORGANIC_MEDIA
214
515
  else:
215
516
  # Note that organic media is summable by assumption.
216
- return self._aggregate_and_scale_geo_da(self.organic_media_raw_da, None)
517
+ national_da = self._aggregate_and_scale_geo_da(
518
+ self.organic_media_raw_da, constants.NATIONAL_ORGANIC_MEDIA, None
519
+ )
520
+ return national_da
217
521
 
218
522
  @functools.cached_property
219
- def organic_media_scaled_da_national(self) -> xr.DataArray | None:
523
+ def national_organic_media_scaled_da(self) -> xr.DataArray | None:
524
+ """Returns the national scaled organic media data array."""
220
525
  if self.organic_media_scaled_da is None:
221
526
  return None
222
- if self._meridian.is_national:
223
- return self.organic_media_scaled_da.squeeze(constants.GEO)
527
+ if self._is_national_data:
528
+ national_da = self.organic_media_scaled_da.squeeze(
529
+ constants.GEO, drop=True
530
+ )
531
+ national_da.name = constants.NATIONAL_ORGANIC_MEDIA_SCALED
224
532
  else:
225
533
  # Note that organic media is summable by assumption.
226
- return self._aggregate_and_scale_geo_da(
534
+ national_da = self._aggregate_and_scale_geo_da(
227
535
  self.organic_media_raw_da,
536
+ constants.NATIONAL_ORGANIC_MEDIA_SCALED,
228
537
  transformers.MediaTransformer,
229
538
  )
539
+ return national_da
230
540
 
231
541
  @functools.cached_property
232
542
  def non_media_scaled_da(self) -> xr.DataArray | None:
@@ -236,51 +546,62 @@ class EDAEngine:
236
546
  da=self._meridian.input_data.non_media_treatments,
237
547
  values=self._meridian.non_media_treatments_normalized,
238
548
  )
549
+ non_media_scaled_da.name = constants.NON_MEDIA_TREATMENTS_SCALED
239
550
  return non_media_scaled_da
240
551
 
241
552
  @functools.cached_property
242
- def non_media_scaled_da_national(self) -> xr.DataArray | None:
243
- """Returns the national non-media treatment data array."""
553
+ def national_non_media_scaled_da(self) -> xr.DataArray | None:
554
+ """Returns the national scaled non-media treatment data array."""
244
555
  if self._meridian.input_data.non_media_treatments is None:
245
556
  return None
246
- if self._meridian.is_national:
557
+ if self._is_national_data:
247
558
  if self.non_media_scaled_da is None:
248
559
  # This case should be impossible given the check above.
249
560
  raise RuntimeError(
250
561
  'non_media_scaled_da is None when non_media_treatments is not None.'
251
562
  )
252
- return self.non_media_scaled_da.squeeze(constants.GEO)
563
+ national_da = self.non_media_scaled_da.squeeze(constants.GEO, drop=True)
564
+ national_da.name = constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED
253
565
  else:
254
- return self._aggregate_and_scale_geo_da(
566
+ national_da = self._aggregate_and_scale_geo_da(
255
567
  self._meridian.input_data.non_media_treatments,
568
+ constants.NATIONAL_NON_MEDIA_TREATMENTS_SCALED,
256
569
  transformers.CenteringAndScalingTransformer,
257
570
  constants.NON_MEDIA_CHANNEL,
258
571
  self._agg_config.non_media_treatments,
259
572
  )
573
+ return national_da
260
574
 
261
575
  @functools.cached_property
262
576
  def rf_spend_da(self) -> xr.DataArray | None:
263
- if self._meridian.input_data.rf_spend is None:
577
+ """Returns RF spend.
578
+
579
+ If the input spend is aggregated, it is allocated across geo and time
580
+ proportionally to RF impressions (reach * frequency).
581
+ """
582
+ da = self._meridian.input_data.allocated_rf_spend
583
+ if da is None:
264
584
  return None
265
- rf_spend_da = _data_array_like(
266
- da=self._meridian.input_data.rf_spend,
267
- values=self._meridian.rf_tensors.rf_spend,
268
- )
269
- return rf_spend_da
585
+ da = da.copy()
586
+ da.name = constants.RF_SPEND
587
+ return da
270
588
 
271
589
  @functools.cached_property
272
- def rf_spend_da_national(self) -> xr.DataArray | None:
273
- if self._meridian.input_data.rf_spend is None:
590
+ def national_rf_spend_da(self) -> xr.DataArray | None:
591
+ """Returns the national RF spend data array."""
592
+ rf_spend = self.rf_spend_da
593
+ if rf_spend is None:
274
594
  return None
275
- if self._meridian.is_national:
276
- if self.rf_spend_da is None:
277
- # This case should be impossible given the check above.
278
- raise RuntimeError('rf_spend_da is None when rf_spend is not None.')
279
- return self.rf_spend_da.squeeze(constants.GEO)
595
+ if self._is_national_data:
596
+ national_da = rf_spend.squeeze(constants.GEO, drop=True)
597
+ national_da.name = constants.NATIONAL_RF_SPEND
280
598
  else:
281
- return self._aggregate_and_scale_geo_da(
282
- self._meridian.input_data.rf_spend, None
599
+ national_da = self._aggregate_and_scale_geo_da(
600
+ self._meridian.input_data.allocated_rf_spend,
601
+ constants.NATIONAL_RF_SPEND,
602
+ None,
283
603
  )
604
+ return national_da
284
605
 
285
606
  @functools.cached_property
286
607
  def _rf_data(self) -> ReachFrequencyData | None:
@@ -302,31 +623,34 @@ class EDAEngine:
302
623
  def reach_scaled_da(self) -> xr.DataArray | None:
303
624
  if self._rf_data is None:
304
625
  return None
305
- return self._rf_data.reach_scaled_da
626
+ return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
306
627
 
307
628
  @property
308
- def reach_raw_da_national(self) -> xr.DataArray | None:
629
+ def national_reach_raw_da(self) -> xr.DataArray | None:
630
+ """Returns the national raw reach data array."""
309
631
  if self._rf_data is None:
310
632
  return None
311
- return self._rf_data.reach_raw_da_national
633
+ return self._rf_data.national_reach_raw_da
312
634
 
313
635
  @property
314
- def reach_scaled_da_national(self) -> xr.DataArray | None:
636
+ def national_reach_scaled_da(self) -> xr.DataArray | None:
637
+ """Returns the national scaled reach data array."""
315
638
  if self._rf_data is None:
316
639
  return None
317
- return self._rf_data.reach_scaled_da_national
640
+ return self._rf_data.national_reach_scaled_da # pytype: disable=attribute-error
318
641
 
319
642
  @property
320
643
  def frequency_da(self) -> xr.DataArray | None:
321
644
  if self._rf_data is None:
322
645
  return None
323
- return self._rf_data.frequency_da
646
+ return self._rf_data.frequency_da # pytype: disable=attribute-error
324
647
 
325
648
  @property
326
- def frequency_da_national(self) -> xr.DataArray | None:
649
+ def national_frequency_da(self) -> xr.DataArray | None:
650
+ """Returns the national frequency data array."""
327
651
  if self._rf_data is None:
328
652
  return None
329
- return self._rf_data.frequency_da_national
653
+ return self._rf_data.national_frequency_da # pytype: disable=attribute-error
330
654
 
331
655
  @property
332
656
  def rf_impressions_raw_da(self) -> xr.DataArray | None:
@@ -335,10 +659,11 @@ class EDAEngine:
335
659
  return self._rf_data.rf_impressions_raw_da
336
660
 
337
661
  @property
338
- def rf_impressions_raw_da_national(self) -> xr.DataArray | None:
662
+ def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
663
+ """Returns the national raw RF impressions data array."""
339
664
  if self._rf_data is None:
340
665
  return None
341
- return self._rf_data.rf_impressions_raw_da_national
666
+ return self._rf_data.national_rf_impressions_raw_da
342
667
 
343
668
  @property
344
669
  def rf_impressions_scaled_da(self) -> xr.DataArray | None:
@@ -347,10 +672,11 @@ class EDAEngine:
347
672
  return self._rf_data.rf_impressions_scaled_da
348
673
 
349
674
  @property
350
- def rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
675
+ def national_rf_impressions_scaled_da(self) -> xr.DataArray | None:
676
+ """Returns the national scaled RF impressions data array."""
351
677
  if self._rf_data is None:
352
678
  return None
353
- return self._rf_data.rf_impressions_scaled_da_national
679
+ return self._rf_data.national_rf_impressions_scaled_da
354
680
 
355
681
  @functools.cached_property
356
682
  def _organic_rf_data(self) -> ReachFrequencyData | None:
@@ -372,19 +698,21 @@ class EDAEngine:
372
698
  def organic_reach_scaled_da(self) -> xr.DataArray | None:
373
699
  if self._organic_rf_data is None:
374
700
  return None
375
- return self._organic_rf_data.reach_scaled_da
701
+ return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
376
702
 
377
703
  @property
378
- def organic_reach_raw_da_national(self) -> xr.DataArray | None:
704
+ def national_organic_reach_raw_da(self) -> xr.DataArray | None:
705
+ """Returns the national raw organic reach data array."""
379
706
  if self._organic_rf_data is None:
380
707
  return None
381
- return self._organic_rf_data.reach_raw_da_national
708
+ return self._organic_rf_data.national_reach_raw_da
382
709
 
383
710
  @property
384
- def organic_reach_scaled_da_national(self) -> xr.DataArray | None:
711
+ def national_organic_reach_scaled_da(self) -> xr.DataArray | None:
712
+ """Returns the national scaled organic reach data array."""
385
713
  if self._organic_rf_data is None:
386
714
  return None
387
- return self._organic_rf_data.reach_scaled_da_national
715
+ return self._organic_rf_data.national_reach_scaled_da # pytype: disable=attribute-error
388
716
 
389
717
  @property
390
718
  def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
@@ -393,22 +721,24 @@ class EDAEngine:
393
721
  return self._organic_rf_data.rf_impressions_scaled_da
394
722
 
395
723
  @property
396
- def organic_rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
724
+ def national_organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
725
+ """Returns the national scaled organic RF impressions data array."""
397
726
  if self._organic_rf_data is None:
398
727
  return None
399
- return self._organic_rf_data.rf_impressions_scaled_da_national
728
+ return self._organic_rf_data.national_rf_impressions_scaled_da
400
729
 
401
730
  @property
402
731
  def organic_frequency_da(self) -> xr.DataArray | None:
403
732
  if self._organic_rf_data is None:
404
733
  return None
405
- return self._organic_rf_data.frequency_da
734
+ return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
406
735
 
407
736
  @property
408
- def organic_frequency_da_national(self) -> xr.DataArray | None:
737
+ def national_organic_frequency_da(self) -> xr.DataArray | None:
738
+ """Returns the national organic frequency data array."""
409
739
  if self._organic_rf_data is None:
410
740
  return None
411
- return self._organic_rf_data.frequency_da_national
741
+ return self._organic_rf_data.national_frequency_da # pytype: disable=attribute-error
412
742
 
413
743
  @property
414
744
  def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
@@ -417,14 +747,15 @@ class EDAEngine:
417
747
  return self._organic_rf_data.rf_impressions_raw_da
418
748
 
419
749
  @property
420
- def organic_rf_impressions_raw_da_national(self) -> xr.DataArray | None:
750
+ def national_organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
751
+ """Returns the national raw organic RF impressions data array."""
421
752
  if self._organic_rf_data is None:
422
753
  return None
423
- return self._organic_rf_data.rf_impressions_raw_da_national
754
+ return self._organic_rf_data.national_rf_impressions_raw_da
424
755
 
425
756
  @functools.cached_property
426
757
  def geo_population_da(self) -> xr.DataArray | None:
427
- if self._meridian.is_national:
758
+ if self._is_national_data:
428
759
  return None
429
760
  return xr.DataArray(
430
761
  self._meridian.population,
@@ -435,21 +766,38 @@ class EDAEngine:
435
766
 
436
767
  @functools.cached_property
437
768
  def kpi_scaled_da(self) -> xr.DataArray:
438
- return _data_array_like(
769
+ scaled_kpi_da = _data_array_like(
439
770
  da=self._meridian.input_data.kpi,
440
771
  values=self._meridian.kpi_scaled,
441
772
  )
773
+ scaled_kpi_da.name = constants.KPI_SCALED
774
+ return scaled_kpi_da
442
775
 
443
776
  @functools.cached_property
444
- def kpi_scaled_da_national(self) -> xr.DataArray:
445
- if self._meridian.is_national:
446
- return self.kpi_scaled_da.squeeze(constants.GEO)
777
+ def _overall_scaled_kpi_invariability_artifact(
778
+ self,
779
+ ) -> eda_outcome.KpiInvariabilityArtifact:
780
+ """Returns an artifact of overall scaled KPI invariability."""
781
+ return eda_outcome.KpiInvariabilityArtifact(
782
+ level=eda_outcome.AnalysisLevel.OVERALL,
783
+ kpi_da=self.kpi_scaled_da,
784
+ kpi_stdev=self.kpi_scaled_da.std(ddof=1),
785
+ )
786
+
787
+ @functools.cached_property
788
+ def national_kpi_scaled_da(self) -> xr.DataArray:
789
+ """Returns the national scaled KPI data array."""
790
+ if self._is_national_data:
791
+ national_da = self.kpi_scaled_da.squeeze(constants.GEO, drop=True)
792
+ national_da.name = constants.NATIONAL_KPI_SCALED
447
793
  else:
448
794
  # Note that kpi is summable by assumption.
449
- return self._aggregate_and_scale_geo_da(
795
+ national_da = self._aggregate_and_scale_geo_da(
450
796
  self._meridian.input_data.kpi,
797
+ constants.NATIONAL_KPI_SCALED,
451
798
  transformers.CenteringAndScalingTransformer,
452
799
  )
800
+ return national_da
453
801
 
454
802
  @functools.cached_property
455
803
  def treatment_control_scaled_ds(self) -> xr.Dataset:
@@ -473,7 +821,46 @@ class EDAEngine:
473
821
  return xr.merge(to_merge, join='inner')
474
822
 
475
823
  @functools.cached_property
476
- def treatment_control_scaled_ds_national(self) -> xr.Dataset:
824
+ def all_spend_ds(self) -> xr.Dataset:
825
+ """Returns a Dataset containing all spend data.
826
+
827
+ This includes media spend and rf spend.
828
+ """
829
+ to_merge = [
830
+ da
831
+ for da in [
832
+ self.media_spend_da,
833
+ self.rf_spend_da,
834
+ ]
835
+ if da is not None
836
+ ]
837
+ return xr.merge(to_merge, join='inner')
838
+
839
+ @functools.cached_property
840
+ def national_all_spend_ds(self) -> xr.Dataset:
841
+ """Returns a Dataset containing all national spend data.
842
+
843
+ This includes media spend and rf spend.
844
+ """
845
+ to_merge = [
846
+ da
847
+ for da in [
848
+ self.national_media_spend_da,
849
+ self.national_rf_spend_da,
850
+ ]
851
+ if da is not None
852
+ ]
853
+ return xr.merge(to_merge, join='inner')
854
+
855
+ @functools.cached_property
856
+ def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
857
+ """Returns a stacked DataArray of treatment_control_scaled_ds."""
858
+ da = stack_variables(self.treatment_control_scaled_ds)
859
+ da.name = constants.TREATMENT_CONTROL_SCALED
860
+ return da
861
+
862
+ @functools.cached_property
863
+ def national_treatment_control_scaled_ds(self) -> xr.Dataset:
477
864
  """Returns a Dataset containing all scaled treatments and controls.
478
865
 
479
866
  This includes media, RF impressions, organic media, organic RF impressions,
@@ -482,17 +869,148 @@ class EDAEngine:
482
869
  to_merge_national = [
483
870
  da
484
871
  for da in [
485
- self.media_scaled_da_national,
486
- self.rf_impressions_scaled_da_national,
487
- self.organic_media_scaled_da_national,
488
- self.organic_rf_impressions_scaled_da_national,
489
- self.controls_scaled_da_national,
490
- self.non_media_scaled_da_national,
872
+ self.national_media_scaled_da,
873
+ self.national_rf_impressions_scaled_da,
874
+ self.national_organic_media_scaled_da,
875
+ self.national_organic_rf_impressions_scaled_da,
876
+ self.national_controls_scaled_da,
877
+ self.national_non_media_scaled_da,
491
878
  ]
492
879
  if da is not None
493
880
  ]
494
881
  return xr.merge(to_merge_national, join='inner')
495
882
 
883
+ @functools.cached_property
884
+ def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
885
+ """Returns a stacked DataArray of national_treatment_control_scaled_ds."""
886
+ da = stack_variables(self.national_treatment_control_scaled_ds)
887
+ da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
888
+ return da
889
+
890
+ @functools.cached_property
891
+ def all_reach_scaled_da(self) -> xr.DataArray | None:
892
+ """Returns a DataArray containing all scaled reach data.
893
+
894
+ This includes both paid and organic reach, concatenated along the RF_CHANNEL
895
+ dimension.
896
+
897
+ Returns:
898
+ A DataArray containing all scaled reach data, or None if no RF or organic
899
+ RF channels are present.
900
+ """
901
+ reach_das = []
902
+ if self.reach_scaled_da is not None:
903
+ reach_das.append(self.reach_scaled_da)
904
+ if self.organic_reach_scaled_da is not None:
905
+ reach_das.append(
906
+ self.organic_reach_scaled_da.rename(
907
+ {constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
908
+ )
909
+ )
910
+ if not reach_das:
911
+ return None
912
+ da = xr.concat(reach_das, dim=constants.RF_CHANNEL)
913
+ da.name = constants.ALL_REACH_SCALED
914
+ return da
915
+
916
+ @functools.cached_property
917
+ def all_freq_da(self) -> xr.DataArray | None:
918
+ """Returns a DataArray containing all frequency data.
919
+
920
+ This includes both paid and organic frequency, concatenated along the
921
+ RF_CHANNEL dimension.
922
+
923
+ Returns:
924
+ A DataArray containing all frequency data, or None if no RF or organic
925
+ RF channels are present.
926
+ """
927
+ freq_das = []
928
+ if self.frequency_da is not None:
929
+ freq_das.append(self.frequency_da)
930
+ if self.organic_frequency_da is not None:
931
+ freq_das.append(
932
+ self.organic_frequency_da.rename(
933
+ {constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
934
+ )
935
+ )
936
+ if not freq_das:
937
+ return None
938
+ da = xr.concat(freq_das, dim=constants.RF_CHANNEL)
939
+ da.name = constants.ALL_FREQUENCY
940
+ return da
941
+
942
+ @functools.cached_property
943
+ def national_all_reach_scaled_da(self) -> xr.DataArray | None:
944
+ """Returns a DataArray containing all national-level scaled reach data.
945
+
946
+ This includes both paid and organic reach, concatenated along the
947
+ RF_CHANNEL dimension.
948
+
949
+ Returns:
950
+ A DataArray containing all national-level scaled reach data, or None if
951
+ no RF or organic RF channels are present.
952
+ """
953
+ national_reach_das = []
954
+ if self.national_reach_scaled_da is not None:
955
+ national_reach_das.append(self.national_reach_scaled_da)
956
+ national_organic_reach_scaled_da = self.national_organic_reach_scaled_da
957
+ if national_organic_reach_scaled_da is not None:
958
+ national_reach_das.append(
959
+ national_organic_reach_scaled_da.rename(
960
+ {constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
961
+ )
962
+ )
963
+ if not national_reach_das:
964
+ return None
965
+ da = xr.concat(national_reach_das, dim=constants.RF_CHANNEL)
966
+ da.name = constants.NATIONAL_ALL_REACH_SCALED
967
+ return da
968
+
969
+ @functools.cached_property
970
+ def national_all_freq_da(self) -> xr.DataArray | None:
971
+ """Returns a DataArray containing all national-level frequency data.
972
+
973
+ This includes both paid and organic frequency, concatenated along the
974
+ RF_CHANNEL dimension.
975
+
976
+ Returns:
977
+ A DataArray containing all national-level frequency data, or None if no
978
+ RF or organic RF channels are present.
979
+ """
980
+ national_freq_das = []
981
+ if self.national_frequency_da is not None:
982
+ national_freq_das.append(self.national_frequency_da)
983
+ national_organic_frequency_da = self.national_organic_frequency_da
984
+ if national_organic_frequency_da is not None:
985
+ national_freq_das.append(
986
+ national_organic_frequency_da.rename(
987
+ {constants.ORGANIC_RF_CHANNEL: constants.RF_CHANNEL}
988
+ )
989
+ )
990
+ if not national_freq_das:
991
+ return None
992
+ da = xr.concat(national_freq_das, dim=constants.RF_CHANNEL)
993
+ da.name = constants.NATIONAL_ALL_FREQUENCY
994
+ return da
995
+
996
+ @property
997
+ def _critical_checks(
998
+ self,
999
+ ) -> list[tuple[_NamedEDACheckCallable, eda_outcome.EDACheckType]]:
1000
+ """Returns a list of critical checks to be performed."""
1001
+ checks = [
1002
+ (
1003
+ self.check_overall_kpi_invariability,
1004
+ eda_outcome.EDACheckType.KPI_INVARIABILITY,
1005
+ ),
1006
+ (self.check_vif, eda_outcome.EDACheckType.MULTICOLLINEARITY),
1007
+ (
1008
+ self.check_pairwise_corr,
1009
+ eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
1010
+ ),
1011
+ ]
1012
+ return checks
1013
+
496
1014
  def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
497
1015
  """Truncates the first `start` elements of the media time of a variable."""
498
1016
  # This should not happen. If it does, it means this function is mis-used.
@@ -510,14 +1028,16 @@ class EDAEngine:
510
1028
  self,
511
1029
  xarray: xr.DataArray,
512
1030
  transformer_class: Optional[type[transformers.TensorTransformer]],
513
- population: tf.Tensor = tf.constant([1.0], dtype=tf.float32),
514
- ):
1031
+ population: Optional[backend.Tensor] = None,
1032
+ ) -> xr.DataArray:
515
1033
  """Scales xarray values with a TensorTransformer."""
516
1034
  da = xarray.copy()
517
1035
 
518
1036
  if transformer_class is None:
519
1037
  return da
520
- elif transformer_class is transformers.CenteringAndScalingTransformer:
1038
+ if population is None:
1039
+ population = backend.ones([1], dtype=backend.float32)
1040
+ if transformer_class is transformers.CenteringAndScalingTransformer:
521
1041
  xarray_transformer = transformers.CenteringAndScalingTransformer(
522
1042
  tensor=da.values, population=population
523
1043
  )
@@ -537,15 +1057,15 @@ class EDAEngine:
537
1057
 
538
1058
  def _aggregate_variables(
539
1059
  self,
540
- da_geo: xr.DataArray,
1060
+ geo_da: xr.DataArray,
541
1061
  channel_dim: str,
542
- da_var_agg_map: AggregationMap,
1062
+ da_var_agg_map: eda_spec.AggregationMap,
543
1063
  keepdims: bool = True,
544
1064
  ) -> xr.DataArray:
545
1065
  """Aggregates variables within a DataArray based on user-defined functions.
546
1066
 
547
1067
  Args:
548
- da_geo: The geo-level DataArray containing multiple variables along
1068
+ geo_da: The geo-level DataArray containing multiple variables along
549
1069
  channel_dim.
550
1070
  channel_dim: The name of the dimension coordinate to aggregate over (e.g.,
551
1071
  constants.CONTROL_VARIABLE).
@@ -558,8 +1078,8 @@ class EDAEngine:
558
1078
  aggregated according to the da_var_agg_map.
559
1079
  """
560
1080
  agg_results = []
561
- for var_name in da_geo[channel_dim].values:
562
- var_data = da_geo.sel({channel_dim: var_name})
1081
+ for var_name in geo_da[channel_dim].values:
1082
+ var_data = geo_da.sel({channel_dim: var_name})
563
1083
  agg_func = da_var_agg_map.get(var_name, _DEFAULT_DA_VAR_AGG_FUNCTION)
564
1084
  # Apply the aggregation function over the GEO dimension
565
1085
  aggregated_data = var_data.reduce(
@@ -572,15 +1092,17 @@ class EDAEngine:
572
1092
 
573
1093
  def _aggregate_and_scale_geo_da(
574
1094
  self,
575
- da_geo: xr.DataArray,
1095
+ geo_da: xr.DataArray,
1096
+ national_da_name: str,
576
1097
  transformer_class: Optional[type[transformers.TensorTransformer]],
577
1098
  channel_dim: Optional[str] = None,
578
- da_var_agg_map: Optional[AggregationMap] = None,
1099
+ da_var_agg_map: Optional[eda_spec.AggregationMap] = None,
579
1100
  ) -> xr.DataArray:
580
1101
  """Aggregate geo-level xr.DataArray to national level and then scale values.
581
1102
 
582
1103
  Args:
583
- da_geo: The geo-level DataArray to convert.
1104
+ geo_da: The geo-level DataArray to convert.
1105
+ national_da_name: The name for the returned national DataArray.
584
1106
  transformer_class: The TensorTransformer class to apply after summing to
585
1107
  national level. Must be None, CenteringAndScalingTransformer, or
586
1108
  MediaTransformer.
@@ -599,20 +1121,21 @@ class EDAEngine:
599
1121
  da_var_agg_map = {}
600
1122
 
601
1123
  if channel_dim is not None:
602
- da_national = self._aggregate_variables(
603
- da_geo, channel_dim, da_var_agg_map
1124
+ national_da = self._aggregate_variables(
1125
+ geo_da, channel_dim, da_var_agg_map
604
1126
  )
605
1127
  else:
606
- # Default to sum aggregation if no channel dimension is provided
607
- da_national = da_geo.sum(
1128
+ national_da = geo_da.sum(
608
1129
  dim=constants.GEO, keepdims=True, skipna=False, keep_attrs=True
609
1130
  )
610
1131
 
611
- da_national = da_national.assign_coords({constants.GEO: [temp_geo_dim]})
612
- da_national.values = tf.cast(da_national.values, tf.float32)
613
- da_national = self._scale_xarray(da_national, transformer_class)
1132
+ national_da = national_da.assign_coords({constants.GEO: [temp_geo_dim]})
1133
+ national_da.values = backend.cast(national_da.values, dtype=backend.float32)
1134
+ national_da = self._scale_xarray(national_da, transformer_class)
614
1135
 
615
- return da_national.sel({constants.GEO: temp_geo_dim}, drop=True)
1136
+ national_da = national_da.sel({constants.GEO: temp_geo_dim}, drop=True)
1137
+ national_da.name = national_da_name
1138
+ return national_da
616
1139
 
617
1140
  def _get_rf_data(
618
1141
  self,
@@ -625,64 +1148,80 @@ class EDAEngine:
625
1148
  scaled_reach_values = (
626
1149
  self._meridian.organic_rf_tensors.organic_reach_scaled
627
1150
  )
1151
+ names = _ORGANIC_RF_NAMES
628
1152
  else:
629
1153
  scaled_reach_values = self._meridian.rf_tensors.reach_scaled
1154
+ names = _RF_NAMES
1155
+
630
1156
  reach_scaled_da = _data_array_like(
631
1157
  da=reach_raw_da, values=scaled_reach_values
632
1158
  )
1159
+ reach_scaled_da.name = names.reach_scaled
633
1160
  # Truncate the media time for reach and scaled reach.
634
1161
  reach_raw_da = self._truncate_media_time(reach_raw_da)
1162
+ reach_raw_da.name = names.reach
635
1163
  reach_scaled_da = self._truncate_media_time(reach_scaled_da)
636
1164
 
637
1165
  # The geo level frequency
638
1166
  frequency_da = self._truncate_media_time(freq_raw_da)
1167
+ frequency_da.name = names.frequency
639
1168
 
640
1169
  # The raw geo level impression
641
1170
  # It's equal to reach * frequency.
642
1171
  impressions_raw_da = reach_raw_da * frequency_da
643
- impressions_raw_da.name = (
644
- constants.ORGANIC_RF_IMPRESSIONS
645
- if is_organic
646
- else constants.RF_IMPRESSIONS
1172
+ impressions_raw_da.name = names.impressions
1173
+ impressions_raw_da.values = backend.cast(
1174
+ impressions_raw_da.values, dtype=backend.float32
647
1175
  )
648
- impressions_raw_da.values = tf.cast(impressions_raw_da.values, tf.float32)
649
1176
 
650
- if self._meridian.is_national:
651
- reach_raw_da_national = reach_raw_da.squeeze(constants.GEO)
652
- reach_scaled_da_national = reach_scaled_da.squeeze(constants.GEO)
653
- impressions_raw_da_national = impressions_raw_da.squeeze(constants.GEO)
654
- frequency_da_national = frequency_da.squeeze(constants.GEO)
1177
+ if self._is_national_data:
1178
+ national_reach_raw_da = reach_raw_da.squeeze(constants.GEO, drop=True)
1179
+ national_reach_raw_da.name = names.national_reach
1180
+ national_reach_scaled_da = reach_scaled_da.squeeze(
1181
+ constants.GEO, drop=True
1182
+ )
1183
+ national_reach_scaled_da.name = names.national_reach_scaled
1184
+ national_impressions_raw_da = impressions_raw_da.squeeze(
1185
+ constants.GEO, drop=True
1186
+ )
1187
+ national_impressions_raw_da.name = names.national_impressions
1188
+ national_frequency_da = frequency_da.squeeze(constants.GEO, drop=True)
1189
+ national_frequency_da.name = names.national_frequency
655
1190
 
656
1191
  # Scaled impressions
657
1192
  impressions_scaled_da = self._scale_xarray(
658
1193
  impressions_raw_da, transformers.MediaTransformer
659
1194
  )
660
- impressions_scaled_da_national = impressions_scaled_da.squeeze(
661
- constants.GEO
1195
+ impressions_scaled_da.name = names.impressions_scaled
1196
+ national_impressions_scaled_da = impressions_scaled_da.squeeze(
1197
+ constants.GEO, drop=True
662
1198
  )
1199
+ national_impressions_scaled_da.name = names.national_impressions_scaled
663
1200
  else:
664
- reach_raw_da_national = self._aggregate_and_scale_geo_da(
665
- reach_raw_da, None
1201
+ national_reach_raw_da = self._aggregate_and_scale_geo_da(
1202
+ reach_raw_da, names.national_reach, None
666
1203
  )
667
- reach_scaled_da_national = self._aggregate_and_scale_geo_da(
668
- reach_raw_da, transformers.MediaTransformer
1204
+ national_reach_scaled_da = self._aggregate_and_scale_geo_da(
1205
+ reach_raw_da,
1206
+ names.national_reach_scaled,
1207
+ transformers.MediaTransformer,
669
1208
  )
670
- impressions_raw_da_national = self._aggregate_and_scale_geo_da(
671
- impressions_raw_da, None
1209
+ national_impressions_raw_da = self._aggregate_and_scale_geo_da(
1210
+ impressions_raw_da,
1211
+ names.national_impressions,
1212
+ None,
672
1213
  )
673
1214
 
674
1215
  # National frequency is a weighted average of geo frequencies,
675
1216
  # weighted by reach.
676
- frequency_da_national = xr.where(
677
- reach_raw_da_national == 0.0,
1217
+ national_frequency_da = xr.where(
1218
+ national_reach_raw_da == 0.0,
678
1219
  0.0,
679
- impressions_raw_da_national / reach_raw_da_national,
1220
+ national_impressions_raw_da / national_reach_raw_da,
680
1221
  )
681
- frequency_da_national.name = (
682
- constants.ORGANIC_PREFIX if is_organic else ''
683
- ) + constants.FREQUENCY
684
- frequency_da_national.values = tf.cast(
685
- frequency_da_national.values, tf.float32
1222
+ national_frequency_da.name = names.national_frequency
1223
+ national_frequency_da.values = backend.cast(
1224
+ national_frequency_da.values, dtype=backend.float32
686
1225
  )
687
1226
 
688
1227
  # Scale the impressions by population
@@ -691,45 +1230,619 @@ class EDAEngine:
691
1230
  transformers.MediaTransformer,
692
1231
  population=self._meridian.population,
693
1232
  )
1233
+ impressions_scaled_da.name = names.impressions_scaled
694
1234
 
695
1235
  # Scale the national impressions
696
- impressions_scaled_da_national = self._aggregate_and_scale_geo_da(
1236
+ national_impressions_scaled_da = self._aggregate_and_scale_geo_da(
697
1237
  impressions_raw_da,
1238
+ names.national_impressions_scaled,
698
1239
  transformers.MediaTransformer,
699
1240
  )
700
1241
 
701
1242
  return ReachFrequencyData(
702
1243
  reach_raw_da=reach_raw_da,
703
1244
  reach_scaled_da=reach_scaled_da,
704
- reach_raw_da_national=reach_raw_da_national,
705
- reach_scaled_da_national=reach_scaled_da_national,
1245
+ national_reach_raw_da=national_reach_raw_da,
1246
+ national_reach_scaled_da=national_reach_scaled_da,
706
1247
  frequency_da=frequency_da,
707
- frequency_da_national=frequency_da_national,
1248
+ national_frequency_da=national_frequency_da,
708
1249
  rf_impressions_scaled_da=impressions_scaled_da,
709
- rf_impressions_scaled_da_national=impressions_scaled_da_national,
1250
+ national_rf_impressions_scaled_da=national_impressions_scaled_da,
710
1251
  rf_impressions_raw_da=impressions_raw_da,
711
- rf_impressions_raw_da_national=impressions_raw_da_national,
1252
+ national_rf_impressions_raw_da=national_impressions_raw_da,
712
1253
  )
713
1254
 
1255
+ def _pairwise_corr_for_geo_data(
1256
+ self, dims: str | Sequence[str], extreme_corr_threshold: float
1257
+ ) -> tuple[xr.DataArray, pd.DataFrame]:
1258
+ """Get pairwise correlation among treatments and controls for geo data."""
1259
+ corr_mat = _compute_correlation_matrix(
1260
+ self._stacked_treatment_control_scaled_da, dims=dims
1261
+ )
1262
+ extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
1263
+ corr_mat, extreme_corr_threshold
1264
+ )
1265
+ return corr_mat, extreme_corr_var_pairs_df
714
1266
 
715
- def _data_array_like(
716
- *, da: xr.DataArray, values: np.ndarray | tf.Tensor
717
- ) -> xr.DataArray:
718
- """Returns a DataArray from `values` with the same structure as `da`.
1267
+ def check_geo_pairwise_corr(
1268
+ self,
1269
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PairwiseCorrArtifact]:
1270
+ """Checks pairwise correlation among treatments and controls for geo data.
719
1271
 
720
- Args:
721
- da: The DataArray whose structure (dimensions, coordinates, name, and attrs)
722
- will be used for the new DataArray.
723
- values: The numpy array or tensorflow tensor to use as the values for the
724
- new DataArray.
1272
+ Returns:
1273
+ An EDAOutcome object with findings and result values.
725
1274
 
726
- Returns:
727
- A new DataArray with the provided `values` and the same structure as `da`.
728
- """
729
- return xr.DataArray(
730
- values,
731
- coords=da.coords,
732
- dims=da.dims,
733
- name=da.name,
734
- attrs=da.attrs,
735
- )
1275
+ Raises:
1276
+ GeoLevelCheckOnNationalModelError: If the model is national.
1277
+ """
1278
+ # If the model is national, raise an error.
1279
+ if self._is_national_data:
1280
+ raise GeoLevelCheckOnNationalModelError(
1281
+ 'check_geo_pairwise_corr is not supported for national models.'
1282
+ )
1283
+
1284
+ findings = []
1285
+
1286
+ overall_corr_mat, overall_extreme_corr_var_pairs_df = (
1287
+ self._pairwise_corr_for_geo_data(
1288
+ dims=[constants.GEO, constants.TIME],
1289
+ extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
1290
+ )
1291
+ )
1292
+ if not overall_extreme_corr_var_pairs_df.empty:
1293
+ var_pairs = overall_extreme_corr_var_pairs_df.index.to_list()
1294
+ findings.append(
1295
+ eda_outcome.EDAFinding(
1296
+ severity=eda_outcome.EDASeverity.ERROR,
1297
+ explanation=(
1298
+ 'Some variables have perfect pairwise correlation across all'
1299
+ ' times and geos. For each pair of perfectly-correlated'
1300
+ ' variables, please remove one of the variables from the'
1301
+ f' model.\nPairs with perfect correlation: {var_pairs}'
1302
+ ),
1303
+ )
1304
+ )
1305
+
1306
+ geo_corr_mat, geo_extreme_corr_var_pairs_df = (
1307
+ self._pairwise_corr_for_geo_data(
1308
+ dims=constants.TIME,
1309
+ extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
1310
+ )
1311
+ )
1312
+ # Overall correlation and per-geo correlation findings are mutually
1313
+ # exclusive, and overall correlation finding takes precedence.
1314
+ if (
1315
+ overall_extreme_corr_var_pairs_df.empty
1316
+ and not geo_extreme_corr_var_pairs_df.empty
1317
+ ):
1318
+ findings.append(
1319
+ eda_outcome.EDAFinding(
1320
+ severity=eda_outcome.EDASeverity.ATTENTION,
1321
+ explanation=(
1322
+ 'Some variables have perfect pairwise correlation in certain'
1323
+ ' geo(s). Consider checking your data, and/or combining these'
1324
+ ' variables if they also have high pairwise correlations in'
1325
+ ' other geos.'
1326
+ ),
1327
+ )
1328
+ )
1329
+
1330
+ # If there are no findings, add a INFO level finding indicating that no
1331
+ # severe correlations were found and what it means for user's data.
1332
+ if not findings:
1333
+ findings.append(
1334
+ eda_outcome.EDAFinding(
1335
+ severity=eda_outcome.EDASeverity.INFO,
1336
+ explanation=(
1337
+ 'Please review the computed pairwise correlations. Note that'
1338
+ ' high pairwise correlation may cause model identifiability'
1339
+ ' and convergence issues. Consider combining the variables if'
1340
+ ' high correlation exists.'
1341
+ ),
1342
+ )
1343
+ )
1344
+
1345
+ pairwise_corr_artifacts = [
1346
+ eda_outcome.PairwiseCorrArtifact(
1347
+ level=eda_outcome.AnalysisLevel.OVERALL,
1348
+ corr_matrix=overall_corr_mat,
1349
+ extreme_corr_var_pairs=overall_extreme_corr_var_pairs_df,
1350
+ extreme_corr_threshold=_OVERALL_PAIRWISE_CORR_THRESHOLD,
1351
+ ),
1352
+ eda_outcome.PairwiseCorrArtifact(
1353
+ level=eda_outcome.AnalysisLevel.GEO,
1354
+ corr_matrix=geo_corr_mat,
1355
+ extreme_corr_var_pairs=geo_extreme_corr_var_pairs_df,
1356
+ extreme_corr_threshold=_GEO_PAIRWISE_CORR_THRESHOLD,
1357
+ ),
1358
+ ]
1359
+
1360
+ return eda_outcome.EDAOutcome(
1361
+ check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
1362
+ findings=findings,
1363
+ analysis_artifacts=pairwise_corr_artifacts,
1364
+ )
1365
+
1366
+ def check_national_pairwise_corr(
1367
+ self,
1368
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PairwiseCorrArtifact]:
1369
+ """Checks pairwise correlation among treatments and controls for national data.
1370
+
1371
+ Returns:
1372
+ An EDAOutcome object with findings and result values.
1373
+ """
1374
+ findings = []
1375
+
1376
+ corr_mat = _compute_correlation_matrix(
1377
+ self._stacked_national_treatment_control_scaled_da, dims=constants.TIME
1378
+ )
1379
+ extreme_corr_var_pairs_df = _find_extreme_corr_pairs(
1380
+ corr_mat, _NATIONAL_PAIRWISE_CORR_THRESHOLD
1381
+ )
1382
+
1383
+ if not extreme_corr_var_pairs_df.empty:
1384
+ var_pairs = extreme_corr_var_pairs_df.index.to_list()
1385
+ findings.append(
1386
+ eda_outcome.EDAFinding(
1387
+ severity=eda_outcome.EDASeverity.ERROR,
1388
+ explanation=(
1389
+ 'Some variables have perfect pairwise correlation across all'
1390
+ ' times. For each pair of perfectly-correlated'
1391
+ ' variables, please remove one of the variables from the'
1392
+ f' model.\nPairs with perfect correlation: {var_pairs}'
1393
+ ),
1394
+ )
1395
+ )
1396
+ else:
1397
+ findings.append(
1398
+ eda_outcome.EDAFinding(
1399
+ severity=eda_outcome.EDASeverity.INFO,
1400
+ explanation=(
1401
+ 'Please review the computed pairwise correlations. Note that'
1402
+ ' high pairwise correlation may cause model identifiability'
1403
+ ' and convergence issues. Consider combining the variables if'
1404
+ ' high correlation exists.'
1405
+ ),
1406
+ )
1407
+ )
1408
+
1409
+ pairwise_corr_artifacts = [
1410
+ eda_outcome.PairwiseCorrArtifact(
1411
+ level=eda_outcome.AnalysisLevel.NATIONAL,
1412
+ corr_matrix=corr_mat,
1413
+ extreme_corr_var_pairs=extreme_corr_var_pairs_df,
1414
+ extreme_corr_threshold=_NATIONAL_PAIRWISE_CORR_THRESHOLD,
1415
+ )
1416
+ ]
1417
+ return eda_outcome.EDAOutcome(
1418
+ check_type=eda_outcome.EDACheckType.PAIRWISE_CORRELATION,
1419
+ findings=findings,
1420
+ analysis_artifacts=pairwise_corr_artifacts,
1421
+ )
1422
+
1423
+ def check_pairwise_corr(
1424
+ self,
1425
+ ) -> eda_outcome.EDAOutcome[eda_outcome.PairwiseCorrArtifact]:
1426
+ """Checks pairwise correlation among treatments and controls.
1427
+
1428
+ Returns:
1429
+ An EDAOutcome object with findings and result values.
1430
+ """
1431
+ if self._is_national_data:
1432
+ return self.check_national_pairwise_corr()
1433
+ else:
1434
+ return self.check_geo_pairwise_corr()
1435
+
1436
+ def _check_std(
1437
+ self,
1438
+ data: xr.DataArray,
1439
+ level: eda_outcome.AnalysisLevel,
1440
+ zero_std_message: str,
1441
+ ) -> tuple[
1442
+ Optional[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
1443
+ ]:
1444
+ """Helper to check standard deviation."""
1445
+ std_ds, outlier_df = _calculate_std(data)
1446
+
1447
+ finding = None
1448
+ if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
1449
+ finding = eda_outcome.EDAFinding(
1450
+ severity=eda_outcome.EDASeverity.ATTENTION,
1451
+ explanation=zero_std_message,
1452
+ )
1453
+
1454
+ artifact = eda_outcome.StandardDeviationArtifact(
1455
+ variable=str(data.name),
1456
+ level=level,
1457
+ std_ds=std_ds,
1458
+ outlier_df=outlier_df,
1459
+ )
1460
+
1461
+ return finding, artifact
1462
+
1463
+ def check_geo_std(
1464
+ self,
1465
+ ) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]:
1466
+ """Checks std for geo-level KPI, treatments, R&F, and controls."""
1467
+ if self._is_national_data:
1468
+ raise ValueError('check_geo_std is not applicable for national models.')
1469
+
1470
+ findings = []
1471
+ artifacts = []
1472
+
1473
+ checks = [
1474
+ (
1475
+ self.kpi_scaled_da,
1476
+ (
1477
+ 'KPI has zero standard deviation after removing outliers'
1478
+ ' in certain geos, indicating weak or no signal in the response'
1479
+ ' variable for these geos. Please review the input data,'
1480
+ ' and/or consider grouping these geos together.'
1481
+ ),
1482
+ ),
1483
+ (
1484
+ self._stacked_treatment_control_scaled_da,
1485
+ (
1486
+ 'Some treatment or control variables have zero standard'
1487
+ ' deviation after removing outliers in certain geo(s). Please'
1488
+ ' review the input data. If these variables are sparse,'
1489
+ ' consider combining them to mitigate potential model'
1490
+ ' identifiability and convergence issues.'
1491
+ ),
1492
+ ),
1493
+ (
1494
+ self.all_reach_scaled_da,
1495
+ (
1496
+ 'There are RF or Organic RF channels with zero variation of'
1497
+ ' reach across time at a geo after outliers are removed. If'
1498
+ ' these channels also have low variation of reach in other'
1499
+ ' geos, consider modeling them as impression-based channels'
1500
+ ' instead by taking reach * frequency.'
1501
+ ),
1502
+ ),
1503
+ (
1504
+ self.all_freq_da,
1505
+ (
1506
+ 'There are RF or Organic RF channels with zero variation of'
1507
+ ' frequency across time at a geo after outliers are removed. If'
1508
+ ' these channels also have low variation of frequency in other'
1509
+ ' geos, consider modeling them as impression-based channels'
1510
+ ' instead by taking reach * frequency.'
1511
+ ),
1512
+ ),
1513
+ ]
1514
+
1515
+ for data_da, message in checks:
1516
+ if data_da is None:
1517
+ continue
1518
+ finding, artifact = self._check_std(
1519
+ level=eda_outcome.AnalysisLevel.GEO,
1520
+ data=data_da,
1521
+ zero_std_message=message,
1522
+ )
1523
+ artifacts.append(artifact)
1524
+ if finding:
1525
+ findings.append(finding)
1526
+
1527
+ # Add an INFO finding if no findings were added.
1528
+ if not findings:
1529
+ findings.append(
1530
+ eda_outcome.EDAFinding(
1531
+ severity=eda_outcome.EDASeverity.INFO,
1532
+ explanation=(
1533
+ 'Please review any identified outliers and the standard'
1534
+ ' deviation.'
1535
+ ),
1536
+ )
1537
+ )
1538
+
1539
+ return eda_outcome.EDAOutcome(
1540
+ check_type=eda_outcome.EDACheckType.STANDARD_DEVIATION,
1541
+ findings=findings,
1542
+ analysis_artifacts=artifacts,
1543
+ )
1544
+
1545
+ def check_national_std(
1546
+ self,
1547
+ ) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]:
1548
+ """Checks std for national-level KPI, treatments, R&F, and controls."""
1549
+ findings = []
1550
+ artifacts = []
1551
+
1552
+ checks = [
1553
+ (
1554
+ self.national_kpi_scaled_da,
1555
+ (
1556
+ 'The standard deviation of the scaled KPI drops from positive'
1557
+ ' to zero after removing outliers, indicating sparsity of KPI'
1558
+ ' i.e. lack of signal in the response variable. Please review'
1559
+ ' the input data, and/or reconsider the feasibility of model'
1560
+ ' fitting with this dataset.'
1561
+ ),
1562
+ ),
1563
+ (
1564
+ self._stacked_national_treatment_control_scaled_da,
1565
+ (
1566
+ 'The standard deviation of these scaled treatment or control'
1567
+ ' variables drops from positive to zero after removing'
1568
+ ' outliers. This indicates sparsity of these variables, which'
1569
+ ' may cause model identifiability and convergence issues.'
1570
+ ' Please review the input data, and/or consider combining these'
1571
+ ' variables to mitigate sparsity.'
1572
+ ),
1573
+ ),
1574
+ (
1575
+ self.national_all_reach_scaled_da,
1576
+ (
1577
+ 'There are RF channels with totally zero variation of reach'
1578
+ ' across time at the national level after outliers are removed.'
1579
+ ' Consider modeling these RF channels as impression-based'
1580
+ ' channels instead.'
1581
+ ),
1582
+ ),
1583
+ (
1584
+ self.national_all_freq_da,
1585
+ (
1586
+ 'There are RF channels with totally zero variation of frequency'
1587
+ ' across time at the national level after outliers are removed.'
1588
+ ' Consider modeling these RF channels as impression-based'
1589
+ ' channels instead.'
1590
+ ),
1591
+ ),
1592
+ ]
1593
+
1594
+ for data_da, message in checks:
1595
+ if data_da is None:
1596
+ continue
1597
+ finding, artifact = self._check_std(
1598
+ data=data_da,
1599
+ level=eda_outcome.AnalysisLevel.NATIONAL,
1600
+ zero_std_message=message,
1601
+ )
1602
+ artifacts.append(artifact)
1603
+ if finding:
1604
+ findings.append(finding)
1605
+
1606
+ # Add an INFO finding if no findings were added.
1607
+ if not findings:
1608
+ findings.append(
1609
+ eda_outcome.EDAFinding(
1610
+ severity=eda_outcome.EDASeverity.INFO,
1611
+ explanation=(
1612
+ 'Please review any identified outliers and the standard'
1613
+ ' deviation.'
1614
+ ),
1615
+ )
1616
+ )
1617
+
1618
+ return eda_outcome.EDAOutcome(
1619
+ check_type=eda_outcome.EDACheckType.STANDARD_DEVIATION,
1620
+ findings=findings,
1621
+ analysis_artifacts=artifacts,
1622
+ )
1623
+
1624
+ def check_std(
1625
+ self,
1626
+ ) -> eda_outcome.EDAOutcome[eda_outcome.StandardDeviationArtifact]:
1627
+ """Checks standard deviation for treatments and controls.
1628
+
1629
+ Returns:
1630
+ An EDAOutcome object with findings and result values.
1631
+ """
1632
+ if self._is_national_data:
1633
+ return self.check_national_std()
1634
+ else:
1635
+ return self.check_geo_std()
1636
+
1637
+ def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1638
+ """Computes geo-level variance inflation factor among treatments and controls."""
1639
+ if self._is_national_data:
1640
+ raise ValueError(
1641
+ 'Geo-level VIF checks are not applicable for national models.'
1642
+ )
1643
+
1644
+ # Overall level VIF check for geo data.
1645
+ tc_da = self._stacked_treatment_control_scaled_da
1646
+ overall_threshold = self._spec.vif_spec.overall_threshold
1647
+
1648
+ overall_vif_da = _calculate_vif(tc_da, _STACK_VAR_COORD_NAME)
1649
+ extreme_overall_vif_da = overall_vif_da.where(
1650
+ overall_vif_da > overall_threshold
1651
+ )
1652
+ extreme_overall_vif_df = extreme_overall_vif_da.to_dataframe(
1653
+ name=_VIF_COL_NAME
1654
+ ).dropna()
1655
+
1656
+ overall_vif_artifact = eda_outcome.VIFArtifact(
1657
+ level=eda_outcome.AnalysisLevel.OVERALL,
1658
+ vif_da=overall_vif_da,
1659
+ outlier_df=extreme_overall_vif_df,
1660
+ )
1661
+
1662
+ # Geo level VIF check.
1663
+ geo_threshold = self._spec.vif_spec.geo_threshold
1664
+ geo_vif_da = tc_da.groupby(constants.GEO).map(
1665
+ lambda x: _calculate_vif(x, _STACK_VAR_COORD_NAME)
1666
+ )
1667
+ extreme_geo_vif_da = geo_vif_da.where(geo_vif_da > geo_threshold)
1668
+ extreme_geo_vif_df = extreme_geo_vif_da.to_dataframe(
1669
+ name=_VIF_COL_NAME
1670
+ ).dropna()
1671
+
1672
+ geo_vif_artifact = eda_outcome.VIFArtifact(
1673
+ level=eda_outcome.AnalysisLevel.GEO,
1674
+ vif_da=geo_vif_da,
1675
+ outlier_df=extreme_geo_vif_df,
1676
+ )
1677
+
1678
+ findings = []
1679
+ if not extreme_overall_vif_df.empty:
1680
+ high_vif_vars = extreme_overall_vif_df.index.to_list()
1681
+ findings.append(
1682
+ eda_outcome.EDAFinding(
1683
+ severity=eda_outcome.EDASeverity.ERROR,
1684
+ explanation=(
1685
+ 'Some variables have extreme multicollinearity (VIF'
1686
+ f' >{overall_threshold}) across all times and geos. To'
1687
+ ' address multicollinearity, please drop any variable that'
1688
+ ' is a linear combination of other variables. Otherwise,'
1689
+ ' consider combining variables.\n'
1690
+ f'Variables with extreme VIF: {high_vif_vars}'
1691
+ ),
1692
+ )
1693
+ )
1694
+ elif not extreme_geo_vif_df.empty:
1695
+ findings.append(
1696
+ eda_outcome.EDAFinding(
1697
+ severity=eda_outcome.EDASeverity.ATTENTION,
1698
+ explanation=(
1699
+ 'Some variables have extreme multicollinearity (with VIF >'
1700
+ f' {geo_threshold}) in certain geo(s). Consider checking your'
1701
+ ' data, and/or combining these variables if they also have'
1702
+ ' high VIF in other geos.'
1703
+ ),
1704
+ )
1705
+ )
1706
+ else:
1707
+ findings.append(
1708
+ eda_outcome.EDAFinding(
1709
+ severity=eda_outcome.EDASeverity.INFO,
1710
+ explanation=(
1711
+ 'Please review the computed VIFs. Note that high VIF suggests'
1712
+ ' multicollinearity issues in the dataset, which may'
1713
+ ' jeopardize model identifiability and model convergence.'
1714
+ ' Consider combining the variables if high VIF occurs.'
1715
+ ),
1716
+ )
1717
+ )
1718
+
1719
+ return eda_outcome.EDAOutcome(
1720
+ check_type=eda_outcome.EDACheckType.MULTICOLLINEARITY,
1721
+ findings=findings,
1722
+ analysis_artifacts=[overall_vif_artifact, geo_vif_artifact],
1723
+ )
1724
+
1725
+ def check_national_vif(
1726
+ self,
1727
+ ) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1728
+ """Computes national-level variance inflation factor among treatments and controls."""
1729
+ national_tc_da = self._stacked_national_treatment_control_scaled_da
1730
+ national_threshold = self._spec.vif_spec.national_threshold
1731
+ national_vif_da = _calculate_vif(national_tc_da, _STACK_VAR_COORD_NAME)
1732
+
1733
+ extreme_national_vif_df = (
1734
+ national_vif_da.where(national_vif_da > national_threshold)
1735
+ .to_dataframe(name=_VIF_COL_NAME)
1736
+ .dropna()
1737
+ )
1738
+ national_vif_artifact = eda_outcome.VIFArtifact(
1739
+ level=eda_outcome.AnalysisLevel.NATIONAL,
1740
+ vif_da=national_vif_da,
1741
+ outlier_df=extreme_national_vif_df,
1742
+ )
1743
+
1744
+ findings = []
1745
+ if not extreme_national_vif_df.empty:
1746
+ high_vif_vars = extreme_national_vif_df.index.to_list()
1747
+ findings.append(
1748
+ eda_outcome.EDAFinding(
1749
+ severity=eda_outcome.EDASeverity.ERROR,
1750
+ explanation=(
1751
+ 'Some variables have extreme multicollinearity (with VIF >'
1752
+ f' {national_threshold}) across all times. To address'
1753
+ ' multicollinearity, please drop any variable that is a'
1754
+ ' linear combination of other variables. Otherwise, consider'
1755
+ ' combining variables.\n'
1756
+ f'Variables with extreme VIF: {high_vif_vars}'
1757
+ ),
1758
+ )
1759
+ )
1760
+ else:
1761
+ findings.append(
1762
+ eda_outcome.EDAFinding(
1763
+ severity=eda_outcome.EDASeverity.INFO,
1764
+ explanation=(
1765
+ 'Please review the computed VIFs. Note that high VIF suggests'
1766
+ ' multicollinearity issues in the dataset, which may'
1767
+ ' jeopardize model identifiability and model convergence.'
1768
+ ' Consider combining the variables if high VIF occurs.'
1769
+ ),
1770
+ )
1771
+ )
1772
+ return eda_outcome.EDAOutcome(
1773
+ check_type=eda_outcome.EDACheckType.MULTICOLLINEARITY,
1774
+ findings=findings,
1775
+ analysis_artifacts=[national_vif_artifact],
1776
+ )
1777
+
1778
+ def check_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1779
+ """Computes variance inflation factor among treatments and controls.
1780
+
1781
+ Returns:
1782
+ An EDAOutcome object with findings and result values.
1783
+ """
1784
+ if self._is_national_data:
1785
+ return self.check_national_vif()
1786
+ else:
1787
+ return self.check_geo_vif()
1788
+
1789
+ @property
1790
+ def kpi_has_variability(self) -> bool:
1791
+ """Returns True if the KPI has variability across geos and times."""
1792
+ return (
1793
+ self._overall_scaled_kpi_invariability_artifact.kpi_stdev.item()
1794
+ >= _STD_THRESHOLD
1795
+ )
1796
+
1797
+ def check_overall_kpi_invariability(self) -> eda_outcome.EDAOutcome:
1798
+ """Checks if the KPI is constant across all geos and times."""
1799
+ kpi = self._overall_scaled_kpi_invariability_artifact.kpi_da.name
1800
+ geo_text = '' if self._is_national_data else 'geos and '
1801
+
1802
+ if not self.kpi_has_variability:
1803
+ eda_finding = eda_outcome.EDAFinding(
1804
+ severity=eda_outcome.EDASeverity.ERROR,
1805
+ explanation=(
1806
+ f'`{kpi}` is constant across all {geo_text}times, indicating no'
1807
+ ' signal in the data. Please fix this data error.'
1808
+ ),
1809
+ )
1810
+ else:
1811
+ eda_finding = eda_outcome.EDAFinding(
1812
+ severity=eda_outcome.EDASeverity.INFO,
1813
+ explanation=(
1814
+ f'The {kpi} has variability across {geo_text}times in the data.'
1815
+ ),
1816
+ )
1817
+
1818
+ return eda_outcome.EDAOutcome(
1819
+ check_type=eda_outcome.EDACheckType.KPI_INVARIABILITY,
1820
+ findings=[eda_finding],
1821
+ analysis_artifacts=[self._overall_scaled_kpi_invariability_artifact],
1822
+ )
1823
+
1824
+ def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]:
1825
+ """Runs all critical EDA checks.
1826
+
1827
+ Critical checks are those that can result in EDASeverity.ERROR findings.
1828
+
1829
+ Returns:
1830
+ A list of EDA outcomes, one for each check.
1831
+ """
1832
+ outcomes = []
1833
+ for check, check_type in self._critical_checks:
1834
+ try:
1835
+ outcomes.append(check())
1836
+ except Exception as e: # pylint: disable=broad-except
1837
+ error_finding = eda_outcome.EDAFinding(
1838
+ severity=eda_outcome.EDASeverity.ERROR,
1839
+ explanation=f'An error occurred during check {check.__name__}: {e}',
1840
+ )
1841
+ outcomes.append(
1842
+ eda_outcome.EDAOutcome(
1843
+ check_type=check_type,
1844
+ findings=[error_finding],
1845
+ analysis_artifacts=[],
1846
+ )
1847
+ )
1848
+ return outcomes