google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
  2. google_meridian-1.3.0.dist-info/RECORD +62 -0
  3. meridian/analysis/__init__.py +2 -0
  4. meridian/analysis/analyzer.py +280 -142
  5. meridian/analysis/formatter.py +2 -2
  6. meridian/analysis/optimizer.py +353 -169
  7. meridian/analysis/review/__init__.py +20 -0
  8. meridian/analysis/review/checks.py +721 -0
  9. meridian/analysis/review/configs.py +110 -0
  10. meridian/analysis/review/constants.py +40 -0
  11. meridian/analysis/review/results.py +544 -0
  12. meridian/analysis/review/reviewer.py +186 -0
  13. meridian/analysis/summarizer.py +14 -12
  14. meridian/analysis/templates/chips.html.jinja +12 -0
  15. meridian/analysis/test_utils.py +27 -5
  16. meridian/analysis/visualizer.py +45 -50
  17. meridian/backend/__init__.py +698 -55
  18. meridian/backend/config.py +75 -16
  19. meridian/backend/test_utils.py +127 -1
  20. meridian/constants.py +52 -11
  21. meridian/data/input_data.py +7 -2
  22. meridian/data/test_utils.py +5 -3
  23. meridian/mlflow/autolog.py +2 -2
  24. meridian/model/__init__.py +1 -0
  25. meridian/model/adstock_hill.py +10 -9
  26. meridian/model/eda/__init__.py +3 -0
  27. meridian/model/eda/constants.py +21 -0
  28. meridian/model/eda/eda_engine.py +1580 -84
  29. meridian/model/eda/eda_outcome.py +200 -0
  30. meridian/model/eda/eda_spec.py +84 -0
  31. meridian/model/eda/meridian_eda.py +220 -0
  32. meridian/model/knots.py +56 -50
  33. meridian/model/media.py +10 -8
  34. meridian/model/model.py +79 -16
  35. meridian/model/model_test_data.py +53 -9
  36. meridian/model/posterior_sampler.py +398 -391
  37. meridian/model/prior_distribution.py +114 -39
  38. meridian/model/prior_sampler.py +146 -90
  39. meridian/model/spec.py +7 -8
  40. meridian/model/transformers.py +16 -8
  41. meridian/version.py +1 -1
  42. google_meridian-1.2.0.dist-info/RECORD +0 -52
  43. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
  44. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
  45. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,186 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Implementation of the runner of the Model Quality Checks."""
16
+
17
+ import typing
18
+
19
+ import immutabledict
20
+ from meridian import constants
21
+ from meridian.analysis import analyzer as analyzer_module
22
+ from meridian.analysis.review import checks
23
+ from meridian.analysis.review import configs
24
+ from meridian.analysis.review import results
25
+ from meridian.model import prior_distribution
26
+
27
+
28
+ CheckType = typing.Type[checks.BaseCheck]
29
+ ConfigInstance = configs.BaseConfig
30
+ ChecksBattery = immutabledict.immutabledict[CheckType, ConfigInstance]
31
+
32
+ _DEFAULT_POST_CONVERGENCE_CHECKS: ChecksBattery = immutabledict.immutabledict({
33
+ checks.BaselineCheck: configs.BaselineConfig(),
34
+ checks.BayesianPPPCheck: configs.BayesianPPPConfig(),
35
+ checks.GoodnessOfFitCheck: configs.GoodnessOfFitConfig(),
36
+ checks.PriorPosteriorShiftCheck: configs.PriorPosteriorShiftConfig(),
37
+ checks.ROIConsistencyCheck: configs.ROIConsistencyConfig(),
38
+ })
39
+
40
+
41
+ class ModelReviewer:
42
+ """Executes a series of quality checks on a Meridian model.
43
+
44
+ The reviewer first runs a convergence check. If the model has converged, it
45
+ proceeds to run a battery of post-convergence checks.
46
+
47
+ The default battery of post-convergence checks includes:
48
+ - BaselineCheck
49
+ - BayesianPPPCheck
50
+ - GoodnessOfFitCheck
51
+ - PriorPosteriorShiftCheck
52
+ - ROIConsistencyCheck
53
+ Each with its default configuration.
54
+
55
+ This battery of checks can be customized by passing a dictionary to the
56
+ `post_convergence_checks` argument of the constructor, mapping check
57
+ classes to their configuration instances. For example, to run only the
58
+ BaselineCheck with a non-default configuration:
59
+
60
+ ```python
61
+ my_checks = {
62
+ checks.BaselineCheck: configs.BaselineConfig(
63
+ negative_baseline_prob_review_threshold=0.1,
64
+ negative_baseline_prob_fail_threshold=0.5,
65
+ )
66
+ }
67
+ reviewer = ModelReviewer(meridian_model, post_convergence_checks=my_checks)
68
+ ```
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ meridian,
74
+ post_convergence_checks: ChecksBattery = _DEFAULT_POST_CONVERGENCE_CHECKS,
75
+ ):
76
+ self._meridian = meridian
77
+ self._results: list[results.CheckResult] = []
78
+ self._analyzer = analyzer_module.Analyzer(meridian)
79
+ self._post_convergence_checks = post_convergence_checks
80
+
81
+ def _run_and_handle(self, check_class, config):
82
+ instance = check_class(self._meridian, self._analyzer, config) # pytype: disable=not-instantiable
83
+ self._results.append(instance.run())
84
+
85
+ def _uses_roi_priors(self):
86
+ """Checks if the model uses ROI priors."""
87
+ return (
88
+ self._meridian.n_media_channels > 0
89
+ and self._meridian.model_spec.effective_media_prior_type
90
+ == constants.TREATMENT_PRIOR_TYPE_ROI
91
+ ) or (
92
+ self._meridian.n_rf_channels > 0
93
+ and self._meridian.model_spec.effective_rf_prior_type
94
+ == constants.TREATMENT_PRIOR_TYPE_ROI
95
+ )
96
+
97
+ def _has_custom_roi_priors(self):
98
+ """Checks if the model uses custom ROI priors."""
99
+ default_distribution = prior_distribution.PriorDistribution()
100
+ if (
101
+ self._meridian.n_media_channels > 0
102
+ and self._meridian.model_spec.effective_media_prior_type
103
+ == constants.TREATMENT_PRIOR_TYPE_ROI
104
+ ):
105
+ if not prior_distribution.distributions_are_equal(
106
+ self._meridian.model_spec.prior.roi_m, default_distribution.roi_m
107
+ ):
108
+ return True
109
+ if (
110
+ self._meridian.n_rf_channels > 0
111
+ and self._meridian.model_spec.effective_rf_prior_type
112
+ == constants.TREATMENT_PRIOR_TYPE_ROI
113
+ ):
114
+ if not prior_distribution.distributions_are_equal(
115
+ self._meridian.model_spec.prior.roi_rf, default_distribution.roi_rf
116
+ ):
117
+ return True
118
+ return False
119
+
120
+ def run(self) -> results.ReviewSummary:
121
+ """Executes all checks and generates the final summary."""
122
+ self._results.clear()
123
+ self._run_and_handle(checks.ConvergenceCheck, configs.ConvergenceConfig())
124
+
125
+ # Stop if the model did not converge.
126
+ if (
127
+ self._results
128
+ and self._results[0].case is results.ConvergenceCases.NOT_CONVERGED
129
+ ):
130
+ return results.ReviewSummary(
131
+ overall_status=results.Status.FAIL,
132
+ summary_message=(
133
+ "Failed: Model did not converge. Other checks were skipped."
134
+ ),
135
+ results=self._results,
136
+ )
137
+
138
+ # Run all other checks in sequence.
139
+ for check_class, config in self._post_convergence_checks.items():
140
+ if (
141
+ check_class == checks.PriorPosteriorShiftCheck
142
+ and not self._uses_roi_priors()
143
+ ):
144
+ # Skip the Prior-Posterior Shift check if no ROI priors are used.
145
+ continue
146
+ if (
147
+ check_class == checks.ROIConsistencyCheck
148
+ and not self._has_custom_roi_priors()
149
+ ):
150
+ # Skip the ROI Consistency check if no custom ROI priors are provided.
151
+ continue
152
+ self._run_and_handle(check_class, config)
153
+
154
+ # Determine the final overall status.
155
+ has_failures = any(
156
+ res.case.status is results.Status.FAIL for res in self._results
157
+ )
158
+ has_reviews = any(
159
+ res.case.status is results.Status.REVIEW for res in self._results
160
+ )
161
+
162
+ if has_failures and has_reviews:
163
+ overall_status = results.Status.FAIL
164
+ summary_message = (
165
+ "Failed: Quality issues were detected in your model. Follow"
166
+ " recommendations to address any failed checks and review"
167
+ " results to determine if further action is needed."
168
+ )
169
+ elif has_failures:
170
+ overall_status = results.Status.FAIL
171
+ summary_message = (
172
+ "Failed: Quality issues were detected in your model. Address failed"
173
+ " checks before proceeding."
174
+ )
175
+ elif has_reviews:
176
+ overall_status = results.Status.PASS
177
+ summary_message = "Passed with reviews: Review is needed."
178
+ else:
179
+ overall_status = results.Status.PASS
180
+ summary_message = "Passed: No major quality issues were identified."
181
+
182
+ return results.ReviewSummary(
183
+ overall_status=overall_status,
184
+ summary_message=summary_message,
185
+ results=self._results,
186
+ )
@@ -20,6 +20,7 @@ import os
20
20
 
