google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -77,6 +77,7 @@ class Summarizer:
77
77
  filepath: str,
78
78
  start_date: tc.Date = None,
79
79
  end_date: tc.Date = None,
80
+ use_kpi: bool = False,
80
81
  ):
81
82
  """Generates and saves the HTML results summary output.
82
83
 
@@ -86,15 +87,18 @@ class Summarizer:
86
87
  start_date: Optional start date selector, *inclusive*, in _yyyy-mm-dd_
87
88
  format.
88
89
  end_date: Optional end date selector, *inclusive* in _yyyy-mm-dd_ format.
90
+ use_kpi: If `True`, calculate the incremental KPI. Otherwise, calculate
91
+ the incremental revenue using the revenue per KPI (if available).
89
92
  """
90
93
  os.makedirs(filepath, exist_ok=True)
91
94
  with open(os.path.join(filepath, filename), 'w') as f:
92
- f.write(self._gen_model_results_summary(start_date, end_date))
95
+ f.write(self._gen_model_results_summary(start_date, end_date, use_kpi))
93
96
 
94
97
  def _gen_model_results_summary(
95
98
  self,
96
99
  start_date: tc.Date = None,
97
100
  end_date: tc.Date = None,
101
+ use_kpi: bool = False,
98
102
  ) -> str:
99
103
  """Generate HTML results summary output (as sanitized content str)."""
100
104
  all_dates = self._meridian.input_data.time_coordinates.all_dates
@@ -140,6 +144,7 @@ class Summarizer:
140
144
  cards_htmls = self._create_cards_htmls(
141
145
  template_env,
142
146
  selected_times=selected_times,
147
+ use_kpi=use_kpi,
143
148
  )
144
149
 
145
150
  return html_template.render(
@@ -150,6 +155,7 @@ class Summarizer:
150
155
  self,
151
156
  template_env: jinja2.Environment,
152
157
  selected_times: Sequence[str] | None,
158
+ use_kpi: bool,
153
159
  ) -> Sequence[str]:
154
160
  """Creates the HTML snippets for cards in the summary page."""
155
161
  media_summary = visualizer.MediaSummary(
@@ -165,10 +171,13 @@ class Summarizer:
165
171
  )
166
172
  cards = [
167
173
  self._create_model_fit_card_html(
168
- template_env, selected_times=selected_times
174
+ template_env, selected_times=selected_times, use_kpi=use_kpi
169
175
  ),
170
176
  self._create_outcome_contrib_card_html(
171
- template_env, media_summary, selected_times=selected_times
177
+ template_env,
178
+ media_summary,
179
+ selected_times=selected_times,
180
+ use_kpi=use_kpi,
172
181
  ),
173
182
  self._create_performance_breakdown_card_html(
174
183
  template_env, media_summary
@@ -179,16 +188,17 @@ class Summarizer:
179
188
  media_summary=media_summary,
180
189
  media_effects=media_effects,
181
190
  reach_frequency=reach_frequency,
191
+ use_kpi=use_kpi,
182
192
  ),
183
193
  ]
184
194
  return cards
185
195
 
186
196
  def _create_model_fit_card_html(
187
- self, template_env: jinja2.Environment, **kwargs
197
+ self, template_env: jinja2.Environment, use_kpi: bool, **kwargs
188
198
  ) -> str:
189
199
  """Creates the HTML snippet for the Model Fit card."""
190
200
  model_fit = self._model_fit
