google-meridian 1.3.1__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/METADATA +7 -7
  2. {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/RECORD +35 -35
  3. meridian/analysis/__init__.py +1 -2
  4. meridian/analysis/analyzer.py +0 -1
  5. meridian/analysis/optimizer.py +5 -3
  6. meridian/analysis/review/checks.py +81 -30
  7. meridian/analysis/review/constants.py +4 -0
  8. meridian/analysis/review/results.py +40 -9
  9. meridian/analysis/summarizer.py +1 -1
  10. meridian/analysis/visualizer.py +1 -1
  11. meridian/backend/__init__.py +53 -5
  12. meridian/backend/test_utils.py +72 -0
  13. meridian/constants.py +1 -0
  14. meridian/data/load.py +2 -0
  15. meridian/model/eda/__init__.py +0 -1
  16. meridian/model/eda/constants.py +12 -2
  17. meridian/model/eda/eda_engine.py +299 -37
  18. meridian/model/eda/eda_outcome.py +21 -1
  19. meridian/model/knots.py +17 -0
  20. meridian/{analysis/templates → templates}/card.html.jinja +1 -1
  21. meridian/{analysis/templates → templates}/chart.html.jinja +1 -1
  22. meridian/{analysis/templates → templates}/chips.html.jinja +1 -1
  23. meridian/{analysis → templates}/formatter.py +12 -1
  24. meridian/templates/formatter_test.py +216 -0
  25. meridian/{analysis/templates → templates}/insights.html.jinja +1 -1
  26. meridian/{analysis/templates → templates}/stats.html.jinja +1 -1
  27. meridian/{analysis/templates → templates}/style.css +1 -1
  28. meridian/{analysis/templates → templates}/style.scss +1 -1
  29. meridian/{analysis/templates → templates}/summary.html.jinja +4 -2
  30. meridian/{analysis/templates → templates}/table.html.jinja +1 -1
  31. meridian/version.py +1 -1
  32. schema/__init__.py +12 -0
  33. meridian/model/eda/meridian_eda.py +0 -220
  34. {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/WHEEL +0 -0
  35. {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/licenses/LICENSE +0 -0
  36. {google_meridian-1.3.1.dist-info → google_meridian-1.3.2.dist-info}/top_level.txt +0 -0
@@ -48,9 +48,6 @@ _CORRELATION_MATRIX_NAME = 'correlation_matrix'
48
48
  _OVERALL_PAIRWISE_CORR_THRESHOLD = 0.999
49
49
  _GEO_PAIRWISE_CORR_THRESHOLD = 0.999
50
50
  _NATIONAL_PAIRWISE_CORR_THRESHOLD = 0.999
51
- _EMPTY_DF_FOR_EXTREME_CORR_PAIRS = pd.DataFrame(
52
- columns=[_CORR_VAR1, _CORR_VAR2, _CORRELATION_COL_NAME]
53
- )
54
51
  _Q1_THRESHOLD = 0.25
55
52
  _Q3_THRESHOLD = 0.75
56
53
  _IQR_MULTIPLIER = 1.5
@@ -261,37 +258,57 @@ def _find_extreme_corr_pairs(
261
258
  corr_tri = _get_upper_triangle_corr_mat(extreme_corr_da)
262
259
  extreme_corr_da = corr_tri.where(abs(corr_tri) > extreme_corr_threshold)
263
260
 
264
- df = extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME).dropna()
265
- if df.empty:
266
- return _EMPTY_DF_FOR_EXTREME_CORR_PAIRS.copy()
267
- return df.sort_values(
268
- by=_CORRELATION_COL_NAME, ascending=False, inplace=False
261
+ return (
262
+ extreme_corr_da.to_dataframe(name=_CORRELATION_COL_NAME)
263
+ .dropna()
264
+ .assign(**{
265
+ eda_constants.ABS_CORRELATION_COL_NAME: (
266
+ lambda x: x[_CORRELATION_COL_NAME].abs()
267
+ )
268
+ })
269
+ .sort_values(
270
+ by=eda_constants.ABS_CORRELATION_COL_NAME,
271
+ ascending=False,
272
+ inplace=False,
273
+ )
269
274
  )
270
275
 
271
276
 
272
- def _calculate_std(
277
+ def _get_outlier_bounds(
273
278
  input_da: xr.DataArray,
274
- ) -> tuple[xr.Dataset, pd.DataFrame]:
275
- """Helper function to compute std with and without outliers.
279
+ ) -> tuple[xr.DataArray, xr.DataArray]:
280
+ """Computes lower and upper bounds for outliers across time using the IQR method.
276
281
 
277
282
  Args:
278
- input_da: A DataArray for which to calculate the std.
283
+ input_da: A DataArray for which to calculate outlier bounds.
279
284
 
280
285
  Returns:
281
- A tuple where the first element is a Dataset with two data variables:
282
- 'std_incl_outliers' and 'std_excl_outliers'. The second element is a
283
- DataFrame with columns for variables, geo (if applicable), time, and
284
- outlier values.
286
+ A tuple containing the lower and upper bounds of outliers as DataArrays.
285
287
  """
286
- std_with_outliers = input_da.std(dim=constants.TIME, ddof=1)
287
-
288
288
  # TODO: Allow users to specify custom outlier definitions.
289
289
  q1 = input_da.quantile(_Q1_THRESHOLD, dim=constants.TIME)
290
290
  q3 = input_da.quantile(_Q3_THRESHOLD, dim=constants.TIME)
291
291
  iqr = q3 - q1
292
292
  lower_bound = q1 - _IQR_MULTIPLIER * iqr
293
293
  upper_bound = q3 + _IQR_MULTIPLIER * iqr
294
+ return lower_bound, upper_bound
295
+
296
+
297
+ def _calculate_std(
298
+ input_da: xr.DataArray,
299
+ ) -> xr.Dataset:
300
+ """Helper function to compute std with and without outliers.
301
+
302
+ Args:
303
+ input_da: A DataArray for which to calculate the std.
304
+
305
+ Returns:
306
+ A Dataset with two data variables: 'std_with_outliers' and
307
+ 'std_without_outliers'.
308
+ """
309
+ std_with_outliers = input_da.std(dim=constants.TIME, ddof=1)
294
310
 
311
+ lower_bound, upper_bound = _get_outlier_bounds(input_da)
295
312
  da_no_outlier = input_da.where(
296
313
  (input_da >= lower_bound) & (input_da <= upper_bound)
297
314
  )
@@ -301,17 +318,34 @@ def _calculate_std(
301
318
  _STD_WITH_OUTLIERS_VAR_NAME: std_with_outliers,
302
319
  _STD_WITHOUT_OUTLIERS_VAR_NAME: std_without_outliers,
303
320
  })
321
+ return std_ds
322
+
323
+
324
+ def _calculate_outliers(
325
+ input_da: xr.DataArray,
326
+ ) -> pd.DataFrame:
327
+ """Helper function to extract outliers from a DataArray across time.
328
+
329
+ Args:
330
+ input_da: A DataArray from which to extract outliers.
304
331
 
332
+ Returns:
333
+ A DataFrame with columns for variables, geo (if applicable), time, and
334
+ outlier values.
335
+ """
336
+ lower_bound, upper_bound = _get_outlier_bounds(input_da)
305
337
  outlier_da = input_da.where(
306
338
  (input_da < lower_bound) | (input_da > upper_bound)
307
339
  )
308
-
309
- outlier_df = outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME).dropna()
310
- outlier_df = outlier_df.assign(
311
- **{_ABS_OUTLIERS_COL_NAME: np.abs(outlier_df[_OUTLIERS_COL_NAME])}
312
- ).sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
313
-
314
- return std_ds, outlier_df
340
+ outlier_df = (
341
+ outlier_da.to_dataframe(name=_OUTLIERS_COL_NAME)
342
+ .dropna()
343
+ .assign(
344
+ **{_ABS_OUTLIERS_COL_NAME: lambda x: np.abs(x[_OUTLIERS_COL_NAME])}
345
+ )
346
+ .sort_values(by=_ABS_OUTLIERS_COL_NAME, ascending=False, inplace=False)
347
+ )
348
+ return outlier_df
315
349
 
316
350
 
317
351
  def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
@@ -327,7 +361,9 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
327
361
  """
328
362
  num_vars = input_da.sizes[var_dim]
329
363
  np_data = input_da.values.reshape(-1, num_vars)
330
- np_data_with_const = sm.add_constant(np_data, prepend=True)
364
+ np_data_with_const = sm.add_constant(
365
+ np_data, prepend=True, has_constant='add'
366
+ )
331
367
 
332
368
  # Compute VIF for each variable excluding const which is the first one in the
333
369
  # 'variable' dimension.
@@ -344,6 +380,117 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
344
380
  return vif_da
345
381
 
346
382
 
383
+ def _check_cost_media_unit_inconsistency(
384
+ cost_da: xr.DataArray,
385
+ media_units_da: xr.DataArray,
386
+ ) -> pd.DataFrame:
387
+ """Checks for inconsistencies between cost and media units.
388
+
389
+ Args:
390
+ cost_da: DataArray containing cost data.
391
+ media_units_da: DataArray containing media unit data.
392
+
393
+ Returns:
394
+ A DataFrame of inconsistencies where either cost is zero and media units
395
+ are
396
+ positive, or cost is positive and media units are zero.
397
+ """
398
+ cost_media_units_ds = xr.merge([cost_da, media_units_da])
399
+
400
+ # Condition 1: cost == 0 and media unit > 0
401
+ zero_cost_positive_mask = (cost_da == 0) & (media_units_da > 0)
402
+ zero_cost_positive_media_unit_df = (
403
+ cost_media_units_ds.where(zero_cost_positive_mask).to_dataframe().dropna()
404
+ )
405
+
406
+ # Condition 2: cost > 0 and media unit == 0
407
+ positive_cost_zero_mask = (cost_da > 0) & (media_units_da == 0)
408
+ positive_cost_zero_media_unit_df = (
409
+ cost_media_units_ds.where(positive_cost_zero_mask).to_dataframe().dropna()
410
+ )
411
+
412
+ return pd.concat(
413
+ [zero_cost_positive_media_unit_df, positive_cost_zero_media_unit_df]
414
+ )
415
+
416
+
417
+ def _check_cost_per_media_unit(
418
+ cost_ds: xr.Dataset,
419
+ media_units_ds: xr.Dataset,
420
+ level: eda_outcome.AnalysisLevel,
421
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
422
+ """Helper to check if the cost per media unit is valid."""
423
+ findings = []
424
+ # Stack variables with the same dimension name, so that they can be operated
425
+ # on together.
426
+ cost_da = stack_variables(cost_ds, constants.CHANNEL).rename(constants.SPEND)
427
+ media_units_da = stack_variables(media_units_ds, constants.CHANNEL).rename(
428
+ constants.MEDIA_UNITS
429
+ )
430
+
431
+ cost_media_unit_inconsistency_df = _check_cost_media_unit_inconsistency(
432
+ cost_da,
433
+ media_units_da,
434
+ )
435
+ if not cost_media_unit_inconsistency_df.empty:
436
+ findings.append(
437
+ eda_outcome.EDAFinding(
438
+ severity=eda_outcome.EDASeverity.ATTENTION,
439
+ explanation=(
440
+ 'There are instances of inconsistent cost and media units.'
441
+ ' This occurs when cost is zero but media units are positive,'
442
+ ' or when cost is positive but media units are zero. Please'
443
+ ' review the outcome artifact for more details.'
444
+ ),
445
+ )
446
+ )
447
+
448
+ # Calculate cost per media unit
449
+ # Avoid division by zero by setting cost to NaN where media units are 0.
450
+ # Note that both (cost == media unit == 0) and (cost > 0 and media unit ==
451
+ # 0) result in NaN, while the latter one is not desired.
452
+ cost_per_media_unit_da = xr.where(
453
+ media_units_da == 0,
454
+ np.nan,
455
+ cost_da / media_units_da,
456
+ )
457
+ cost_per_media_unit_da.name = eda_constants.COST_PER_MEDIA_UNIT
458
+
459
+ outlier_df = _calculate_outliers(cost_per_media_unit_da)
460
+ if not outlier_df.empty:
461
+ findings.append(
462
+ eda_outcome.EDAFinding(
463
+ severity=eda_outcome.EDASeverity.ATTENTION,
464
+ explanation=(
465
+ 'There are outliers in cost per media unit across time.'
466
+ ' Please review the outcome artifact for more details.'
467
+ ),
468
+ )
469
+ )
470
+
471
+ # If no specific findings, add an INFO finding.
472
+ if not findings:
473
+ findings.append(
474
+ eda_outcome.EDAFinding(
475
+ severity=eda_outcome.EDASeverity.INFO,
476
+ explanation='Please review the cost per media unit data.',
477
+ )
478
+ )
479
+
480
+ artifact = eda_outcome.CostPerMediaUnitArtifact(
481
+ level=level,
482
+ cost_per_media_unit_da=cost_per_media_unit_da,
483
+ cost_media_unit_inconsistency_df=cost_media_unit_inconsistency_df,
484
+ outlier_df=outlier_df,
485
+ )
486
+
487
+ return eda_outcome.EDAOutcome(
488
+ check_type=eda_outcome.EDACheckType.COST_PER_MEDIA_UNIT,
489
+ findings=findings,
490
+ analysis_artifacts=[artifact],
491
+ )
492
+
493
+
347
494
  class EDAEngine:
348
495
  """Meridian EDA Engine."""
349
496
 
@@ -366,6 +513,7 @@ class EDAEngine:
366
513
 
367
514
  @functools.cached_property
368
515
  def controls_scaled_da(self) -> xr.DataArray | None:
516
+ """Returns the scaled controls data array."""
369
517
  if self._meridian.input_data.controls is None:
370
518
  return None
371
519
  controls_scaled_da = _data_array_like(
@@ -400,6 +548,7 @@ class EDAEngine:
400
548
 
401
549
  @functools.cached_property
402
550
  def media_raw_da(self) -> xr.DataArray | None:
551
+ """Returns the raw media data array."""
403
552
  if self._meridian.input_data.media is None:
404
553
  return None
405
554
  raw_media_da = self._truncate_media_time(self._meridian.input_data.media)
@@ -408,6 +557,7 @@ class EDAEngine:
408
557
 
409
558
  @functools.cached_property
410
559
  def media_scaled_da(self) -> xr.DataArray | None:
560
+ """Returns the scaled media data array."""
411
561
  if self._meridian.input_data.media is None:
412
562
  return None
413
563
  media_scaled_da = _data_array_like(
@@ -485,6 +635,7 @@ class EDAEngine:
485
635
 
486
636
  @functools.cached_property
487
637
  def organic_media_raw_da(self) -> xr.DataArray | None:
638
+ """Returns the raw organic media data array."""
488
639
  if self._meridian.input_data.organic_media is None:
489
640
  return None
490
641
  raw_organic_media_da = self._truncate_media_time(
@@ -495,6 +646,7 @@ class EDAEngine:
495
646
 
496
647
  @functools.cached_property
497
648
  def organic_media_scaled_da(self) -> xr.DataArray | None:
649
+ """Returns the scaled organic media data array."""
498
650
  if self._meridian.input_data.organic_media is None:
499
651
  return None
500
652
  organic_media_scaled_da = _data_array_like(
@@ -540,6 +692,7 @@ class EDAEngine:
540
692
 
541
693
  @functools.cached_property
542
694
  def non_media_scaled_da(self) -> xr.DataArray | None:
695
+ """Returns the scaled non-media treatments data array."""
543
696
  if self._meridian.input_data.non_media_treatments is None:
544
697
  return None
545
698
  non_media_scaled_da = _data_array_like(
@@ -615,12 +768,14 @@ class EDAEngine:
615
768
 
616
769
  @property
617
770
  def reach_raw_da(self) -> xr.DataArray | None:
771
+ """Returns the raw reach data array."""
618
772
  if self._rf_data is None:
619
773
  return None
620
774
  return self._rf_data.reach_raw_da
621
775
 
622
776
  @property
623
777
  def reach_scaled_da(self) -> xr.DataArray | None:
778
+ """Returns the scaled reach data array."""
624
779
  if self._rf_data is None:
625
780
  return None
626
781
  return self._rf_data.reach_scaled_da # pytype: disable=attribute-error
@@ -641,6 +796,7 @@ class EDAEngine:
641
796
 
642
797
  @property
643
798
  def frequency_da(self) -> xr.DataArray | None:
799
+ """Returns the frequency data array."""
644
800
  if self._rf_data is None:
645
801
  return None
646
802
  return self._rf_data.frequency_da # pytype: disable=attribute-error
@@ -654,19 +810,21 @@ class EDAEngine:
654
810
 
655
811
  @property
656
812
  def rf_impressions_raw_da(self) -> xr.DataArray | None:
813
+ """Returns the raw RF impressions data array."""
657
814
  if self._rf_data is None:
658
815
  return None
659
- return self._rf_data.rf_impressions_raw_da
816
+ return self._rf_data.rf_impressions_raw_da # pytype: disable=attribute-error
660
817
 
661
818
  @property
662
819
  def national_rf_impressions_raw_da(self) -> xr.DataArray | None:
663
820
  """Returns the national raw RF impressions data array."""
664
821
  if self._rf_data is None:
665
822
  return None
666
- return self._rf_data.national_rf_impressions_raw_da
823
+ return self._rf_data.national_rf_impressions_raw_da # pytype: disable=attribute-error
667
824
 
668
825
  @property
669
826
  def rf_impressions_scaled_da(self) -> xr.DataArray | None:
827
+ """Returns the scaled RF impressions data array."""
670
828
  if self._rf_data is None:
671
829
  return None
672
830
  return self._rf_data.rf_impressions_scaled_da
@@ -690,12 +848,14 @@ class EDAEngine:
690
848
 
691
849
  @property
692
850
  def organic_reach_raw_da(self) -> xr.DataArray | None:
851
+ """Returns the raw organic reach data array."""
693
852
  if self._organic_rf_data is None:
694
853
  return None
695
854
  return self._organic_rf_data.reach_raw_da
696
855
 
697
856
  @property
698
857
  def organic_reach_scaled_da(self) -> xr.DataArray | None:
858
+ """Returns the scaled organic reach data array."""
699
859
  if self._organic_rf_data is None:
700
860
  return None
701
861
  return self._organic_rf_data.reach_scaled_da # pytype: disable=attribute-error
@@ -716,6 +876,7 @@ class EDAEngine:
716
876
 
717
877
  @property
718
878
  def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
879
+ """Returns the scaled organic RF impressions data array."""
719
880
  if self._organic_rf_data is None:
720
881
  return None
721
882
  return self._organic_rf_data.rf_impressions_scaled_da
@@ -729,6 +890,7 @@ class EDAEngine:
729
890
 
730
891
  @property
731
892
  def organic_frequency_da(self) -> xr.DataArray | None:
893
+ """Returns the organic frequency data array."""
732
894
  if self._organic_rf_data is None:
733
895
  return None
734
896
  return self._organic_rf_data.frequency_da # pytype: disable=attribute-error
@@ -742,6 +904,7 @@ class EDAEngine:
742
904
 
743
905
  @property
744
906
  def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
907
+ """Returns the raw organic RF impressions data array."""
745
908
  if self._organic_rf_data is None:
746
909
  return None
747
910
  return self._organic_rf_data.rf_impressions_raw_da
@@ -755,6 +918,7 @@ class EDAEngine:
755
918
 
756
919
  @functools.cached_property
757
920
  def geo_population_da(self) -> xr.DataArray | None:
921
+ """Returns the geo population data array."""
758
922
  if self._is_national_data:
759
923
  return None
760
924
  return xr.DataArray(
@@ -766,6 +930,7 @@ class EDAEngine:
766
930
 
767
931
  @functools.cached_property
768
932
  def kpi_scaled_da(self) -> xr.DataArray:
933
+ """Returns the scaled KPI data array."""
769
934
  scaled_kpi_da = _data_array_like(
770
935
  da=self._meridian.input_data.kpi,
771
936
  values=self._meridian.kpi_scaled,
@@ -887,6 +1052,22 @@ class EDAEngine:
887
1052
  da.name = constants.NATIONAL_TREATMENT_CONTROL_SCALED
888
1053
  return da
889
1054
 
1055
+ @functools.cached_property
1056
+ def treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1057
+ """Returns a Dataset of scaled treatments excluding non-media."""
1058
+ return self.treatment_control_scaled_ds.drop_dims(
1059
+ [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1060
+ errors='ignore',
1061
+ )
1062
+
1063
+ @functools.cached_property
1064
+ def national_treatments_without_non_media_scaled_ds(self) -> xr.Dataset:
1065
+ """Returns a Dataset of national scaled treatments excluding non-media."""
1066
+ return self.national_treatment_control_scaled_ds.drop_dims(
1067
+ [constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
1068
+ errors='ignore',
1069
+ )
1070
+
890
1071
  @functools.cached_property
891
1072
  def all_reach_scaled_da(self) -> xr.DataArray | None:
892
1073
  """Returns a DataArray containing all scaled reach data.
@@ -993,6 +1174,30 @@ class EDAEngine:
993
1174
  da.name = constants.NATIONAL_ALL_FREQUENCY
994
1175
  return da
995
1176
 
1177
+ @functools.cached_property
1178
+ def paid_raw_media_units_ds(self) -> xr.Dataset:
1179
+ to_merge = [
1180
+ da
1181
+ for da in [
1182
+ self.media_raw_da,
1183
+ self.rf_impressions_raw_da,
1184
+ ]
1185
+ if da is not None
1186
+ ]
1187
+ return xr.merge(to_merge, join='inner')
1188
+
1189
+ @functools.cached_property
1190
+ def national_paid_raw_media_units_ds(self) -> xr.Dataset:
1191
+ to_merge = [
1192
+ da
1193
+ for da in [
1194
+ self.national_media_raw_da,
1195
+ self.national_rf_impressions_raw_da,
1196
+ ]
1197
+ if da is not None
1198
+ ]
1199
+ return xr.merge(to_merge, join='inner')
1200
+
996
1201
  @property
997
1202
  def _critical_checks(
998
1203
  self,
@@ -1430,8 +1635,8 @@ class EDAEngine:
1430
1635
  """
1431
1636
  if self._is_national_data:
1432
1637
  return self.check_national_pairwise_corr()
1433
- else:
1434
- return self.check_geo_pairwise_corr()
1638
+
1639
+ return self.check_geo_pairwise_corr()
1435
1640
 
1436
1641
  def _check_std(
1437
1642
  self,
@@ -1439,10 +1644,11 @@ class EDAEngine:
1439
1644
  level: eda_outcome.AnalysisLevel,
1440
1645
  zero_std_message: str,
1441
1646
  ) -> tuple[
1442
- Optional[eda_outcome.EDAFinding], eda_outcome.StandardDeviationArtifact
1647
+ eda_outcome.EDAFinding | None, eda_outcome.StandardDeviationArtifact
1443
1648
  ]:
1444
1649
  """Helper to check standard deviation."""
1445
- std_ds, outlier_df = _calculate_std(data)
1650
+ std_ds = _calculate_std(data)
1651
+ outlier_df = _calculate_outliers(data)
1446
1652
 
1447
1653
  finding = None
1448
1654
  if (std_ds[_STD_WITHOUT_OUTLIERS_VAR_NAME] < _STD_THRESHOLD).any():
@@ -1631,8 +1837,8 @@ class EDAEngine:
1631
1837
  """
1632
1838
  if self._is_national_data:
1633
1839
  return self.check_national_std()
1634
- else:
1635
- return self.check_geo_std()
1840
+
1841
+ return self.check_geo_std()
1636
1842
 
1637
1843
  def check_geo_vif(self) -> eda_outcome.EDAOutcome[eda_outcome.VIFArtifact]:
1638
1844
  """Computes geo-level variance inflation factor among treatments and controls."""
@@ -1783,8 +1989,8 @@ class EDAEngine:
1783
1989
  """
1784
1990
  if self._is_national_data:
1785
1991
  return self.check_national_vif()
1786
- else:
1787
- return self.check_geo_vif()
1992
+
1993
+ return self.check_geo_vif()
1788
1994
 
1789
1995
  @property
1790
1996
  def kpi_has_variability(self) -> bool:
@@ -1821,6 +2027,60 @@ class EDAEngine:
1821
2027
  analysis_artifacts=[self._overall_scaled_kpi_invariability_artifact],
1822
2028
  )
1823
2029
 
2030
+ def check_geo_cost_per_media_unit(
2031
+ self,
2032
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
2033
+ """Checks if the cost per media unit is valid for geo data.
2034
+
2035
+ Returns:
2036
+ An EDAOutcome object with findings and result values.
2037
+
2038
+ Raises:
2039
+ GeoLevelCheckOnNationalModelError: If the check is called for a national
2040
+ model.
2041
+ """
2042
+ if self._is_national_data:
2043
+ raise GeoLevelCheckOnNationalModelError(
2044
+ 'check_geo_cost_per_media_unit is not supported for national models.'
2045
+ )
2046
+ return _check_cost_per_media_unit(
2047
+ self.all_spend_ds,
2048
+ self.paid_raw_media_units_ds,
2049
+ eda_outcome.AnalysisLevel.GEO,
2050
+ )
2051
+
2052
+ def check_national_cost_per_media_unit(
2053
+ self,
2054
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
2055
+ """Checks if the cost per media unit is valid for national data.
2056
+
2057
+ Returns:
2058
+ An EDAOutcome object with findings and result values.
2059
+ """
2060
+ return _check_cost_per_media_unit(
2061
+ self.national_all_spend_ds,
2062
+ self.national_paid_raw_media_units_ds,
2063
+ eda_outcome.AnalysisLevel.NATIONAL,
2064
+ )
2065
+
2066
+ def check_cost_per_media_unit(
2067
+ self,
2068
+ ) -> eda_outcome.EDAOutcome[eda_outcome.CostPerMediaUnitArtifact]:
2069
+ """Checks if the cost per media unit is valid.
2070
+
2071
+ This function checks the following conditions:
2072
+ 1. cost == 0 and media unit > 0.
2073
+ 2. cost > 0 and media unit == 0.
2074
+ 3. cost_per_media_unit has outliers.
2075
+
2076
+ Returns:
2077
+ An EDAOutcome object with findings and result values.
2078
+ """
2079
+ if self._is_national_data:
2080
+ return self.check_national_cost_per_media_unit()
2081
+
2082
+ return self.check_geo_cost_per_media_unit()
2083
+
1824
2084
  def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]:
1825
2085
  """Runs all critical EDA checks.
1826
2086
 
@@ -1836,7 +2096,9 @@ class EDAEngine:
1836
2096
  except Exception as e: # pylint: disable=broad-except
1837
2097
  error_finding = eda_outcome.EDAFinding(
1838
2098
  severity=eda_outcome.EDASeverity.ERROR,
1839
- explanation=f'An error occurred during check {check.__name__}: {e}',
2099
+ explanation=(
2100
+ f'An error occurred during running {check.__name__}: {e!r}'
2101
+ ),
1840
2102
  )
1841
2103
  outcomes.append(
1842
2104
  eda_outcome.EDAOutcome(
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "StandardDeviationArtifact",
30
30
  "VIFArtifact",
31
31
  "KpiInvariabilityArtifact",
32
+ "CostPerMediaUnitArtifact",
32
33
  "EDACheckType",
33
34
  "ArtifactType",
34
35
  "EDAOutcome",
@@ -101,7 +102,8 @@ class PairwiseCorrArtifact(AnalysisArtifact):
101
102
  Attributes:
102
103
  corr_matrix: Pairwise correlation matrix.
103
104
  extreme_corr_var_pairs: DataFrame of variable pairs exceeding the
104
- correlation threshold.
105
+ correlation threshold. Includes 'correlation' and 'abs_correlation'
106
+ columns, and is sorted by 'abs_correlation' in descending order.
105
107
  extreme_corr_threshold: The threshold used to identify extreme correlation
106
108
  pairs.
107
109
  """
@@ -153,6 +155,23 @@ class KpiInvariabilityArtifact(AnalysisArtifact):
153
155
  kpi_stdev: xr.DataArray
154
156
 
155
157
 
158
+ @dataclasses.dataclass(frozen=True)
159
+ class CostPerMediaUnitArtifact(AnalysisArtifact):
160
+ """Encapsulates artifacts from a Cost per Media Unit analysis.
161
+
162
+ Attributes:
163
+ cost_per_media_unit_da: DataArray of cost per media unit.
164
+ cost_media_unit_inconsistency_df: DataFrame of time periods where cost and
165
+ media units are inconsistent (e.g., zero cost with positive media units,
166
+ or positive cost with zero media units).
167
+ outlier_df: DataFrame with outliers of cost per media unit.
168
+ """
169
+
170
+ cost_per_media_unit_da: xr.DataArray
171
+ cost_media_unit_inconsistency_df: pd.DataFrame
172
+ outlier_df: pd.DataFrame
173
+
174
+
156
175
  @enum.unique
157
176
  class EDACheckType(enum.Enum):
158
177
  """Enumeration for the type of an EDA check."""
@@ -161,6 +180,7 @@ class EDACheckType(enum.Enum):
161
180
  STANDARD_DEVIATION = enum.auto()
162
181
  MULTICOLLINEARITY = enum.auto()
163
182
  KPI_INVARIABILITY = enum.auto()
183
+ COST_PER_MEDIA_UNIT = enum.auto()
164
184
 
165
185
 
166
186
  ArtifactType = typing.TypeVar("ArtifactType", bound="AnalysisArtifact")
meridian/model/knots.py CHANGED
@@ -19,6 +19,7 @@ from collections.abc import Collection, Sequence
19
19
  import copy
20
20
  import dataclasses
21
21
  import math
22
+ import pprint
22
23
  from typing import Any
23
24
  from meridian import constants
24
25
  from meridian.data import input_data
@@ -289,6 +290,22 @@ class AKS:
289
290
  penalty = geo_scaling_factor * base_penalty
290
291
 
291
292
  aspline = self.aspline(x=x, y=y, knots=knots, penalty=penalty)
293
+ # Ensure defined knot range covers at least one of the available knot sets.
294
+ available_knots_lengths = np.unique(
295
+ np.fromiter(
296
+ (len(x) for x in aspline[constants.KNOTS_SELECTED]), dtype=int
297
+ )
298
+ ).tolist()
299
+ if not any(
300
+ min_internal_knots <= k <= max_internal_knots
301
+ for k in available_knots_lengths
302
+ ):
303
+ raise ValueError(
304
+ f'The range [{min_internal_knots}, {max_internal_knots}] does not'
305
+ ' contain any of the available knot lengths:'
306
+ f' {pprint.pformat(available_knots_lengths)}'
307
+ )
308
+
292
309
  n_knots = np.array([len(x) for x in aspline[constants.KNOTS_SELECTED]])
293
310
  feasible_idx = np.where(
294
311
  (n_knots >= min_internal_knots) & (n_knots <= max_internal_knots)
@@ -1,5 +1,5 @@
1
1
  {#
2
- Copyright 2024 Google LLC
2
+ Copyright 2025 Google LLC
3
3
 
4
4
  Licensed under the Apache License, Version 2.0 (the "License");
5
5
  you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  {#
2
- Copyright 2024 Google LLC
2
+ Copyright 2025 Google LLC
3
3
 
4
4
  Licensed under the Apache License, Version 2.0 (the "License");
5
5
  you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  {#
2
- Copyright 2024 Google LLC
2
+ Copyright 2025 Google LLC
3
3
 
4
4
  Licensed under the Apache License, Version 2.0 (the "License");
5
5
  you may not use this file except in compliance with the License.
@@ -88,7 +88,7 @@ AXIS_CONFIG = immutabledict.immutabledict({
88
88
 
89
89
 
90
90
  _template_loader = jinja2.FileSystemLoader(
91
- os.path.abspath(os.path.dirname(__file__)) + '/templates'
91
+ os.path.abspath(os.path.dirname(__file__))
92
92
  )
93
93
 
94
94
 
@@ -206,6 +206,17 @@ def create_template_env() -> jinja2.Environment:
206
206
  )
207
207
 
208
208
 
209
+ def create_summary_html(
210
+ template_env: jinja2.Environment,
211
+ title: str,
212
+ cards: Sequence[str],
213
+ ) -> str:
214
+ """Creates the HTML snippet for the summary page."""
215
+ return template_env.get_template('summary.html.jinja').render(
216
+ title=title, cards=cards
217
+ )
218
+
219
+
209
220
  def create_card_html(
210
221
  template_env: jinja2.Environment,
211
222
  card_spec: CardSpec,