21
21
  import jinja2
22
22
  from meridian import constants as c
23
+ from meridian.analysis import analyzer
23
24
  from meridian.analysis import formatter
24
25
  from meridian.analysis import summary_text
25
26
  from meridian.analysis import visualizer
@@ -59,17 +60,18 @@ RESPONSE_CURVES_CARD_SPEC = formatter.CardSpec(
59
60
  class Summarizer:
60
61
  """Generates HTML summary visualizations from the model fitting."""
61
62
 
62
- def __init__(self, meridian: model.Meridian):
63
+ def __init__(self, meridian: model.Meridian, use_kpi: bool = False):
63
64
  """Initialize the visualizer classes that are not time-dependent."""
64
65
  self._meridian = meridian
66
+ self._use_kpi = analyzer.Analyzer(meridian)._use_kpi(use_kpi)
65
67
 
66
68
  @functools.cached_property
67
69
  def _model_fit(self):
68
- return visualizer.ModelFit(self._meridian)
70
+ return visualizer.ModelFit(self._meridian, use_kpi=self._use_kpi)
69
71
 
70
72
  @functools.cached_property
71
73
  def _model_diagnostics(self):
72
- return visualizer.ModelDiagnostics(self._meridian)
74
+ return visualizer.ModelDiagnostics(self._meridian, use_kpi=self._use_kpi)
73
75
 
74
76
  def output_model_results_summary(
75
77
  self,
@@ -153,12 +155,14 @@ class Summarizer:
153
155
  ) -> Sequence[str]:
154
156
  """Creates the HTML snippets for cards in the summary page."""
155
157
  media_summary = visualizer.MediaSummary(
156
- self._meridian, selected_times=selected_times
158
+ self._meridian, selected_times=selected_times, use_kpi=self._use_kpi
159
+ )
160
+ media_effects = visualizer.MediaEffects(
161
+ self._meridian, use_kpi=self._use_kpi
157
162
  )
158
- media_effects = visualizer.MediaEffects(self._meridian)
159
163
  reach_frequency = (
160
164
  visualizer.ReachAndFrequency(
161
- self._meridian, selected_times=selected_times
165
+ self._meridian, selected_times=selected_times, use_kpi=self._use_kpi
162
166
  )
163
167
  if self._meridian.n_rf_channels > 0
164
168
  else None
@@ -168,7 +172,9 @@ class Summarizer:
168
172
  template_env, selected_times=selected_times
169
173
  ),
170
174
  self._create_outcome_contrib_card_html(
171
- template_env, media_summary, selected_times=selected_times
175
+ template_env,
176
+ media_summary,
177
+ selected_times=selected_times,
172
178
  ),
173
179
  self._create_performance_breakdown_card_html(
174
180
  template_env, media_summary
@@ -525,8 +531,4 @@ class Summarizer:
525
531
  ).optimal_frequency
526
532
 
527
533
  def _kpi_or_revenue(self) -> str:
528
- if self._meridian.input_data.revenue_per_kpi is not None:
529
- outcome_str = c.REVENUE
530
- else:
531
- outcome_str = c.KPI.upper()
532
- return outcome_str
534
+ return c.KPI.upper() if self._use_kpi else c.REVENUE
@@ -18,4 +18,16 @@ limitations under the License.
18
18
  <chip>
19
19
  Time period: {{ start_date }} - {{ end_date }}
20
20
  </chip>
21
+ {% if selected_geos %}
22
+ {%- set geo_count = selected_geos | length %}
23
+ {%- set full_list = selected_geos | join(', ') %}
24
+ <chip title="Selected: {{ full_list }}">
25
+ Selected geos:
26
+ {%- if geo_count > 5 %}
27
+ {{ selected_geos[:5] | join(', ') }}, + {{ geo_count - 5 }} more
28
+ {%- else %}
29
+ {{ full_list }}
30
+ {%- endif %}
31
+ </chip>
32
+ {% endif %}
21
33
  </chips>
@@ -726,8 +726,6 @@ INC_OUTCOME_NON_PAID_USE_PRIOR = np.array([[
726
726
  2.224e05,
727
727
  ],
728
728
  ]])