191
- outcome = self._kpi_or_revenue()
201
+ outcome = self._kpi_or_revenue(use_kpi)
192
202
  expected_actual_outcome_chart = formatter.ChartSpec(
193
203
  id=summary_text.EXPECTED_ACTUAL_OUTCOME_CHART_ID,
194
204
  description=summary_text.EXPECTED_ACTUAL_OUTCOME_CHART_DESCRIPTION_FORMAT.format(
@@ -197,7 +207,9 @@ class Summarizer:
197
207
  chart_json=model_fit.plot_model_fit(**kwargs).to_json(),
198
208
  )
199
209
 
200
- predictive_accuracy_table = self._predictive_accuracy_table_spec(**kwargs)
210
+ predictive_accuracy_table = self._predictive_accuracy_table_spec(
211
+ use_kpi=use_kpi, **kwargs
212
+ )
201
213
  insights = summary_text.MODEL_FIT_INSIGHTS_FORMAT
202
214
 
203
215
  return formatter.create_card_html(
@@ -207,9 +219,11 @@ class Summarizer:
207
219
  [expected_actual_outcome_chart, predictive_accuracy_table],
208
220
  )
209
221
 
210
- def _predictive_accuracy_table_spec(self, **kwargs) -> formatter.TableSpec:
222
+ def _predictive_accuracy_table_spec(
223
+ self, use_kpi: bool, **kwargs
224
+ ) -> formatter.TableSpec:
211
225
  """Creates the HTML snippet for the predictive accuracy table."""
212
- outcome = self._kpi_or_revenue()
226
+ outcome = self._kpi_or_revenue(use_kpi)
213
227
  model_diag = self._model_diagnostics
214
228
  table = model_diag.predictive_accuracy_table(column_var=c.METRIC, **kwargs)
215
229
 
@@ -270,9 +284,10 @@ class Summarizer:
270
284
  template_env: jinja2.Environment,
271
285
  media_summary: visualizer.MediaSummary,
272
286
  selected_times: Sequence[str] | None,
287
+ use_kpi: bool,
273
288
  ) -> str:
274
289
  """Creates the HTML snippet for the Outcome Contrib card."""
275
- outcome = self._kpi_or_revenue()
290
+ outcome = self._kpi_or_revenue(use_kpi)
276
291
 
277
292
  num_selected_times = (
278
293
  self._meridian.n_times
@@ -442,9 +457,10 @@ class Summarizer:
442
457
  media_summary: visualizer.MediaSummary,
443
458
  media_effects: visualizer.MediaEffects,
444
459
  reach_frequency: visualizer.ReachAndFrequency | None,
460
+ use_kpi: bool,
445
461
  ) -> str:
446
462
  """Creates the HTML snippet for the Optimal Analyst card."""
447
- outcome = self._kpi_or_revenue()
463
+ outcome = self._kpi_or_revenue(use_kpi)
448
464
  charts = []
