google-meridian 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. google_meridian-1.3.2.dist-info/METADATA +209 -0
  2. google_meridian-1.3.2.dist-info/RECORD +76 -0
  3. {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +1 -2
  5. meridian/analysis/analyzer.py +0 -1
  6. meridian/analysis/optimizer.py +5 -3
  7. meridian/analysis/review/checks.py +81 -30
  8. meridian/analysis/review/constants.py +4 -0
  9. meridian/analysis/review/results.py +40 -9
  10. meridian/analysis/summarizer.py +1 -1
  11. meridian/analysis/visualizer.py +1 -1
  12. meridian/backend/__init__.py +229 -24
  13. meridian/backend/test_utils.py +194 -0
  14. meridian/constants.py +1 -0
  15. meridian/data/load.py +2 -0
  16. meridian/model/eda/__init__.py +0 -1
  17. meridian/model/eda/constants.py +12 -2
  18. meridian/model/eda/eda_engine.py +353 -45
  19. meridian/model/eda/eda_outcome.py +21 -1
  20. meridian/model/knots.py +17 -0
  21. meridian/model/model_test_data.py +15 -0
  22. meridian/{analysis/templates → templates}/card.html.jinja +1 -1
  23. meridian/{analysis/templates → templates}/chart.html.jinja +1 -1
  24. meridian/{analysis/templates → templates}/chips.html.jinja +1 -1
  25. meridian/{analysis → templates}/formatter.py +12 -1
  26. meridian/templates/formatter_test.py +216 -0
  27. meridian/{analysis/templates → templates}/insights.html.jinja +1 -1
  28. meridian/{analysis/templates → templates}/stats.html.jinja +1 -1
  29. meridian/{analysis/templates → templates}/style.css +1 -1
  30. meridian/{analysis/templates → templates}/style.scss +1 -1
  31. meridian/{analysis/templates → templates}/summary.html.jinja +4 -2
  32. meridian/{analysis/templates → templates}/table.html.jinja +1 -1
  33. meridian/version.py +1 -1
  34. schema/__init__.py +30 -0
  35. schema/serde/__init__.py +26 -0
  36. schema/serde/constants.py +48 -0
  37. schema/serde/distribution.py +515 -0
  38. schema/serde/eda_spec.py +192 -0
  39. schema/serde/function_registry.py +143 -0
  40. schema/serde/hyperparameters.py +363 -0
  41. schema/serde/inference_data.py +105 -0
  42. schema/serde/marketing_data.py +1321 -0
  43. schema/serde/meridian_serde.py +413 -0
  44. schema/serde/serde.py +47 -0
  45. schema/serde/test_data.py +4608 -0
  46. schema/utils/__init__.py +17 -0
  47. schema/utils/time_record.py +156 -0
  48. google_meridian-1.3.0.dist-info/METADATA +0 -409
  49. google_meridian-1.3.0.dist-info/RECORD +0 -62
  50. meridian/model/eda/meridian_eda.py +0 -220
  51. {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/WHEEL +0 -0
  52. {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -48,9 +48,6 @@ _CORRELATION_MATRIX_NAME = 'correlation_matrix'
48
48
  _OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
49
49
  _GEO_PAIRWISE_CORR_THRESHOLD = 0.999
50
50
  _NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
51
- _EMPTY_DF_FOR_EXTREME_CORR_PAIRS = pd.DataFrame(
52
- columns=[_CORR_VAR1, _CORR_VAR2, _CORRELATION_COL_NAME]
53
- )
54
51
  _Q1_THRESHOLD = 0.25
55
52
  _Q3_THRESHOLD = 0.75
56
53
  _IQR_MULTIPLIER = 1.5
@@ -175,10 +172,22 @@ def _data_array_like(
175
172
  )
176
173
 
177
174
 
178
- def _stack_variables(
175
+ def stack_variables(
179
176
  ds: xr.Dataset, coord_name: str = _STACK_VAR_COORD_NAME
180
177
  ) -> xr.DataArray:
181
- """Stacks data variables other than time and geo into a single variable."""
178
+ """Stacks data variables of a Dataset into a single DataArray.
179
+
180
+ This function is designed to work with Datasets that have 'time' or 'geo'
181
+ dimensions, which are preserved. Other dimensions are stacked into a new
182
+ dimension.
183
+
184
+ Args:
185
+ ds: The input xarray.Dataset to stack.
186
+ coord_name: The name of the new coordinate for the stacked dimension.
187
+
188
+ Returns:
189
+ An xarray.DataArray with the specified dimensions stacked.
190
+ """
182
191
  dims = []
183
192
  coords = []
184
193
  sample_dims = []
@@ -249,37 +258,57 @@ def _find_extreme_corr_pairs(
249
258
  corr_tri = _get_upper_triangle_corr_mat(extreme_corr_da)
250
259
  extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
251
260
 
252
- df = extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME).dropna()
253
- if df.empty:
254
- return _EMPTY_DF_FOR_EXTREME_CORR_PAIRS.copy()
255
- return df.sort_values(
256
- by=_CORRELATION_COL_NAME, ascending=False, inplace=False
261
+ return (
262
+ extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME)
263
+ .dropna()
264
+ .assign(**{
265
+ eda_constants.ABS_CORRELATION_COL_NAME: (
266
+ lambda x: x[_CORRELATION_COL_NAME].abs()
267
+ )
268
+ })
269
+ .sort_values(
270
+ by=eda_constants.ABS_CORRELATION_COL_NAME,
271
+ ascending=False,
272
+ inplace=False,
273
+ )
257
274
  )
258
275
 
259
276
 
260
- def _calculate_std(
277
+ def _get_outlier_bounds(
261
278
  input_da: xr.DataArray,
262
- ) -> tuple[xr.Dataset, pd.DataFrame]:
263
- """Helper function to compute std with and without outliers.
279
+ ) -> tuple[xr.DataArray, xr.DataArray]:
280
+ """Computes lower and upper bounds for outliers across time using the IQR method.
264
281
 
265
282
  Args:
266
- input_da: A DataArray for which to calculate the std.
283
+ input_da: A DataArray for which to calculate outlier bounds.
267
284
 
268
285
  Returns:
269
- A tuple where the first element is a Dataset with two data variables:
270
- 'std_incl_outliers' and 'std_excl_outliers'. The second element is a
271
- DataFrame with columns for variables, geo (if applicable), time, and
272
- outlier values.
286
+ A tuple containing the lower and upper bounds of outliers as DataArrays.
273
287
  """
274
- std_with_outliers = input_da.std(dim=constants.TIME, ddof=1)
275
-
276
288
  # TODO: Allow users to specify custom outlier definitions.
277
289
  q1 = input_da.quantile(_Q1_THRESHOLD, dim=constants.TIME)
278
290
  q3 = input_da.quantile(_Q3_THRESHOLD, dim=constants.TIME)
279
291
  iqr = q3 - q1
280
292
  lower_bound = q1 - _IQR_MULTIPLIER * iqr
281
293
  upper_bound = q3 + _IQR_MULTIPLIER * iqr
294
+ return lower_bound, upper_bound
295
+
296
+
297
+ def _calculate_std(
298
+ input_da: xr.DataArray,
299
+ ) -> xr.Dataset:
300
+ """Helper function to compute std with and without outliers.
282
301
 
302
+ Args:
303
+ input_da: A DataArray for which to calculate the std.
304
+
305
+ Returns:
306
+ A Dataset with two data variables: 'std_with_outliers' and
307
+ 'std_without_outliers'.
308
+ """
309
+ std_with_outliers = input_da.std(dim=constants.TIME, ddof=1)
310
+
311
+ lower_bound, upper_bound = _get_outlier_bounds(input_da)
283
312
  da_no_outlier = input_da.where(
284
313
  (input_da >= lower_bound) & (input_da <= upper_bound)
285
314
  )
@@ -289,17 +318,34 @@ def _calculate_std(
289
318
  _STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
290
319
  _STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
291
320
  })
321
+ return std_ds
322
+
292
323
 
324
+ def _calculate_outliers(
325
+ input_da: xr.DataArray,
326
+ ) -> pd.DataFrame:
327
+ """Helper function to extract outliers from a DataArray across time.
328
+
329
+ Args:
330
+ input_da: A DataArray from which to extract outliers.
331
+
332
+ Returns:
333
+ A DataFrame with columns for variables, geo (if applicable), time, and
334
+ outlier values.
335
+ """
336
+ lower_bound, upper_bound = _get_outlier_bounds(input_da)
293
337
  outlier_da = input_da.where(
294
338
  (input_da < lower_bound) | (input_da > upper_bound)
295
339
  )
296
-
297
- outlier_df = outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME).dropna()
298
- outlier_df = outlier_df.assign(
299
- **{_ABS_OUTLIERS_COL_NAME: np.abs(outlier_df[_OUTLIERS_COL_NAME])}
300
- ).sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
301
-
302
- return std_ds, outlier_df
340
+ outlier_df = (
341
+ outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME)
342
+ .dropna()
343
+ .assign(
344
+ **{_ABS_OUTLIERS_COL_NAME: lambda x: np.abs(x[_OUTLIERS_COL_NAME])}
345
+ )
346
+ .sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
347
+ )
348
+ return outlier_df
303
349
 
304
350
 
305
351
  def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
@@ -315,7 +361,9 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
315
361
  """
316
362
  num_vars = input_da.sizes[var_dim]
317
363
  np_data = input_da.values.reshape(-1, num_vars)
318
- np_data_with_const = sm.add_constant(np_data, prepend=True)
364
+ np_data_with_const = sm.add_constant(
365
+ np_data, prepend=True, has_constant='add'
366
+ )
319
367
 
320
368
  # Compute VIF for each variable excluding const which is the first one in the
321
369
  # 'variable' dimension.
@@ -332,6 +380,117 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
332
380
  return vif_da
333
381
 
334
382
 
383
+ def _check_cost_media_unit_inconsistency(
384
+ cost_da: xr.DataArray,
385
+ media_units_da: xr.DataArray,
386
+ ) -> pd.DataFrame:
387
+ """Checks for inconsistencies between cost and media units.
388
+
389
+ Args:
390
+ cost_da: DataArray containing cost data.
391
+ media_units_da: DataArray containing media unit data.
392
+
393
+ Returns:
394
+ A DataFrame of inconsistencies where either cost is zero and media units
395
+ are
396
+ positive, or cost is positive and media units are zero.
397
+ """
398
+ cost_media_units_ds = xr.merge([cost_da, media_units_da])
399
+
400
+ # Condition 1: cost == 0 and media unit > 0
401
+ zero_cost_positive_mask = (cost_da == 0) & (media_units_da > 0)
402
+ zero_cost_positive_media_unit_df = (
403
+ cost_media_units_ds.where(zero_cost_positive_mask).to_dataframe().dropna()
404
+ )
405
+
406
+ # Condition 2: cost > 0 and media unit == 0
407
+ positive_cost_zero_mask = (cost_da > 0) & (media_units_da == 0)
408
+ positive_cost_zero_media_unit_df = (
409
+ cost_media_units_ds.where(positive_cost_zero_mask).to_dataframe().dropna()
410
+ )
411
+
412
+ return pd.concat(
413
+ [zero_cost_positive_media_unit_df, positive_cost_zero_media_unit_df]
414
+ )
415
+
416
+
417
+ def _check_cost_per_media_unit(
418
+ cost_ds: xr.Dataset,
419
+ media_units_ds: xr.Dataset,
420
+ level: eda_outcome.AnalysisLevel,
421
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
422
+ """Helper to check if the cost per media unit is valid."""
423
+ findings = []
424
+ # Stack variables with the same dimension name, so that they can be operated
425
+ # on together.
426
+ cost_da = stack_variables(cost_ds, constants.CHANNEL).rename(constants.SPEND)
427
+ media_units_da = stack_variables(media_units_ds, constants.CHANNEL).rename(
428
+ constants.MEDIA_UNITS
429
+ )
430
+
431
+ cost_media_unit_inconsistency_df = _check_cost_media_unit_inconsistency(
432
+ cost_da,
433
+ media_units_da,
434
+ )
435
+ if not cost_media_unit_inconsistency_df.empty:
436
+ findings.append(
437
+ eda_outcome.EDAFinding(
438
+ severity=eda_outcome.EDASeverity.ATTENTION,
439
+ explanation=(
440
+ 'There are instances of inconsistent cost and media units.'
441
+ ' This occurs when cost is zero but media units are positive,'
442
+ ' or when cost is positive but media units are zero. Please'
443
+ ' review the outcome artifact for more details.'
444
+ ),
445
+ )
446
+ )
447
+
448
+ # Calculate cost per media unit
449
+ # Avoid division by zero by setting cost to NaN where media units are 0.
450
+ # Note that both (cost == media unit == 0) and (cost > 0 and media unit ==
451
+ # 0) result in NaN, while the latter one is not desired.
452
+ cost_per_media_unit_da = xr.where(
453
+ media_units_da == 0,
454
+ np.nan,
455
+ cost_da / media_units_da,
456
+ )
457
+ cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
458
+
459
+ outlier_df = _calculate_outliers(cost_per_media_unit_da)
460
+ if not outlier_df.empty:
461
+ findings.append(
462
+ eda_outcome.EDAFinding(
463
+ severity=eda_outcome.EDASeverity.ATTENTION,
464
+ explanation=(
465
+ 'There are outliers in cost per media unit across time.'
466
+ ' Please review the outcome artifact for more details.'
467
+ ),
468
+ )
469
+ )
470
+
471
+ # If no specific findings, add an INFO finding.
472
+ if not findings:
473
+ findings.append(
474
+ eda_outcome.EDAFinding(
475
+ severity=eda_outcome.EDASeverity.INFO,
476
+ explanation='Please review the cost per media unit data.',
477
+ )
478
+ )
479
+
480
+ artifact = eda_outcome.CostPerMediaUnitArtifact(
481
+ level=level,
482
+ cost_per_media_unit_da=cost_per_media_unit_da,
483
+ cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
484
+ outlier_df=outlier_df,
485
+ )
486
+
487
+ return eda_outcome.EDAOutcome(
488
+ check_type=eda_outcome.EDACheckType.COST_PER_MEDIA_UNIT,
489
+ findings=findings,
490
+ analysis_artifacts=[artifact],
491
+ )
492
+
493
+
335
494
  class EDAEngine:
336
495
  """Meridian EDA Engine."""
337
496
 
@@ -354,6 +513,7 @@ class EDAEngine:
354
513
 
355
514
  @functools.cached_property
356
515
  def controls_scaled_da(self) -> xr.DataArray | None:
516
+ """Returns the scaled controls data array."""
357
517
  if self._meridian.input_data.controls is None:
358
518
  return None
359
519
  controls_scaled_da = _data_array_like(
@@ -388,6 +548,7 @@ class EDAEngine:
388
548
 
389
549
  @functools.cached_property
390
550
  def media_raw_da(self) -> xr.DataArray | None:
551
+ """Returns the raw media data array."""
391
552
  if self._meridian.input_data.media is None:
392
553
  return None
393
554
  raw_media_da = self._truncate_media_time(self._meridian.input_data.media)
@@ -396,6 +557,7 @@ class EDAEngine:
396
557
 
397
558
  @functools.cached_property
398
559
  def media_scaled_da(self) -> xr.DataArray | None:
560
+ """Returns the scaled media data array."""
399
561
  if self._meridian.input_data.media is None:
400
562
  return None
401
563
  media_scaled_da = _data_array_like(
@@ -423,10 +585,11 @@ class EDAEngine:
423
585
  @functools.cached_property
424
586
  def national_media_spend_da(self) -> xr.DataArray | None:
425
587
  """Returns the national media spend data array."""
426
- if self.media_spend_da is None:
588
+ media_spend = self.media_spend_da
589
+ if media_spend is None:
427
590
  return None
428
591
  if self._is_national_data:
429
- national_da = self.media_spend_da.squeeze(constants.GEO, drop=True)
592
+ national_da = media_spend.squeeze(constants.GEO, drop=True)
430
593
  national_da.name = constants.NATIONAL_MEDIA_SPEND
431
594
  else:
432
595
  national_da = self._aggregate_and_scale_geo_da(
@@ -472,6 +635,7 @@ class EDAEngine:
472
635
 
473
636
  @functools.cached_property
474
637
  def organic_media_raw_da(self) -> xr.DataArray | None:
638
+ """Returns the raw organic media data array."""
475
639
  if self._meridian.input_data.organic_media is None:
476
640
  return None
477
641
  raw_organic_media_da = self._truncate_media_time(
@@ -482,6 +646,7 @@ class EDAEngine:
482
646
 
483
647
  @functools.cached_property
484
648
  def organic_media_scaled_da(self) -> xr.DataArray | None:
649
+ """Returns the scaled organic media data array."""
485
650
  if self._meridian.input_data.organic_media is None:
486
651
  return None
487
652
  organic_media_scaled_da = _data_array_like(
@@ -527,6 +692,7 @@ class EDAEngine:
527
692
 
528
693
  @functools.cached_property
529
694
  def non_media_scaled_da(self) -> xr.DataArray | None:
695
+ """Returns the scaled non-media treatments data array."""
530
696
  if self._meridian.input_data.non_media_treatments is None:
531
697
  return None
532
698
  non_media_scaled_da = _data_array_like(
@@ -576,10 +742,11 @@ class EDAEngine:
576
742
  @functools.cached_property
577
743
  def national_rf_spend_da(self) -> xr.DataArray | None:
578
744
  """Returns the national RF spend data array."""
579
- if self.rf_spend_da is None:
745
+ rf_spend = self.rf_spend_da
746
+ if rf_spend is None:
580
747
  return None
581
748
  if self._is_national_data:
582
- national_da = self.rf_spend_da.squeeze(constants.GEO, drop=True)
749
+ national_da = rf_spend.squeeze(constants.GEO, drop=True)
583
750
  national_da.name = constants.NATIONAL_RF_SPEND
584
751
  else:
585
752
  national_da = self._aggregate_and_scale_geo_da(
@@ -601,12 +768,14 @@ class EDAEngine:
601
768
 
602
769
  @property
603
770
  def reach_raw_da(self) -> xr.DataArray | None:
771
+ """Returns the raw reach data array."""
604
772
  if self._rf_data is None:
605
773
  return None
606
774
  return self._rf_data.reach_raw_da
607
775
 
608
776
  @property
609
777
  def reach_scaled_da(self) -> xr.DataArray | None:
778
+ """Returns the scaled reach data array."""
610
779
  if self._rf_data is None:
611
780
  return None
612
781
  return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
@@ -627,6 +796,7 @@ class EDAEngine:
627
796
 
628
797
  @property
629
798
  def frequency_da(self) -> xr.DataArray | None:
799
+ """Returns the frequency data array."""
630
800
  if self._rf_data is None:
631
801
  return None
632
802
  return self._rf_data.frequency_da # pytype: disable=attribute-error
@@ -640,19 +810,21 @@ class EDAEngine:
640
810
 
641
811
  @property
642
812
  def rf_impressions_raw_da(self) -> xr.DataArray | None:
813
+ """Returns the raw RF impressions data array."""
643
814
  if self._rf_data is None:
644
815
  return None
645
- return self._rf_data.rf_impressions_raw_da
816
+ return self._rf_data.rf_impressions_raw_da # pytype: disable=attribute-error
646
817
 
647
818
  @property
648
819
  def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
649
820
  """Returns the national raw RF impressions data array."""
650
821
  if self._rf_data is None:
651
822
  return None
652
- return self._rf_data.national_rf_impressions_raw_da
823
+ return self._rf_data.national_rf_impressions_raw_da # pytype: disable=attribute-error
653
824
 
654
825
  @property
655
826
  def rf_impressions_scaled_da(self) -> xr.DataArray | None:
827
+ """Returns the scaled RF impressions data array."""
656
828
  if self._rf_data is None:
657
829
  return None
658
830
  return self._rf_data.rf_impressions_scaled_da
@@ -676,12 +848,14 @@ class EDAEngine:
676
848
 
677
849
  @property
678
850
  def organic_reach_raw_da(self) -> xr.DataArray | None:
851
+ """Returns the raw organic reach data array."""
679
852
  if self._organic_rf_data is None:
680
853
  return None
681
854
  return self._organic_rf_data.reach_raw_da
682
855
 
683
856
  @property
684
857
  def organic_reach_scaled_da(self) -> xr.DataArray | None:
858
+ """Returns the scaled organic reach data array."""
685
859
  if self._organic_rf_data is None:
686
860
  return None
687
861
  return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
@@ -702,6 +876,7 @@ class EDAEngine:
702
876
 
703
877
  @property
704
878
  def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
879
+ """Returns the scaled organic RF impressions data array."""
705
880
  if self._organic_rf_data is None:
706
881
  return None
707
882
  return self._organic_rf_data.rf_impressions_scaled_da
@@ -715,6 +890,7 @@ class EDAEngine:
715
890
 
716
891
  @property
717
892
  def organic_frequency_da(self) -> xr.DataArray | None:
893
+ """Returns the organic frequency data array."""
718
894
  if self._organic_rf_data is None:
719
895
  return None
720
896
  return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
@@ -728,6 +904,7 @@ class EDAEngine:
728
904
 
729
905
  @property
730
906
  def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
907
+ """Returns the raw organic RF impressions data array."""
731
908
  if self._organic_rf_data is None:
732
909
  return None
733
910
  return self._organic_rf_data.rf_impressions_raw_da
@@ -741,6 +918,7 @@ class EDAEngine:
741
918
 
742
919
  @functools.cached_property
743
920
  def geo_population_da(self) -> xr.DataArray | None:
921
+ """Returns the geo population data array."""
744
922
  if self._is_national_data:
745
923
  return None
746
924
  return xr.DataArray(
@@ -752,6 +930,7 @@ class EDAEngine:
752
930
 
753
931
  @functools.cached_property
754
932
  def kpi_scaled_da(self) -> xr.DataArray:
933
+ """Returns the scaled KPI data array."""
755
934
  scaled_kpi_da = _data_array_like(
756
935
  da=self._meridian.input_data.kpi,
757
936
  values=self._meridian.kpi_scaled,
@@ -806,10 +985,42 @@ class EDAEngine:
806
985
  ]
807
986
  return xr.merge(to_merge, join='inner')
808
987
 
988
+ @functools.cached_property
989
+ def all_spend_ds(self) -> xr.Dataset:
990
+ """Returns a Dataset containing all spend data.
991
+
992
+ This includes media spend and rf spend.
993
+ """
994
+ to_merge = [
995
+ da
996
+ for da in [
997
+ self.media_spend_da,
998
+ self.rf_spend_da,
999
+ ]
1000
+ if da is not None
1001
+ ]
1002
+ return xr.merge(to_merge, join='inner')
1003
+
1004
+ @functools.cached_property
1005
+ def national_all_spend_ds(self) -> xr.Dataset:
1006
+ """Returns a Dataset containing all national spend data.
1007
+
1008
+ This includes media spend and rf spend.
1009
+ """
1010
+ to_merge = [
1011
+ da
1012
+ for da in [
1013
+ self.national_media_spend_da,
1014
+ self.national_rf_spend_da,
1015
+ ]
1016
+ if da is not None
1017
+ ]
1018
+ return xr.merge(to_merge, join='inner')
1019
+
809
1020
  @functools.cached_property
810
1021
  def _stacked_treatment_control_scaled_da(self) -> xr.DataArray:
811
1022
  """Returns a stacked DataArray of treatment_control_scaled_ds."""
812
- da = _stack_variables(self.treatment_control_scaled_ds)
1023
+ da = stack_variables(self.treatment_control_scaled_ds)
813
1024
  da.name = constants.TREATMENT_CONTROL_SCALED
814
1025
  return da
815
1026
 
@@ -837,10 +1048,26 @@ class EDAEngine:
837
1048
  @functools.cached_property
838
1049
  def _stacked_national_treatment_control_scaled_da(self) -> xr.DataArray:
839
1050
  """Returns a stacked DataArray of national_treatment_control_scaled_ds."""
840
- da = _stack_variables(self.national_treatment_control_scaled_ds)
1051
+ da = stack_variables(self.national_treatment_control_scaled_ds)
841
1052
  da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
842
1053
  return da
843
1054
 
1055
+ @functools.cached_property
1056
+ def treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1057
+ """Returns a Dataset of scaled treatments excluding non-media."""
1058
+ return self.treatment_control_scaled_ds.drop_dims(
1059
+ [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1060
+ errors='ignore',
1061
+ )
1062
+
1063
+ @functools.cached_property
1064
+ def national_treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1065
+ """Returns a Dataset of national scaled treatments excluding non-media."""
1066
+ return self.national_treatment_control_scaled_ds.drop_dims(
1067
+ [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1068
+ errors='ignore',
1069
+ )
1070
+
844
1071
  @functools.cached_property
845
1072
  def all_reach_scaled_da(self) -> xr.DataArray | None:
846
1073
  """Returns a DataArray containing all scaled reach data.
@@ -947,6 +1174,30 @@ class EDAEngine:
947
1174
  da.name = constants.NATIONAL_ALL_FREQUENCY
948
1175
  return da
949
1176
 
1177
+ @functools.cached_property
1178
+ def paid_raw_media_units_ds(self) -> xr.Dataset:
1179
+ to_merge = [
1180
+ da
1181
+ for da in [
1182
+ self.media_raw_da,
1183
+ self.rf_impressions_raw_da,
1184
+ ]
1185
+ if da is not None
1186
+ ]
1187
+ return xr.merge(to_merge, join='inner')
1188
+
1189
+ @functools.cached_property
1190
+ def national_paid_raw_media_units_ds(self) -> xr.Dataset:
1191
+ to_merge = [
1192
+ da
1193
+ for da in [
1194
+ self.national_media_raw_da,
1195
+ self.national_rf_impressions_raw_da,
1196
+ ]
1197
+ if da is not None
1198
+ ]
1199
+ return xr.merge(to_merge, join='inner')
1200
+
950
1201
  @property
951
1202
  def _critical_checks(
952
1203
  self,
@@ -1384,8 +1635,8 @@ class EDAEngine:
1384
1635
  """
1385
1636
  if self._is_national_data:
1386
1637
  return self.check_national_pairwise_corr()
1387
- else:
1388
- return self.check_geo_pairwise_corr()
1638
+
1639
+ return self.check_geo_pairwise_corr()
1389
1640
 
1390
1641
  def _check_std(
1391
1642
  self,
@@ -1393,10 +1644,11 @@ class EDAEngine:
1393
1644
  level: eda_outcome.AnalysisLevel,
1394
1645
  zero_std_message: str,
1395
1646
  ) -> tuple[
1396
- Optional[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
1647
+ eda_outcome.EDAFinding | None, eda_outcome.StandardDeviationArtifact
1397
1648
  ]:
1398
1649
  """Helper to check standard deviation."""
1399
- std_ds, outlier_df = _calculate_std(data)
1650
+ std_ds = _calculate_std(data)
1651
+ outlier_df = _calculate_outliers(data)
1400
1652
 
1401
1653
  finding = None
1402
1654
  if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
@@ -1585,8 +1837,8 @@ class EDAEngine:
1585
1837
  """
1586
1838
  if self._is_national_data:
1587
1839
  return self.check_national_std()
1588
- else:
1589
- return self.check_geo_std()
1840
+
1841
+ return self.check_geo_std()
1590
1842
 
1591
1843
  def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1592
1844
  """Computes geo-level variance inflation factor among treatments and controls."""
@@ -1737,8 +1989,8 @@ class EDAEngine:
1737
1989
  """
1738
1990
  if self._is_national_data:
1739
1991
  return self.check_national_vif()
1740
- else:
1741
- return self.check_geo_vif()
1992
+
1993
+ return self.check_geo_vif()
1742
1994
 
1743
1995
  @property
1744
1996
  def kpi_has_variability(self) -> bool:
@@ -1775,6 +2027,60 @@ class EDAEngine:
1775
2027
  analysis_artifacts=[self._overall_scaled_kpi_invariability_artifact],
1776
2028
  )
1777
2029
 
2030
+ def check_geo_cost_per_media_unit(
2031
+ self,
2032
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
2033
+ """Checks if the cost per media unit is valid for geo data.
2034
+
2035
+ Returns:
2036
+ An EDAOutcome object with findings and result values.
2037
+
2038
+ Raises:
2039
+ GeoLevelCheckOnNationalModelError: If the check is called for a national
2040
+ model.
2041
+ """
2042
+ if self._is_national_data:
2043
+ raise GeoLevelCheckOnNationalModelError(
2044
+ 'check_geo_cost_per_media_unit is not supported for national models.'
2045
+ )
2046
+ return _check_cost_per_media_unit(
2047
+ self.all_spend_ds,
2048
+ self.paid_raw_media_units_ds,
2049
+ eda_outcome.AnalysisLevel.GEO,
2050
+ )
2051
+
2052
+ def check_national_cost_per_media_unit(
2053
+ self,
2054
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
2055
+ """Checks if the cost per media unit is valid for national data.
2056
+
2057
+ Returns:
2058
+ An EDAOutcome object with findings and result values.
2059
+ """
2060
+ return _check_cost_per_media_unit(
2061
+ self.national_all_spend_ds,
2062
+ self.national_paid_raw_media_units_ds,
2063
+ eda_outcome.AnalysisLevel.NATIONAL,
2064
+ )
2065
+
2066
+ def check_cost_per_media_unit(
2067
+ self,
2068
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
2069
+ """Checks if the cost per media unit is valid.
2070
+
2071
+ This function checks the following conditions:
2072
+ 1. cost == 0 and media unit > 0.
2073
+ 2. cost > 0 and media unit == 0.
2074
+ 3. cost_per_media_unit has outliers.
2075
+
2076
+ Returns:
2077
+ An EDAOutcome object with findings and result values.
2078
+ """
2079
+ if self._is_national_data:
2080
+ return self.check_national_cost_per_media_unit()
2081
+
2082
+ return self.check_geo_cost_per_media_unit()
2083
+
1778
2084
  def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]:
1779
2085
  """Runs all critical EDA checks.
1780
2086
 
@@ -1790,7 +2096,9 @@ class EDAEngine:
1790
2096
  except Exception as e: # pylint: disable=broad-except
1791
2097
  error_finding = eda_outcome.EDAFinding(
1792
2098
  severity=eda_outcome.EDASeverity.ERROR,
1793
- explanation=f'An error occurred during check {check.__name__}: {e}',
2099
+ explanation=(
2100
+ f'An error occurred during running {check.__name__}: {e!r}'
2101
+ ),
1794
2102
  )
1795
2103
  outcomes.append(
1796
2104
  eda_outcome.EDAOutcome(