729
-
730
-
731
729
  INC_OUTCOME_NON_PAID_USE_POSTERIOR = np.array([
732
730
  [
733
731
  [
@@ -1055,7 +1053,6 @@ INC_OUTCOME_NON_PAID_USE_POSTERIOR = np.array([
1055
1053
  ],
1056
1054
  ])
1057
1055
 
1058
-
1059
1056
  INC_OUTCOME_NON_MEDIA_MAX = np.array([
1060
1057
  [
1061
1058
  [
@@ -1282,7 +1279,6 @@ INC_OUTCOME_NON_MEDIA_MAX = np.array([
1282
1279
  ],
1283
1280
  ],
1284
1281
  ])
1285
-
1286
1282
  INC_OUTCOME_NON_MEDIA_MIX = np.array([
1287
1283
  [
1288
1284
  [
@@ -1509,7 +1505,6 @@ INC_OUTCOME_NON_MEDIA_MIX = np.array([
1509
1505
  ],
1510
1506
  ],
1511
1507
  ])
1512
-
1513
1508
  INC_OUTCOME_NON_MEDIA_FIXED = np.array([
1514
1509
  [
1515
1510
  [
@@ -1737,6 +1732,33 @@ INC_OUTCOME_NON_MEDIA_FIXED = np.array([
1737
1732
  ],
1738
1733
  ])
1739
1734
 
1735
+ EXP_OUTCOME_MEDIA_AND_RF = np.array([
1736
+ [
1737
+ 32254.154,
1738
+ 32296.66,
1739
+ 32331.918,
1740
+ 32391.336,
1741
+ 32797.367,
1742
+ 33077.375,
1743
+ 33469.82,
1744
+ 33316.707,
1745
+ 33251.33,
1746
+ 33261.21,
1747
+ ],
1748
+ [
1749
+ 53146.668,
1750
+ 53151.16,
1751
+ 53404.375,
1752
+ 53814.035,
1753
+ 54136.48,
1754
+ 54141.82,
1755
+ 54581.93,
1756
+ 54777.164,
1757
+ 54912.758,
1758
+ 54915.656,
1759
+ ],
1760
+ ])
1761
+
1740
1762
  MROI_MEDIA_AND_RF_USE_PRIOR = np.array([[
1741
1763
  [0.3399, 1.7045, 3.1300, 2.7845, 0.3523],
1742
1764
  [0.4540, 1.3445, 0.9966, 0.2985, 0.4917],
@@ -46,9 +46,10 @@ alt.data_transformers.disable_max_rows()
46
46
  class ModelDiagnostics:
47
47
  """Generates model diagnostics plots from the Meridian model fitting."""
48
48
 
49
- def __init__(self, meridian: model.Meridian):
49
+ def __init__(self, meridian: model.Meridian, use_kpi: bool = False):
50
50
  self._meridian = meridian
51
51
  self._analyzer = analyzer.Analyzer(meridian)
52
+ self._use_kpi = self._analyzer._use_kpi(use_kpi)
52
53
 
53
54
  @functools.lru_cache(maxsize=128)
54
55
  def _predictive_accuracy_dataset(
@@ -82,6 +83,7 @@ class ModelDiagnostics:
82
83
  return self._analyzer.predictive_accuracy(
83
84
  selected_geos=selected_geos_list,
84
85
  selected_times=selected_times_list,
86
+ use_kpi=self._use_kpi,
85
87
  batch_size=batch_size,
86
88
  )
87
89
 
@@ -366,19 +368,23 @@ class ModelFit:
366
368
  def __init__(
367
369
  self,
368
370
  meridian: model.Meridian,
371
+ use_kpi: bool = False,
369
372
  confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
370
373
  ):
371
374
  """Initializes the dataset based on the model and confidence level.
372
375
 
373
376
  Args:
374
377
  meridian: Media mix model with the raw data from the model fitting.
378
+ use_kpi: If `True`, plots the incremental KPI. Otherwise, plots the
379
+ incremental revenue using the revenue per KPI (if available).
375
380
  confidence_level: Confidence level for expected outcome credible intervals
376
381
  represented as a value between zero and one. Default is `0.9`.
377
382
  """
378
383
  self._meridian = meridian
379
384
  self._analyzer = analyzer.Analyzer(meridian)
385
+ self._use_kpi = self._analyzer._use_kpi(use_kpi)
380
386
  self._model_fit_data = self._analyzer.expected_vs_actual_data(
381
- confidence_level=confidence_level
387
+ use_kpi=self._use_kpi, confidence_level=confidence_level
382
388
  )
383
389
 
384
390
  @property
@@ -430,11 +436,7 @@ class ModelFit:
430
436
  Returns:
431
437
  An Altair plot showing the model fit.
432
438
  """
433
- outcome = (
434
- c.REVENUE
435
- if self._meridian.input_data.revenue_per_kpi is not None
436
- else c.KPI.upper()
437
- )
439
+ outcome = c.KPI.upper() if self._use_kpi else c.REVENUE
438
440
  self._validate_times_to_plot(selected_times)
439
441
  self._validate_geos_to_plot(
440
442
  selected_geos, n_top_largest_geos, show_geo_level
@@ -459,10 +461,10 @@ class ModelFit:
459
461
  title = summary_text.EXPECTED_ACTUAL_OUTCOME_CHART_TITLE.format(
460
462
  outcome=outcome
461
463
  )
462
- if self._meridian.input_data.revenue_per_kpi is not None:
463
- y_axis_label = summary_text.REVENUE_LABEL
464
- else:
464
+ if self._use_kpi:
465
465
  y_axis_label = summary_text.KPI_LABEL
466
+ else:
467
+ y_axis_label = summary_text.REVENUE_LABEL
466
468
  plot = (
467
469
  alt.Chart(model_fit_df, width=c.VEGALITE_FACET_EXTRA_LARGE_WIDTH)
468
470
  .mark_line()
@@ -638,7 +640,7 @@ class ReachAndFrequency:
638
640
  self,
639
641
  meridian: model.Meridian,
640
642
  selected_times: Sequence[str] | None = None,
641
- use_kpi: bool | None = None,
643
+ use_kpi: bool = False,
642
644
  ):
643
645
  """Initializes the reach and frequency dataset for the model data.
644
646
 
@@ -651,15 +653,7 @@ class ReachAndFrequency:
651
653
  self._meridian = meridian
652
654
  self._analyzer = analyzer.Analyzer(meridian)
653
655
  self._selected_times = selected_times
654
- # TODO Adapt the mechanisms to choose between KPI and REVENUE
655
- # from Analyzer.
656
- if use_kpi is None:
657
- self._use_kpi = (
658
- meridian.input_data.kpi_type == c.NON_REVENUE
659
- and meridian.input_data.revenue_per_kpi is None
660
- )
661
- else:
662
- self._use_kpi = use_kpi
656
+ self._use_kpi = self._analyzer._use_kpi(use_kpi)
663
657
  self._optimal_frequency_data = self._analyzer.optimal_freq(
664
658
  selected_times=selected_times,
665
659
  use_kpi=self._use_kpi,
@@ -844,6 +838,7 @@ class MediaEffects:
844
838
  self,
845
839
  meridian: model.Meridian,
846
840
  by_reach: bool = True,
841
+ use_kpi: bool = False,
847
842
  ):
848
843
  """Initializes the Media Effects based on the model data and params.
849
844
 
@@ -852,10 +847,13 @@ class MediaEffects:
852
847
  by_reach: For the channel w/ reach and frequency, return the response
853
848
  curves by reach given fixed frequency if true; return the response
854
849
  curves by frequency given fixed reach if false.
850
+ use_kpi: If `True`, calculate the incremental KPI. Otherwise, calculate
851
+ the incremental revenue using the revenue per KPI (if available).
855
852
  """
856
853
  self._meridian = meridian
857
854
  self._analyzer = analyzer.Analyzer(meridian)
858
855
  self._by_reach = by_reach
856
+ self._use_kpi = self._analyzer._use_kpi(use_kpi)
859
857
 
860
858
  @functools.lru_cache(maxsize=128)
861
859
  def response_curves_data(
@@ -891,13 +889,12 @@ class MediaEffects:
891
889
  A Dataset displaying the response curves data.
892
890
  """
893
891
  selected_times_list = list(selected_times) if selected_times else None
894
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
895
892
  return self._analyzer.response_curves(
896
893
  spend_multipliers=list(np.arange(0, 2.2, c.RESPONSE_CURVE_STEP_SIZE)),
897
894
  confidence_level=confidence_level,
898
895
  selected_times=selected_times_list,
899
896
  by_reach=by_reach,
900
- use_kpi=use_kpi,
897
+ use_kpi=self._use_kpi,
901
898
  )
902
899
 
903
900
  @functools.lru_cache(maxsize=128)
@@ -1022,10 +1019,11 @@ class MediaEffects:
1022
1019
  selected_times=selected_times,
1023
1020
  by_reach=by_reach,
1024
1021
  )
1025
- if self._meridian.input_data.revenue_per_kpi is not None:
1026
- y_axis_label = summary_text.INC_OUTCOME_LABEL
1027
- else:
1028
- y_axis_label = summary_text.INC_KPI_LABEL
1022
+ y_axis_label = (
1023
+ summary_text.INC_KPI_LABEL
1024
+ if self._use_kpi
1025
+ else summary_text.INC_OUTCOME_LABEL
1026
+ )
1029
1027
  base = (
1030
1028
  alt.Chart(response_curves_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
1031
1029
  .transform_calculate(
@@ -1404,6 +1402,7 @@ class MediaSummary:
1404
1402
  selected_times: Sequence[str] | None = None,
1405
1403
  marginal_roi_by_reach: bool = True,
1406
1404
  non_media_baseline_values: Sequence[float] | None = None,
1405
+ use_kpi: bool = False,
1407
1406
  ):
1408
1407
  """Initializes the media summary metrics based on the model data and params.
1409
1408
 
@@ -1423,6 +1422,7 @@ class MediaSummary:
1423
1422
  value which will be used as baseline for the given channel. If `None`,
1424
1423
  the values defined with `ModelSpec.non_media_baseline_values` will be
1425
1424
  used.
1425
+ use_kpi: If `True`, use KPI instead of revenue.
1426
1426
  """
1427
1427
  self._meridian = meridian
1428
1428
  self._analyzer = analyzer.Analyzer(meridian)
@@ -1430,6 +1430,7 @@ class MediaSummary:
1430
1430
  self._selected_times = selected_times
1431
1431
  self._marginal_roi_by_reach = marginal_roi_by_reach
1432
1432
  self._non_media_baseline_values = non_media_baseline_values
1433
+ self._use_kpi = self._analyzer._use_kpi(use_kpi)
1433
1434
 
1434
1435
  @property
1435
1436
  def paid_summary_metrics(self):
@@ -1464,7 +1465,7 @@ class MediaSummary:
1464
1465
  return self._analyzer.summary_metrics(
1465
1466
  selected_times=self._selected_times,
1466
1467
  marginal_roi_by_reach=self._marginal_roi_by_reach,
1467
- use_kpi=self._meridian.input_data.revenue_per_kpi is None,
1468
+ use_kpi=self._use_kpi,
1468
1469
  confidence_level=self._confidence_level,
1469
1470
  include_non_paid_channels=False,
1470
1471
  aggregate_times=aggregate_times,
@@ -1497,7 +1498,7 @@ class MediaSummary:
1497
1498
  """
1498
1499
  return self._analyzer.summary_metrics(
1499
1500
  selected_times=self._selected_times,
1500
- use_kpi=self._meridian.input_data.revenue_per_kpi is None,
1501
+ use_kpi=self._use_kpi,
1501
1502
  confidence_level=self._confidence_level,
1502
1503
  include_non_paid_channels=True,
1503
1504
  non_media_baseline_values=self._non_media_baseline_values,
@@ -1509,6 +1510,7 @@ class MediaSummary:
1509
1510
  include_prior: bool = True,
1510
1511
  include_posterior: bool = True,
1511
1512
  include_non_paid_channels: bool = False,
1513
+ currency: str = c.DEFAULT_CURRENCY,
1512
1514
  ) -> pd.DataFrame:
1513
1515
  """Returns a formatted dataframe table of the summary metrics.
1514
1516
 
@@ -1525,6 +1527,7 @@ class MediaSummary:
1525
1527
  reported. If `False`, only the paid channels (media, reach and
1526
1528
  frequency) are included but the summary contains also the metrics
1527
1529
  dependent on spend. Default: `False`.
1530
+ currency: The currency to use for the monetary values. Default: `'$'`.
1528
1531
 
1529
1532
  Returns:
1530
1533
  pandas.DataFrame of formatted summary metrics.
@@ -1534,7 +1537,7 @@ class MediaSummary:
1534
1537
  'At least one of `include_posterior` or `include_prior` must be True.'
1535
1538
  )
1536
1539
 
1537
- use_revenue = self._meridian.input_data.revenue_per_kpi is not None
1540
+ use_revenue = not self._use_kpi
1538
1541
  distribution = [c.PRIOR] * include_prior + [c.POSTERIOR] * include_posterior
1539
1542
 
1540
1543
  percentage_metrics = [
@@ -1607,7 +1610,7 @@ class MediaSummary:
1607
1610
  # Format monetary values.
1608
1611
  for k in monetary_metrics:
1609
1612
  if k in df.columns:
1610
- df[k] = '$' + df[k].astype(str)
1613
+ df[k] = currency + df[k].astype(str)
1611
1614
 
1612
1615
  # Format the model result data variables as central_tendency (ci_lo, ci_hi).
1613
1616
  index_vars = [c.CHANNEL, c.DISTRIBUTION]
@@ -1720,11 +1723,7 @@ class MediaSummary:
1720
1723
  ),
1721
1724
  y=alt.Y(
1722
1725
  f'{c.INCREMENTAL_OUTCOME}:Q',
1723
- title=(
1724
- c.REVENUE.title()
1725
- if self._meridian.input_data.revenue_per_kpi is not None
1726
- else c.KPI.upper()
1727
- ),
1726
+ title=(c.KPI.upper() if self._use_kpi else c.REVENUE.title()),
1728
1727
  axis=alt.Axis(
1729
1728
  ticks=False,
1730
1729
  domain=False,
@@ -1890,11 +1889,7 @@ class MediaSummary:
1890
1889
  Returns:
1891
1890
  An Altair plot showing the contributions per channel.
1892
1891
  """
1893
- outcome = (
1894
- c.REVENUE.title()
1895
- if self._meridian.input_data.revenue_per_kpi is not None
1896
- else c.KPI.upper()
1897
- )
1892
+ outcome = c.KPI.upper() if self._use_kpi else c.REVENUE.title()
1898
1893
  outcome_df = self.contribution_metrics(include_non_paid=True)
1899
1894
  pct = c.PCT_OF_CONTRIBUTION
1900
1895
  value = c.INCREMENTAL_OUTCOME
@@ -1907,7 +1902,7 @@ class MediaSummary:
1907
1902
  num_channels = len(outcome_df[c.CHANNEL])
1908
1903
 
1909
1904
  base = (
1910
- alt.Chart(outcome_df, width=c.VEGALITE_FACET_LARGE_WIDTH)
1905
+ alt.Chart(outcome_df)
1911
1906
  .transform_window(
1912
1907
  sum_outcome=f'sum({c.PCT_OF_CONTRIBUTION})',
1913
1908
  kwargs=f'lead({c.CHANNEL})',
@@ -1923,7 +1918,10 @@ class MediaSummary:
1923
1918
  y=alt.Y(
1924
1919
  f'{c.CHANNEL}:N',
1925
1920
  axis=alt.Axis(
1926
- ticks=False, labelPadding=c.PADDING_10, domain=False
1921
+ ticks=False,
1922
+ labelPadding=c.PADDING_10,
1923
+ domain=False,
1924
+ labelLimit=0,
1927
1925
  ),
1928
1926
  title=None,
1929
1927
  sort=None,
@@ -1966,6 +1964,7 @@ class MediaSummary:
1966
1964
  title=formatter.custom_title_params(
1967
1965
  summary_text.CHANNEL_DRIVERS_CHART_TITLE
1968
1966
  ),
1967
+ width=c.VEGALITE_FACET_LARGE_WIDTH,
1969
1968
  height=c.BAR_SIZE * num_channels
1970
1969
  + c.BAR_SIZE * 2 * c.SCALED_PADDING,
1971
1970
  )
@@ -2028,11 +2027,7 @@ class MediaSummary:
2028
2027
  Returns:
2029
2028
  An Altair plot showing the spend versus outcome percentages per channel.
2030
2029
  """
2031
- outcome = (
2032
- c.REVENUE
2033
- if self._meridian.input_data.revenue_per_kpi is not None
2034
- else c.KPI.upper()
2035
- )
2030
+ outcome = c.KPI.upper() if self._use_kpi else c.REVENUE
2036
2031
  df = self._transform_contribution_spend_metrics()
2037
2032
  domain = [
2038
2033
  f'% {outcome.title() if outcome == c.REVENUE else outcome}',
@@ -2556,10 +2551,10 @@ class MediaSummary:
2556
2551
  A dataframe of spend and outcome percentages and ROI per channel.
2557
2552
  """
2558
2553
  paid_summary_metrics = self.get_paid_summary_metrics()
2559
- if self._meridian.input_data.revenue_per_kpi is not None:
2560
- outcome = summary_text.REVENUE_LABEL
2561
- else:
2554
+ if self._use_kpi:
2562
2555
  outcome = summary_text.KPI_LABEL
2556
+ else:
2557
+ outcome = summary_text.REVENUE_LABEL
2563
2558
  total_media_outcome = (
2564
2559
  paid_summary_metrics[c.INCREMENTAL_OUTCOME]
2565
2560
  .sel(