449
465
  charts.append(
450
466
  formatter.ChartSpec(
@@ -457,6 +473,7 @@ class Summarizer:
457
473
  selected_times=(
458
474
  frozenset(selected_times) if selected_times else None
459
475
  ),
476
+ use_kpi=use_kpi,
460
477
  plot_separately=False,
461
478
  include_ci=False,
462
479
  num_channels_displayed=7,
@@ -524,9 +541,7 @@ class Summarizer:
524
541
  rf_channel=most_spend_rf_channel
525
542
  ).optimal_frequency
526
543
 
527
- def _kpi_or_revenue(self) -> str:
528
- if self._meridian.input_data.revenue_per_kpi is not None:
529
- outcome_str = c.REVENUE
530
- else:
531
- outcome_str = c.KPI.upper()
532
- return outcome_str
544
+ def _kpi_or_revenue(self, use_kpi: bool) -> str:
545
+ if use_kpi or self._meridian.input_data.revenue_per_kpi is None:
546
+ return c.KPI.upper()
547
+ return c.REVENUE
@@ -2571,6 +2571,9 @@ ADSTOCK_DECAY_MEAN = np.array([1.0, 1.0, 0.8493, 0.8630, 0.7215])
2571
2571
  ORGANIC_ADSTOCK_DECAY_CI_HI = np.array([1.0, 0.9636, 0.9291, 0.8962, 0.8650])
2572
2572
  ORGANIC_ADSTOCK_DECAY_CI_LO = np.array([1.0, 0.6623, 0.4394, 0.2920, 0.1944])
2573
2573
  ORGANIC_ADSTOCK_DECAY_MEAN = np.array([1.0, 0.8076, 0.6633, 0.5537, 0.4693])
2574
+ ORGANIC_RF_ADSTOCK_DECAY_CI_HI = np.array([1.0, 0.9208, 0.8482, 0.781, 0.7202])
2575
+ ORGANIC_RF_ADSTOCK_DECAY_CI_LO = np.array([1.0, 0.6674, 0.4460, 0.2985, 0.2001])
2576
+ ORGANIC_RF_ADSTOCK_DECAY_MEAN = np.array([1.0, 0.8344, 0.7042, 0.6001, 0.5155])
2574
2577
  HILL_CURVES_CI_HI = np.array([0.0, 0.0, 0.00098, 0.00895, 0.00195])
2575
2578
  HILL_CURVES_CI_LO = np.array([0.0, 0.0, 0.00085, 0.00322, 0.00169])
2576
2579
  HILL_CURVES_MEAN = np.array([0.0, 0.0, 0.00091, 0.00606, 0.00183])
@@ -2600,110 +2603,106 @@ PREDICTIVE_ACCURACY_NO_HOLDOUT_ID_TIMES_AND_GEOS = np.array(
2600
2603
  [-13.597, -7.360, 1.634, 0.887, 0.990, 0.757]
2601
2604
  )
2602
2605
  PREDICTIVE_ACCURACY_HOLDOUT_ID_NO_GEOS_OR_TIMES = np.array([
2603
- -2.907,
2604
- -2.356,
2605
- -2.784,
2606
- -3.267,
2607
- -1.431,
2608
- -5.836,
2609
- 2.481,
2610
- 46.381,
2611
- 10.724,
2612
- 0.729,
2613
- 63.022,
2614
- 0.696,
2615
- 0.994,
2616
- 1.038,
2617
- 1.001,
2618
- 0.633,
2619
- 0.906,
2620
- 0.596,
2606
+ -2.690704,
2607
+ -3.231603,
2608
+ -2.784759,
2609
+ -2.866354,
2610
+ -1.595214,
2611
+ -5.836171,
2612
+ 12.76909,
2613
+ 1.634914,
2614
+ 10.724035,
2615
+ 0.720384,
2616
+ 1.306767,
2617
+ 0.696319,
2618
+ 0.993516,
2619
+ 1.036484,
2620
+ 1.00181,
2621
+ 0.595755,
2622
+ 0.945436,
2623
+ 0.596676,
2621
2624
  ])
2622
-
2623
2625
  PREDICTIVE_ACCURACY_HOLDOUT_ID_GEOS_NO_TIMES = np.array([
2624
- -4.6241765,
2625
- -3.13614225,
2626
- -4.29795837,
2627
- -4.48636341,
2628
- -1.84796333,
2629
- -4.71139717,
2630
- 2.49890709,
2631
- 8.42507458,
2632
- 3.6478579,
2633
- 1.22353196,
2634
- 4.74503851,
2635
- 1.15957963,
2636
- 1.09398592,
2637
- 1.03744256,
2638
- 1.08360898,
2639
- 0.91437268,
2640
- 1.00845361,
2641
- 0.82047617,
2626
+ -5.167992,
2627
+ -2.34246,
2628
+ -4.297958,
2629
+ -3.945161,
2630
+ -2.930509,
2631
+ -4.711397,
2632
+ 4.080629,
2633
+ 1.58583,
2634
+ 3.647858,
2635
+ 1.457559,
2636
+ 1.268536,
2637
+ 1.15958,
2638
+ 1.123409,
2639
+ 0.932309,
2640
+ 1.083609,
2641
+ 0.8971,
2642
+ 0.932309,
2643
+ 0.820476,
2642
2644
  ])
2643
-
2644
2645
  PREDICTIVE_ACCURACY_HOLDOUT_ID_TIMES_NO_GEOS = np.array([
2645
- -1.23524213,
2646
- -9.06220913,
2647
- -1.39263272,
2648
- 0.3634333,
2649
- -8.32915783,
2650
- -0.81341398,
2651
- 1.35193038,
2652
- 4.68957043,
2653
- 2.24196768,
2654
- 0.75223929,
2655
- 3.71988177,
2656
- 0.99283481,
2657
- 1.01492858,
2658
- 2.17278171,
2659
- 1.17750895,
2660
- 0.45635709,
2661
- 2.17278171,
2662
- 0.60861236,
2646
+ -1.398977,
2647
+ 0.791522,
2648
+ -1.392633,
2649
+ -0.294972,
2650
+ 0.791522,
2651
+ -0.813414,
2652
+ 2.577664,
2653
+ 0.059942,
2654
+ 2.241968,
2655
+ 1.445928,
2656
+ 0.059942,
2657
+ 0.992835,
2658
+ 1.349253,
2659
+ 0.051477,
2660
+ 1.177509,
2661
+ 0.693587,
2662
+ 0.051477,
2663
+ 0.608612,
2663
2664
  ])
2664
2665
  PREDICTIVE_ACCURACY_HOLDOUT_ID_TIMES_AND_GEO = np.array([
2665
- -38.25726318,
2666
- float("-inf"),
2667
- -13.59724903,
2668
- -1.61034203,
2669
- float("-inf"),
2670
- -7.37024498,
2671
- 0.80568475,
2672
- 5.78146744,
2673
- 1.63498175,
2674
- 0.43890095,
2675
- 5.78146744,
2676
- 0.88831377,
2677
- 0.7162329,
2678
- 5.78146744,
2679
- 0.99158531,
2680
- 0.46946037,
2681
- 5.78146744,
2682
- 0.75822759,
2666
+ -20.432268,
2667
+ 0.791522,
2668
+ -13.597249,
2669
+ -4.614312,
2670
+ 0.791522,
2671
+ -7.370245,
2672
+ 2.422502,
2673
+ 0.059942,
2674
+ 1.634982,
2675
+ 2.440133,
2676
+ 0.059942,
2677
+ 0.888314,
2678
+ 1.646492,
2679
+ 0.051477,
2680
+ 0.991585,
2681
+ 1.25399,
2682
+ 0.051477,
2683
+ 0.758228,
2683
2684
  ])
2684
-
2685
2685
  PREDICTIVE_ACCURACY_HOLDOUT_ID_NATIONAL_NO_TIMES = np.array([
2686
- 0.42883771657943726,
2687
- 0.4715208411216736,
2688
- 0.45594334602355957,
2689
- 0.8378637433052063,
2690
- 13.80582332611084,
2691
- 2.9550814628601074,
2692
- 0.34947845339775085,
2693
- 0.4262354075908661,
2694
- 0.3586611747741699,
2686
+ -15.619549,
2687
+ -28.130356,
2688
+ -17.316074,
2689
+ 16.30377,
2690
+ 10.817584,
2691
+ 15.296103,
2692
+ 2.40538,
2693
+ 2.640707,
2694
+ 2.449049,
2695
2695
  ])
2696
-
2697
2696
  PREDICTIVE_ACCURACY_HOLDOUT_ID_NATIONAL_TIMES = np.array([
2698
- -0.30289185,
2699
- float("-inf"),
2700
- 0.15624052,
2701
- 0.86708003,
2702
- 107.46259308,
2703
- 36.39891815,
2704
- 0.61977416,
2705
- 107.46259308,
2706
- 0.88497865,
2697
+ -22.270792,
2698
+ np.nan,
2699
+ -22.270792,
2700
+ 161.652573,
2701
+ np.nan,
2702
+ 161.652573,
2703
+ 4.597788,
2704
+ np.nan,
2705
+ 4.597788,
2707
2706
  ])
2708
2707
 
2709
2708
  SAMPLE_IMPRESSIONS = np.array([
@@ -3142,6 +3141,7 @@ def generate_hill_curves_dataframe() -> pd.DataFrame:
3142
3141
  [f"ch_{i}" for i in range(3)]
3143
3142
  + [f"rf_ch_{i}" for i in range(2)]
3144
3143
  + [f"organic_ch_{i}" for i in range(2)]
3144
+ + [f"organic_rf_ch_{i}" for i in range(1)]
3145
3145
  )
3146
3146
  channel_array = []
3147
3147
  channel_type_array = []
@@ -3154,6 +3154,8 @@ def generate_hill_curves_dataframe() -> pd.DataFrame:
3154
3154
  channel_type_array.append(c.RF)
3155
3155
  elif channel_name.startswith("organic_ch_"):
3156
3156
  channel_type_array.append(c.ORGANIC_MEDIA)
3157
+ elif channel_name.startswith("organic_rf_ch_"):
3158
+ channel_type_array.append(c.ORGANIC_RF)
3157
3159
 
3158
3160
  np.random.seed(0)
3159
3161
  media_units_array = [
@@ -19,6 +19,7 @@ import functools
19
19
  from typing import Mapping
20
20
  import warnings
21
21
  import altair as alt
22
+ from meridian import backend
22
23
  from meridian import constants as c
23
24
  from meridian.analysis import analyzer
24
25
  from meridian.analysis import formatter
@@ -26,8 +27,6 @@ from meridian.analysis import summary_text
26
27
  from meridian.model import model
27
28
  import numpy as np
28
29
  import pandas as pd
29
- import tensorflow as tf
30
- import tensorflow_probability as tfp
31
30
  import xarray as xr
32
31
 
33
32
 
@@ -312,10 +311,10 @@ class ModelDiagnostics:
312
311
  k: v.values
313
312
  for k, v in self._meridian.inference_data.posterior.data_vars.items()
314
313
  }
315
- for k, v in tfp.mcmc.potential_scale_reduction(
316
- {k: tf.einsum('ij...->ji...', v) for k, v in mcmc_states.items()}
314
+ for k, v in backend.mcmc.potential_scale_reduction(
315
+ {k: backend.einsum('ij...->ji...', v) for k, v in mcmc_states.items()}
317
316
  ).items():
318
- rhat_temp = v.numpy().flatten()
317
+ rhat_temp = np.asarray(v).flatten()
319
318
  rhat = pd.concat([
320
319
  rhat,
321
320
  pd.DataFrame({
@@ -864,6 +863,7 @@ class MediaEffects:
864
863
  confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
865
864
  selected_times: frozenset[str] | None = None,
866
865
  by_reach: bool = True,
866
+ use_kpi: bool = False,
867
867
  ) -> xr.Dataset:
868
868
  """Dataset holding the calculated response curves data.
869
869
 
@@ -887,12 +887,14 @@ class MediaEffects:
887
887
  by_reach: For the channel w/ reach and frequency, return the response
888
888
  curves by reach given fixed frequency if true; return the response
889
889
  curves by frequency given fixed reach if false.
890
+ use_kpi: If `True`, calculate the incremental KPI. Otherwise, calculate
891
+ the incremental revenue using the revenue per KPI (if available).
890
892
 
891
893
  Returns:
892
894
  A Dataset displaying the response curves data.
893
895
  """
894
896
  selected_times_list = list(selected_times) if selected_times else None
895
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
897
+ use_kpi = use_kpi or self._meridian.input_data.revenue_per_kpi is None
896
898
  return self._analyzer.response_curves(
897
899
  spend_multipliers=list(np.arange(0, 2.2, c.RESPONSE_CURVE_STEP_SIZE)),
898
900
  confidence_level=confidence_level,
@@ -962,6 +964,7 @@ class MediaEffects:
962
964
  confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
963
965
  selected_times: frozenset[str] | None = None,
964
966
  by_reach: bool = True,
967
+ use_kpi: bool = False,
965
968
  plot_separately: bool = True,
966
969
  include_ci: bool = True,
967
970
  num_channels_displayed: int | None = None,
@@ -987,6 +990,8 @@ class MediaEffects:
987
990
  by_reach: For the channel w/ reach and frequency, return the response
988
991
  curves by reach given fixed frequency if true; return the response
989
992
  curves by frequency given fixed reach if false.
993
+ use_kpi: If `True`, calculate the incremental KPI. Otherwise, calculate
994
+ the incremental revenue using the revenue per KPI (if available).
990
995
  plot_separately: If `True`, the plots are faceted. If `False`, the plots
991
996
  are layered to create one plot with all of the channels.
992
997
  include_ci: If `True`, plots the credible interval. Defaults to `True`.
@@ -1022,11 +1027,13 @@ class MediaEffects:
1022
1027
  confidence_level=confidence_level,
1023
1028
  selected_times=selected_times,
1024
1029
  by_reach=by_reach,
1030
+ use_kpi=use_kpi,
1031
+ )
1032
+ y_axis_label = (
1033
+ summary_text.INC_KPI_LABEL
1034
+ if use_kpi or self._meridian.input_data.revenue_per_kpi is None
1035
+ else summary_text.INC_OUTCOME_LABEL
1025
1036
  )
1026
- if self._meridian.input_data.revenue_per_kpi is not None:
1027
- y_axis_label = summary_text.INC_OUTCOME_LABEL
1028
- else:
1029
- y_axis_label = summary_text.INC_KPI_LABEL
1030
1037
  base = (
1031
1038
  alt.Chart(response_curves_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
1032
1039
  .transform_calculate(
@@ -1197,41 +1204,32 @@ class MediaEffects:
1197
1204
  include_ci: If `True`, plots the credible interval. Defaults to `True`.
1198
1205
 
1199
1206
  Returns:
1200
- A dictionary mapping channel type constants (`media`, `rf`, and
1201
- `organic_media`) to their respective Altair chart objects. Keys are only
1202
- present if charts for that type were generated (i.e., if the
1203
- corresponding channels exist in the data). Returns an empty dictionary if
1204
- no relevant channels are found.
1207
+ A dictionary mapping channel type constants (`media`, `rf`,
1208
+ `organic_media`, and `organic_rf`) to their respective Altair chart
1209
+ objects. Keys are only present if charts for that type were generated
1210
+ (i.e., if the corresponding channels exist in the data). Returns an empty
1211
+ dictionary if no relevant channels are found.
1205
1212
  """
1206
1213
  hill_curves_dataframe = self.hill_curves_dataframe(
1207
1214
  confidence_level=confidence_level
1208
1215
  )
1209
- channel_types = list(set(hill_curves_dataframe[c.CHANNEL_TYPE]))
1216
+ all_channel_types = set(hill_curves_dataframe[c.CHANNEL_TYPE])
1210
1217
  plots: dict[str, alt.Chart] = {}
1211
1218
 
1212
- if c.MEDIA in channel_types:
1213
- media_df = hill_curves_dataframe[
1214
- hill_curves_dataframe[c.CHANNEL_TYPE] == c.MEDIA
1215
- ]
1216
- plots[c.MEDIA] = self._plot_hill_curves_helper(
1217
- media_df, include_prior, include_ci
1218
- )
1219
-
1220
- if c.RF in channel_types:
1221
- rf_df = hill_curves_dataframe[
1222
- hill_curves_dataframe[c.CHANNEL_TYPE] == c.RF
1223
- ]
1224
- plots[c.RF] = self._plot_hill_curves_helper(
1225
- rf_df, include_prior, include_ci
1226
- )
1227
-
1228
- if c.ORGANIC_MEDIA in channel_types:
1229
- organic_media_df = hill_curves_dataframe[
1230
- hill_curves_dataframe[c.CHANNEL_TYPE] == c.ORGANIC_MEDIA
1231
- ]
1232
- plots[c.ORGANIC_MEDIA] = self._plot_hill_curves_helper(
1233
- organic_media_df, include_prior, include_ci
1234
- )
1219
+ supported_channel_types = [
1220
+ c.MEDIA,
1221
+ c.RF,
1222
+ c.ORGANIC_MEDIA,
1223
+ c.ORGANIC_RF,
1224
+ ]
1225
+ for channel_type in supported_channel_types:
1226
+ if channel_type in all_channel_types:
1227
+ df_for_type = hill_curves_dataframe[
1228
+ hill_curves_dataframe[c.CHANNEL_TYPE] == channel_type
1229
+ ]
1230
+ plots[channel_type] = self._plot_hill_curves_helper(
1231
+ df_for_type, include_prior, include_ci
1232
+ )
1235
1233
 
1236
1234
  return plots
1237
1235
 
@@ -1259,19 +1257,17 @@ class MediaEffects:
1259
1257
  column, or contains an unsupported channel type.
1260
1258
  """
1261
1259
  channel_type = df_channel_type[c.CHANNEL_TYPE].iloc[0]
1262
- if channel_type == c.MEDIA:
1260
+ if channel_type in [c.MEDIA, c.ORGANIC_MEDIA]:
1263
1261
  x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1264
1262
  shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1265
- elif channel_type == c.RF:
1263
+ elif channel_type in [c.RF, c.ORGANIC_RF]:
1266
1264
  x_axis_title = summary_text.HILL_X_AXIS_RF_LABEL
1267
1265
  shaded_area_title = summary_text.HILL_SHADED_REGION_RF_LABEL
1268
- elif channel_type == c.ORGANIC_MEDIA:
1269
- x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1270
- shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1271
1266
  else:
1272
1267
  raise ValueError(
1273
1268
  f"Unsupported channel type '{channel_type}' found in Hill curve data."
1274
- ' Expected one of: {c.MEDIA}, {c.RF}, {c.ORGANIC_MEDIA}.'
1269
+ ' Expected one of: {c.MEDIA}, {c.RF}, {c.ORGANIC_MEDIA},'
1270
+ ' {c.ORGANIC_RF}.'
1275
1271
  )
1276
1272
  domain_list = [
1277
1273
  c.POSTERIOR,
@@ -1345,6 +1341,7 @@ class MediaEffects:
1345
1341
  selected_times: frozenset[str] | None = None,
1346
1342
  confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
1347
1343
  by_reach: bool = True,
1344
+ use_kpi: bool = False,
1348
1345
  ) -> pd.DataFrame:
1349
1346
  """Returns DataFrame with top channels by spend for the layered plot.
1350
1347
 
@@ -1359,6 +1356,7 @@ class MediaEffects:
1359
1356
  by_reach: For the channel w/ reach and frequency, return the response
1360
1357
  curves by reach given fixed frequency if true; return the response
1361
1358
  curves by frequency given fixed reach if false.
1359
+ use_kpi: If `True`, use KPI instead of revenue.
1362
1360
 
1363
1361
  Returns:
1364
1362
  A DataFrame containing the top chosen channels
@@ -1369,6 +1367,7 @@ class MediaEffects:
1369
1367
  confidence_level=confidence_level,
1370
1368
  selected_times=selected_times,
1371
1369
  by_reach=by_reach,
1370
+ use_kpi=use_kpi,
1372
1371
  )
1373
1372
  list_sorted_channels_cost = list(
1374
1373
  data.sel(spend_multiplier=1)
@@ -1433,8 +1432,8 @@ class MediaSummary:
1433
1432
  non_media_baseline_values: Optional list of shape
1434
1433
  `(n_non_media_channels,)`. Each element is a float denoting the fixed
1435
1434
  value which will be used as baseline for the given channel. If `None`,
1436
- the values defined with `ModelSpec.non_media_baseline_values`
1437
- will be used.
1435
+ the values defined with `ModelSpec.non_media_baseline_values` will be
1436
+ used.
1438
1437
  """
1439
1438
  self._meridian = meridian
1440
1439
  self._analyzer = analyzer.Analyzer(meridian)
@@ -1654,8 +1653,8 @@ class MediaSummary:
1654
1653
  non_media_baseline_values: Optional list of shape
1655
1654
  `(n_non_media_channels,)`. Each element is a float denoting the fixed
1656
1655
  value which will be used as baseline for the given channel. If `None`,
1657
- the values defined with `ModelSpec.non_media_baseline_values`
1658
- will be used.
1656
+ the values defined with `ModelSpec.non_media_baseline_values` will be
1657
+ used.
1659
1658
  """
1660
1659
  self._confidence_level = confidence_level or self._confidence_level
1661
1660
  self._selected_times = selected_times
@@ -1686,7 +1685,7 @@ class MediaSummary:
1686
1685
  c.DATE_FORMAT if time_granularity == c.WEEKLY else c.QUARTER_FORMAT
1687
1686
  )
1688
1687
 
1689
- outcome_df = self._transform_contribution_metrics(
1688
+ outcome_df = self.contribution_metrics(
1690
1689
  include_non_paid=True, aggregate_times=False
1691
1690
  )
1692
1691
 
@@ -1800,7 +1799,7 @@ class MediaSummary:
1800
1799
  f'time_granularity must be one of {c.TIME_GRANULARITIES}'
1801
1800
  )
1802
1801
 
1803
- outcome_df = self._transform_contribution_metrics(
1802
+ outcome_df = self.contribution_metrics(
1804
1803
  include_non_paid=True, aggregate_times=False
1805
1804
  )
1806
1805
  outcome_df[c.TIME] = pd.to_datetime(outcome_df[c.TIME])
@@ -1907,7 +1906,7 @@ class MediaSummary:
1907
1906
  if self._meridian.input_data.revenue_per_kpi is not None
1908
1907
  else c.KPI.upper()
1909
1908
  )
1910
- outcome_df = self._transform_contribution_metrics(include_non_paid=True)
1909
+ outcome_df = self.contribution_metrics(include_non_paid=True)
1911
1910
  pct = c.PCT_OF_CONTRIBUTION
1912
1911
  value = c.INCREMENTAL_OUTCOME
1913
1912
  outcome_df['outcome_text'] = outcome_df.apply(
@@ -1991,7 +1990,7 @@ class MediaSummary:
1991
1990
  Returns:
1992
1991
  An Altair plot showing the contributions for all channels.
1993
1992
  """
1994
- outcome_df = self._transform_contribution_metrics(
1993
+ outcome_df = self.contribution_metrics(
1995
1994
  [c.ALL_CHANNELS], include_non_paid=True
1996
1995
  )
1997
1996
 
@@ -2411,7 +2410,7 @@ class MediaSummary:
2411
2410
  spend_df = paid_summary_metrics[c.SPEND].to_dataframe().reset_index()
2412
2411
  return metrics_df.merge(spend_df, on=c.CHANNEL)
2413
2412
 
2414
- def _transform_contribution_metrics(
2413
+ def contribution_metrics(
2415
2414
  self,
2416
2415
  selected_channels: Sequence[str] | None = None,
2417
2416
  include_non_paid: bool = False,