google-meridian 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,8 @@
16
16
 
17
17
  from collections.abc import Sequence
18
18
  import functools
19
+ from typing import Mapping
20
+ import warnings
19
21
  import altair as alt
20
22
  from meridian import constants as c
21
23
  from meridian.analysis import analyzer
@@ -32,7 +34,9 @@ import xarray as xr
32
34
  __all__ = [
33
35
  'ModelDiagnostics',
34
36
  'ModelFit',
37
+ 'MediaSummary',
35
38
  'ReachAndFrequency',
39
+ 'MediaEffects',
36
40
  ]
37
41
 
38
42
 
@@ -246,7 +250,7 @@ class ModelDiagnostics:
246
250
  .mark_area(opacity=0.7)
247
251
  .encode(
248
252
  x=f'{parameter}:Q',
249
- y='density:Q',
253
+ y=alt.Y(shorthand='density:Q', stack=False),
250
254
  color=f'{c.DISTRIBUTION}:N',
251
255
  )
252
256
  )
@@ -461,14 +465,14 @@ class ModelFit:
461
465
  else:
462
466
  y_axis_label = summary_text.KPI_LABEL
463
467
  plot = (
464
- alt.Chart(model_fit_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
468
+ alt.Chart(model_fit_df, width=c.VEGALITE_FACET_EXTRA_LARGE_WIDTH)
465
469
  .mark_line()
466
470
  .encode(
467
471
  x=alt.X(
468
472
  f'{c.TIME}:T',
469
473
  title='Time period',
470
474
  axis=alt.Axis(
471
- format='%Y %b',
475
+ format=c.QUARTER_FORMAT,
472
476
  grid=False,
473
477
  tickCount=8,
474
478
  domainColor=c.GREY_300,
@@ -883,11 +887,13 @@ class MediaEffects:
883
887
  A Dataset displaying the response curves data.
884
888
  """
885
889
  selected_times_list = list(selected_times) if selected_times else None
890
+ use_kpi = self._meridian.input_data.revenue_per_kpi is None
886
891
  return self._analyzer.response_curves(
887
892
  spend_multipliers=list(np.arange(0, 2.2, c.RESPONSE_CURVE_STEP_SIZE)),
888
893
  confidence_level=confidence_level,
889
894
  selected_times=selected_times_list,
890
895
  by_reach=by_reach,
896
+ use_kpi=use_kpi,
891
897
  )
892
898
 
893
899
  @functools.lru_cache(maxsize=128)
@@ -1170,7 +1176,7 @@ class MediaEffects:
1170
1176
  confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
1171
1177
  include_prior: bool = True,
1172
1178
  include_ci: bool = True,
1173
- ) -> alt.Chart | list[alt.Chart]:
1179
+ ) -> Mapping[str, alt.Chart]:
1174
1180
  """Plots the Hill curves for each channel.
1175
1181
 
1176
1182
  Args:
@@ -1181,22 +1187,23 @@ class MediaEffects:
1181
1187
  include_ci: If `True`, plots the credible interval. Defaults to `True`.
1182
1188
 
1183
1189
  Returns:
1184
- A faceted Altair plot showing the histogram, prior and posterior lines,
1185
- and bands for the Hill saturation curves. When there are both media and
1186
- RF channels, a list of 2 faceted Altair plots are returned: one
1187
- for the media channels and another for the RF channels.
1190
+ A dictionary mapping channel type constants (`media`, `rf`, and
1191
+ `organic_media`) to their respective Altair chart objects. Keys are only
1192
+ present if charts for that type were generated (i.e., if the
1193
+ corresponding channels exist in the data). Returns an empty dictionary if
1194
+ no relevant channels are found.
1188
1195
  """
1189
1196
  hill_curves_dataframe = self.hill_curves_dataframe(
1190
1197
  confidence_level=confidence_level
1191
1198
  )
1192
1199
  channel_types = list(set(hill_curves_dataframe[c.CHANNEL_TYPE]))
1193
- plot_media, plot_rf = None, None
1200
+ plots: dict[str, alt.Chart] = {}
1194
1201
 
1195
1202
  if c.MEDIA in channel_types:
1196
1203
  media_df = hill_curves_dataframe[
1197
1204
  hill_curves_dataframe[c.CHANNEL_TYPE] == c.MEDIA
1198
1205
  ]
1199
- plot_media = self._plot_hill_curves_helper(
1206
+ plots[c.MEDIA] = self._plot_hill_curves_helper(
1200
1207
  media_df, include_prior, include_ci
1201
1208
  )
1202
1209
 
@@ -1204,14 +1211,19 @@ class MediaEffects:
1204
1211
  rf_df = hill_curves_dataframe[
1205
1212
  hill_curves_dataframe[c.CHANNEL_TYPE] == c.RF
1206
1213
  ]
1207
- plot_rf = self._plot_hill_curves_helper(rf_df, include_prior, include_ci)
1214
+ plots[c.RF] = self._plot_hill_curves_helper(
1215
+ rf_df, include_prior, include_ci
1216
+ )
1208
1217
 
1209
- if plot_media and plot_rf:
1210
- return [plot_media, plot_rf]
1211
- elif plot_media:
1212
- return plot_media
1213
- else:
1214
- return plot_rf
1218
+ if c.ORGANIC_MEDIA in channel_types:
1219
+ organic_media_df = hill_curves_dataframe[
1220
+ hill_curves_dataframe[c.CHANNEL_TYPE] == c.ORGANIC_MEDIA
1221
+ ]
1222
+ plots[c.ORGANIC_MEDIA] = self._plot_hill_curves_helper(
1223
+ organic_media_df, include_prior, include_ci
1224
+ )
1225
+
1226
+ return plots
1215
1227
 
1216
1228
  def _plot_hill_curves_helper(
1217
1229
  self,
@@ -1231,13 +1243,26 @@ class MediaEffects:
1231
1243
  Returns:
1232
1244
  A faceted Altair plot showing the histogram and prior+posterior lines and
1233
1245
  bands for the Hill curves.
1246
+
1247
+ Raises:
1248
+ ValueError: If the input DataFrame is empty, missing the channel_type
1249
+ column, or contains an unsupported channel type.
1234
1250
  """
1235
- if c.MEDIA in list(df_channel_type[c.CHANNEL_TYPE]):
1251
+ channel_type = df_channel_type[c.CHANNEL_TYPE].iloc[0]
1252
+ if channel_type == c.MEDIA:
1236
1253
  x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1237
1254
  shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1238
- else:
1255
+ elif channel_type == c.RF:
1239
1256
  x_axis_title = summary_text.HILL_X_AXIS_RF_LABEL
1240
1257
  shaded_area_title = summary_text.HILL_SHADED_REGION_RF_LABEL
1258
+ elif channel_type == c.ORGANIC_MEDIA:
1259
+ x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1260
+ shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1261
+ else:
1262
+ raise ValueError(
1263
+ f"Unsupported channel type '{channel_type}' found in Hill curve data."
1264
+ ' Expected one of: {c.MEDIA}, {c.RF}, {c.ORGANIC_MEDIA}.'
1265
+ )
1241
1266
  domain_list = [
1242
1267
  c.POSTERIOR,
1243
1268
  c.PRIOR,
@@ -1410,18 +1435,35 @@ class MediaSummary:
1410
1435
  self._marginal_roi_by_reach = marginal_roi_by_reach
1411
1436
  self._non_media_baseline_values = non_media_baseline_values
1412
1437
 
1413
- @functools.cached_property
1414
- def paid_summary_metrics(self) -> xr.Dataset:
1438
+ @property
1439
+ def paid_summary_metrics(self):
1440
+ warnings.warn(
1441
+ 'The `paid_summary_metrics` property is deprecated. Use the'
1442
+ ' `get_paid_summary_metrics()` method instead.',
1443
+ DeprecationWarning,
1444
+ stacklevel=2,
1445
+ )
1446
+ return self.get_paid_summary_metrics()
1447
+
1448
+ @functools.lru_cache(maxsize=128)
1449
+ def get_paid_summary_metrics(
1450
+ self, aggregate_times: bool = True
1451
+ ) -> xr.Dataset:
1415
1452
  """Dataset holding the calculated summary metrics for the paid channels.
1416
1453
 
1417
- The dataset contains the following:
1454
+ Args:
1455
+ aggregate_times: If `True`, aggregates the metrics across all time
1456
+ periods. If `False`, returns time-varying metrics.
1418
1457
 
1419
- - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`, `ci_hi`),
1420
- `distribution` (`prior`, `posterior`)
1421
- - **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
1422
- `pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
1423
- `roi`,
1424
- `effectiveness`, `mroi`.
1458
+ Returns:
1459
+ An `xarray.Dataset` containing the following:
1460
+ - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
1461
+ `ci_hi`),
1462
+ `distribution` (`prior`, `posterior`)
1463
+ - **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
1464
+ `pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
1465
+ `roi`,
1466
+ `effectiveness`, `mroi`.
1425
1467
  """
1426
1468
  return self._analyzer.summary_metrics(
1427
1469
  selected_times=self._selected_times,
@@ -1429,18 +1471,34 @@ class MediaSummary:
1429
1471
  use_kpi=self._meridian.input_data.revenue_per_kpi is None,
1430
1472
  confidence_level=self._confidence_level,
1431
1473
  include_non_paid_channels=False,
1474
+ aggregate_times=aggregate_times,
1475
+ )
1476
+
1477
+ @property
1478
+ def all_summary_metrics(self):
1479
+ warnings.warn(
1480
+ 'The `all_summary_metrics` property is deprecated. Use the'
1481
+ ' `get_all_summary_metrics()` method instead.',
1482
+ DeprecationWarning,
1483
+ stacklevel=2,
1432
1484
  )
1485
+ return self.get_all_summary_metrics()
1433
1486
 
1434
- @functools.cached_property
1435
- def all_summary_metrics(self) -> xr.Dataset:
1487
+ @functools.lru_cache(maxsize=128)
1488
+ def get_all_summary_metrics(self, aggregate_times: bool = True) -> xr.Dataset:
1436
1489
  """Dataset holding the calculated summary metrics for all channels.
1437
1490
 
1438
- The dataset contains the following:
1491
+ Args:
1492
+ aggregate_times: If `True`, aggregates the metrics across all time
1493
+ periods. If `False`, returns time-varying metrics.
1439
1494
 
1440
- - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`, `ci_hi`),
1441
- `distribution` (`prior`, `posterior`)
1442
- - **Data variables:** `incremental_outcome`, `pct_of_contribution`,
1443
- `effectiveness`.
1495
+ Returns:
1496
+ An `xarray.Dataset` containing the following:
1497
+ - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
1498
+ `ci_hi`),
1499
+ `distribution` (`prior`, `posterior`)
1500
+ - **Data variables:** `incremental_outcome`, `pct_of_contribution`,
1501
+ `effectiveness`.
1444
1502
  """
1445
1503
  return self._analyzer.summary_metrics(
1446
1504
  selected_times=self._selected_times,
@@ -1448,6 +1506,7 @@ class MediaSummary:
1448
1506
  confidence_level=self._confidence_level,
1449
1507
  include_non_paid_channels=True,
1450
1508
  non_media_baseline_values=self._non_media_baseline_values,
1509
+ aggregate_times=aggregate_times,
1451
1510
  )
1452
1511
 
1453
1512
  def summary_table(
@@ -1488,7 +1547,7 @@ class MediaSummary:
1488
1547
  ]
1489
1548
  if include_non_paid_channels:
1490
1549
  monetary_metrics = [c.INCREMENTAL_OUTCOME] * use_revenue
1491
- summary_metrics = self.all_summary_metrics
1550
+ summary_metrics = self.get_all_summary_metrics()
1492
1551
  columns_rename_dict = {
1493
1552
  c.PCT_OF_CONTRIBUTION: summary_text.PCT_CONTRIBUTION_COL,
1494
1553
  c.INCREMENTAL_OUTCOME: (
@@ -1512,7 +1571,7 @@ class MediaSummary:
1512
1571
  c.SPEND,
1513
1572
  c.INCREMENTAL_OUTCOME,
1514
1573
  ] * use_revenue
1515
- summary_metrics = self.paid_summary_metrics
1574
+ summary_metrics = self.get_paid_summary_metrics()
1516
1575
  columns_rename_dict = {
1517
1576
  c.PCT_OF_IMPRESSIONS: summary_text.PCT_IMPRESSIONS_COL,
1518
1577
  c.PCT_OF_SPEND: summary_text.PCT_SPEND_COL,
@@ -1598,6 +1657,240 @@ class MediaSummary:
1598
1657
  self._marginal_roi_by_reach = marginal_roi_by_reach
1599
1658
  self._non_media_baseline_values = non_media_baseline_values
1600
1659
 
1660
+ def plot_channel_contribution_area_chart(
1661
+ self, time_granularity: str = c.QUARTERLY
1662
+ ) -> alt.Chart:
1663
+ """Plots a stacked area chart of the contribution share per channel by time.
1664
+
1665
+ Args:
1666
+ time_granularity: The granularity for the time axis. Options are `weekly`
1667
+ or `quarterly`. Defaults to `quarterly`.
1668
+
1669
+ Returns:
1670
+ An Altair plot showing the contribution share per channel by time.
1671
+
1672
+ Raises:
1673
+ ValueError: If time_granularity is not one of the allowed constants.
1674
+ """
1675
+ if time_granularity not in c.TIME_GRANULARITIES:
1676
+ raise ValueError(
1677
+ f'time_granularity must be one of {c.TIME_GRANULARITIES}'
1678
+ )
1679
+
1680
+ x_axis_format = (
1681
+ c.DATE_FORMAT if time_granularity == c.WEEKLY else c.QUARTER_FORMAT
1682
+ )
1683
+
1684
+ outcome_df = self._transform_contribution_metrics(
1685
+ include_non_paid=True, aggregate_times=False
1686
+ )
1687
+
1688
+ # Ensure proper ordering for the stacked area chart. Baseline should be at
1689
+ # the bottom. Separate the *stacking* order from the *legend* order.
1690
+ stack_order = sorted([
1691
+ channel
1692
+ for channel in outcome_df[c.CHANNEL].unique()
1693
+ if channel != c.BASELINE
1694
+ ]) + [c.BASELINE]
1695
+
1696
+ legend_order = [c.BASELINE] + sorted([
1697
+ channel
1698
+ for channel in outcome_df[c.CHANNEL].unique()
1699
+ if channel != c.BASELINE
1700
+ ])
1701
+
1702
+ # Get the minimum incremental outcome for baseline across all time periods
1703
+ # as the lower bound for the stacked area chart.
1704
+ min_y = (
1705
+ outcome_df[outcome_df[c.CHANNEL] == c.BASELINE]
1706
+ .groupby(c.TIME)[c.INCREMENTAL_OUTCOME]
1707
+ .min()
1708
+ .min()
1709
+ )
1710
+
1711
+ plot = (
1712
+ alt.Chart(outcome_df, width=c.VEGALITE_FACET_EXTRA_LARGE_WIDTH)
1713
+ .mark_area()
1714
+ .transform_calculate(
1715
+ sort_channel=f'indexof({stack_order}, datum.channel)'
1716
+ )
1717
+ .encode(
1718
+ x=alt.X(
1719
+ f'{c.TIME}:T',
1720
+ title='Time period',
1721
+ axis=alt.Axis(
1722
+ format=x_axis_format,
1723
+ grid=False,
1724
+ tickCount=8,
1725
+ domainColor=c.GREY_300,
1726
+ ),
1727
+ ),
1728
+ y=alt.Y(
1729
+ f'{c.INCREMENTAL_OUTCOME}:Q',
1730
+ title=(
1731
+ c.REVENUE.title()
1732
+ if self._meridian.input_data.revenue_per_kpi is not None
1733
+ else c.KPI.upper()
1734
+ ),
1735
+ axis=alt.Axis(
1736
+ ticks=False,
1737
+ domain=False,
1738
+ tickCount=5,
1739
+ labelPadding=c.PADDING_10,
1740
+ labelExpr=formatter.compact_number_expr(),
1741
+ **formatter.Y_AXIS_TITLE_CONFIG,
1742
+ ),
1743
+ scale=alt.Scale(domainMin=min_y, clamp=True),
1744
+ ),
1745
+ color=alt.Color(
1746
+ f'{c.CHANNEL}:N',
1747
+ legend=alt.Legend(
1748
+ labelFontSize=c.AXIS_FONT_SIZE,
1749
+ labelFont=c.FONT_ROBOTO,
1750
+ title=None,
1751
+ orient='bottom',
1752
+ ),
1753
+ scale=alt.Scale(domain=legend_order),
1754
+ sort=legend_order,
1755
+ ),
1756
+ tooltip=[
1757
+ alt.Tooltip(f'{c.TIME}:T', format=c.DATE_FORMAT),
1758
+ c.CHANNEL,
1759
+ alt.Tooltip(f'{c.INCREMENTAL_OUTCOME}:Q', format=',.2f'),
1760
+ ],
1761
+ order=alt.Order('sort_channel:N', sort='descending'),
1762
+ )
1763
+ .properties(
1764
+ title=formatter.custom_title_params(
1765
+ summary_text.CHANNEL_CONTRIB_BY_TIME_CHART_TITLE
1766
+ ),
1767
+ )
1768
+ .configure_axis(titlePadding=c.PADDING_10, **formatter.TEXT_CONFIG)
1769
+ .configure_view(strokeOpacity=0)
1770
+ )
1771
+ return plot
1772
+
1773
+ def plot_channel_contribution_bump_chart(
1774
+ self, time_granularity: str = c.QUARTERLY
1775
+ ) -> alt.Chart:
1776
+ """Plots a bump chart of channel contribution rank over time.
1777
+
1778
+ This chart shows the relative rank of each channel's contribution,
1779
+ including the baseline, based on incremental outcome. Depending on the
1780
+ time_granularity, ranks are shown either weekly or at the end of each
1781
+ quarter. Rank 1 represents the highest contribution.
1782
+
1783
+ Args:
1784
+ time_granularity: The granularity for the time axis. Options are `weekly`
1785
+ or `quarterly`. Defaults to `quarterly`.
1786
+
1787
+ Returns:
1788
+ An Altair plot showing the contribution rank per channel by time.
1789
+
1790
+ Raises:
1791
+ ValueError: If time_granularity is not one of the allowed constants.
1792
+ """
1793
+ if time_granularity not in c.TIME_GRANULARITIES:
1794
+ raise ValueError(
1795
+ f'time_granularity must be one of {c.TIME_GRANULARITIES}'
1796
+ )
1797
+
1798
+ outcome_df = self._transform_contribution_metrics(
1799
+ include_non_paid=True, aggregate_times=False
1800
+ )
1801
+ outcome_df[c.TIME] = pd.to_datetime(outcome_df[c.TIME])
1802
+
1803
+ outcome_df['rank'] = outcome_df.groupby(c.TIME)[c.INCREMENTAL_OUTCOME].rank(
1804
+ method='first', ascending=False
1805
+ )
1806
+
1807
+ if time_granularity == c.QUARTERLY:
1808
+ # Filter data to keep only the last available date within each quarter
1809
+ # for a quarterly view of ranking changes.
1810
+ unique_times = pd.Series(outcome_df[c.TIME].unique()).sort_values()
1811
+ quarters = unique_times.dt.to_period('Q')
1812
+ quarterly_dates = unique_times[~quarters.duplicated(keep='last')]
1813
+ plot_df = outcome_df[outcome_df[c.TIME].isin(quarterly_dates)].copy()
1814
+ x_axis_format = c.QUARTER_FORMAT
1815
+ tooltip_time_format = c.QUARTER_FORMAT
1816
+ tooltip_time_title = 'Quarter'
1817
+ else:
1818
+ plot_df = outcome_df.copy()
1819
+ x_axis_format = c.DATE_FORMAT
1820
+ tooltip_time_format = c.DATE_FORMAT
1821
+ tooltip_time_title = 'Week'
1822
+
1823
+ legend_order = [c.BASELINE] + sorted([
1824
+ channel
1825
+ for channel in plot_df[c.CHANNEL].unique()
1826
+ if channel != c.BASELINE
1827
+ ])
1828
+
1829
+ plot = (
1830
+ alt.Chart(plot_df, width=c.VEGALITE_FACET_EXTRA_LARGE_WIDTH)
1831
+ .mark_line(point=True)
1832
+ .encode(
1833
+ x=alt.X(
1834
+ f'{c.TIME}:T',
1835
+ title='Time period',
1836
+ axis=alt.Axis(
1837
+ format=x_axis_format,
1838
+ grid=False,
1839
+ domainColor=c.GREY_300,
1840
+ ),
1841
+ ),
1842
+ y=alt.Y(
1843
+ 'rank:Q',
1844
+ title='Contribution Rank',
1845
+ axis=alt.Axis(
1846
+ ticks=False,
1847
+ domain=False,
1848
+ labelPadding=c.PADDING_10,
1849
+ tickMinStep=1,
1850
+ format='d',
1851
+ ),
1852
+ scale=alt.Scale(
1853
+ zero=False,
1854
+ reverse=True,
1855
+ ),
1856
+ ),
1857
+ color=alt.Color(
1858
+ f'{c.CHANNEL}:N',
1859
+ legend=alt.Legend(
1860
+ labelFontSize=c.AXIS_FONT_SIZE,
1861
+ labelFont=c.FONT_ROBOTO,
1862
+ title=None,
1863
+ orient='bottom',
1864
+ ),
1865
+ scale=alt.Scale(domain=legend_order),
1866
+ sort=legend_order,
1867
+ ),
1868
+ tooltip=[
1869
+ alt.Tooltip(
1870
+ f'{c.TIME}:T',
1871
+ format=tooltip_time_format,
1872
+ title=tooltip_time_title,
1873
+ ),
1874
+ alt.Tooltip(f'{c.CHANNEL}:N', title='Channel'),
1875
+ alt.Tooltip('rank:O', title='Rank'),
1876
+ alt.Tooltip(
1877
+ f'{c.INCREMENTAL_OUTCOME}:Q',
1878
+ format=',.0f',
1879
+ title='Incremental Outcome',
1880
+ ),
1881
+ ],
1882
+ )
1883
+ .properties(
1884
+ title=formatter.custom_title_params(
1885
+ summary_text.CHANNEL_CONTRIB_RANK_CHART_TITLE
1886
+ )
1887
+ )
1888
+ .configure_axis(titlePadding=c.PADDING_10, **formatter.TEXT_CONFIG)
1889
+ .configure_view(strokeOpacity=0)
1890
+ )
1891
+
1892
+ return plot
1893
+
1601
1894
  def plot_contribution_waterfall_chart(self) -> alt.Chart:
1602
1895
  """Plots a waterfall chart of the contribution share per channel.
1603
1896
 
@@ -1621,7 +1914,7 @@ class MediaSummary:
1621
1914
  num_channels = len(outcome_df[c.CHANNEL])
1622
1915
 
1623
1916
  base = (
1624
- alt.Chart(outcome_df)
1917
+ alt.Chart(outcome_df, width=c.VEGALITE_FACET_LARGE_WIDTH)
1625
1918
  .transform_window(
1626
1919
  sum_outcome=f'sum({c.PCT_OF_CONTRIBUTION})',
1627
1920
  kwargs=f'lead({c.CHANNEL})',
@@ -1682,7 +1975,6 @@ class MediaSummary:
1682
1975
  ),
1683
1976
  height=c.BAR_SIZE * num_channels
1684
1977
  + c.BAR_SIZE * 2 * c.SCALED_PADDING,
1685
- width=500,
1686
1978
  )
1687
1979
  .configure_axis(titlePadding=c.PADDING_10, **formatter.TEXT_CONFIG)
1688
1980
  .configure_view(strokeOpacity=0)
@@ -1968,7 +2260,7 @@ class MediaSummary:
1968
2260
  An Altair bubble plot showing the ROI, spend, and another metric.
1969
2261
  """
1970
2262
  if selected_channels:
1971
- channels = self.paid_summary_metrics.channel
2263
+ channels = self.get_paid_summary_metrics().channel
1972
2264
  if any(channel not in channels for channel in selected_channels):
1973
2265
  raise ValueError(
1974
2266
  '`selected_channels` should match the channel dimension names from '
@@ -2105,16 +2397,20 @@ class MediaSummary:
2105
2397
  Returns:
2106
2398
  A dataframe filtered based on the specifications.
2107
2399
  """
2400
+ paid_summary_metrics = self.get_paid_summary_metrics()
2108
2401
  metrics_df = self._summary_metrics_to_mean_df(
2109
- metrics=[c.ROI, metric], selected_channels=selected_channels
2402
+ paid_summary_metrics,
2403
+ metrics=[c.ROI, metric],
2404
+ selected_channels=selected_channels,
2110
2405
  )
2111
- spend_df = self.paid_summary_metrics[c.SPEND].to_dataframe().reset_index()
2406
+ spend_df = paid_summary_metrics[c.SPEND].to_dataframe().reset_index()
2112
2407
  return metrics_df.merge(spend_df, on=c.CHANNEL)
2113
2408
 
2114
2409
  def _transform_contribution_metrics(
2115
2410
  self,
2116
2411
  selected_channels: Sequence[str] | None = None,
2117
2412
  include_non_paid: bool = False,
2413
+ aggregate_times: bool = True,
2118
2414
  ) -> pd.DataFrame:
2119
2415
  """Transforms the media metrics for the contribution plot.
2120
2416
 
@@ -2127,56 +2423,133 @@ class MediaSummary:
2127
2423
  selected_channels: Optional list of a subset of channels to filter by.
2128
2424
  include_non_paid: If `True`, includes the organic media, organic RF and
2129
2425
  non-media channels in the contribution plot. Defaults to `False`.
2426
+ aggregate_times: If `True`, aggregates the metrics across all time
2427
+ periods. If `False`, returns time-varying metrics.
2130
2428
 
2131
2429
  Returns:
2132
2430
  A dataframe with contributions per channel.
2133
2431
  """
2134
- total_media_criteria = {
2135
- c.DISTRIBUTION: c.POSTERIOR,
2136
- c.METRIC: c.MEAN,
2137
- c.CHANNEL: c.ALL_CHANNELS,
2138
- }
2139
2432
  summary_metrics = (
2140
- self.all_summary_metrics
2433
+ self.get_all_summary_metrics(aggregate_times=aggregate_times)
2141
2434
  if include_non_paid
2142
- else self.paid_summary_metrics
2435
+ else self.get_paid_summary_metrics(aggregate_times=aggregate_times)
2143
2436
  )
2144
- total_media_outcome = (
2145
- summary_metrics[c.INCREMENTAL_OUTCOME].sel(total_media_criteria).item()
2437
+
2438
+ contribution_df = self._calculate_contribution_dataframe(
2439
+ summary_metrics, selected_channels
2146
2440
  )
2147
- total_media_pct = (
2148
- summary_metrics[c.PCT_OF_CONTRIBUTION].sel(total_media_criteria).item()
2149
- / 100
2441
+ baseline_df = self._calculate_baseline_contribution_dataframe(
2442
+ summary_metrics, aggregate_times
2150
2443
  )
2151
- total_outcome = total_media_outcome / total_media_pct
2152
- baseline_pct = 1 - total_media_pct
2153
- baseline_outcome = total_outcome * baseline_pct
2154
2444
 
2155
- baseline_df = pd.DataFrame(
2156
- {
2157
- c.CHANNEL: c.BASELINE,
2158
- c.INCREMENTAL_OUTCOME: baseline_outcome,
2159
- c.PCT_OF_CONTRIBUTION: baseline_pct,
2160
- },
2161
- index=[0],
2445
+ combined_df = pd.concat([baseline_df, contribution_df]).reset_index(
2446
+ drop=True
2162
2447
  )
2163
- outcome_df = self._summary_metrics_to_mean_df(
2448
+ if aggregate_times:
2449
+ combined_df.sort_values(
2450
+ by=c.INCREMENTAL_OUTCOME, ascending=False, inplace=True
2451
+ )
2452
+ else:
2453
+ combined_df.sort_values(
2454
+ by=[c.TIME, c.INCREMENTAL_OUTCOME],
2455
+ ascending=[True, False],
2456
+ inplace=True,
2457
+ )
2458
+ return combined_df
2459
+
2460
+ def _calculate_contribution_dataframe(
2461
+ self,
2462
+ summary_metrics: xr.Dataset,
2463
+ selected_channels: Sequence[str] | None,
2464
+ ) -> pd.DataFrame:
2465
+ """Calculates the contribution dataframe.
2466
+
2467
+ Args:
2468
+ summary_metrics: xarray Dataset of summary metrics.
2469
+ selected_channels: Optional list of channels.
2470
+
2471
+ Returns:
2472
+ pd.DataFrame: Contribution dataframe.
2473
+ Shape:
2474
+ - If `aggregate_times=True`: (n_channels, 3)
2475
+ Columns: 'channel', 'incremental_outcome', 'pct_of_contribution'
2476
+ - If `aggregate_times=False`: (n_channels * n_times, 4)
2477
+ Columns: 'time', 'channel', 'incremental_outcome',
2478
+ 'pct_of_contribution'
2479
+ """
2480
+
2481
+ contribution_df = self._summary_metrics_to_mean_df(
2482
+ summary_metrics=summary_metrics,
2164
2483
  metrics=[
2165
2484
  c.INCREMENTAL_OUTCOME,
2166
2485
  c.PCT_OF_CONTRIBUTION,
2167
2486
  ],
2168
2487
  selected_channels=selected_channels,
2169
- include_non_paid=include_non_paid,
2170
2488
  )
2171
2489
  # Convert to percentage values between 0-1.
2172
- outcome_df[c.PCT_OF_CONTRIBUTION] = outcome_df[c.PCT_OF_CONTRIBUTION].div(
2173
- 100
2490
+ contribution_df[c.PCT_OF_CONTRIBUTION] = contribution_df[
2491
+ c.PCT_OF_CONTRIBUTION
2492
+ ].div(100)
2493
+ return contribution_df
2494
+
2495
+ def _calculate_baseline_contribution_dataframe(
2496
+ self, summary_metrics: xr.Dataset, aggregate_times: bool
2497
+ ) -> pd.DataFrame:
2498
+ """Calculates the baseline contribution dataframe.
2499
+
2500
+ Calculates a single total outcome and baseline if aggregating.
2501
+ Calculates time-varying total and baseline if not aggregating.
2502
+
2503
+ Args:
2504
+ summary_metrics: The summary metrics dataset.
2505
+ aggregate_times: Whether to aggregate times.
2506
+
2507
+ Returns:
2508
+ A DataFrame containing the baseline metrics.
2509
+ Shape:
2510
+ - If `aggregate_times=True`: (1, 3)
2511
+ Columns: 'channel', 'incremental_outcome', 'pct_of_contribution'
2512
+ - If `aggregate_times=False`: (n_times, 4)
2513
+ Columns: 'time', 'channel', 'incremental_outcome',
2514
+ 'pct_of_contribution'
2515
+ """
2516
+ total_media_criteria = {
2517
+ c.DISTRIBUTION: c.POSTERIOR,
2518
+ c.METRIC: c.MEAN,
2519
+ c.CHANNEL: c.ALL_CHANNELS,
2520
+ }
2521
+ if not aggregate_times:
2522
+ total_media_criteria[c.TIME] = (
2523
+ self._selected_times
2524
+ or self.get_all_summary_metrics(aggregate_times=False).time
2525
+ )
2526
+
2527
+ total_media_outcome = summary_metrics[c.INCREMENTAL_OUTCOME].sel(
2528
+ total_media_criteria
2174
2529
  )
2175
- outcome_df = pd.concat([baseline_df, outcome_df]).reset_index(drop=True)
2176
- outcome_df.sort_values(
2177
- by=c.INCREMENTAL_OUTCOME, ascending=False, inplace=True
2530
+ total_media_pct = (
2531
+ summary_metrics[c.PCT_OF_CONTRIBUTION].sel(total_media_criteria) / 100
2178
2532
  )
2179
- return outcome_df
2533
+ total_outcome = total_media_outcome / total_media_pct
2534
+ baseline_pct = 1 - total_media_pct
2535
+ baseline_outcome = total_outcome * baseline_pct
2536
+
2537
+ if aggregate_times:
2538
+ return pd.DataFrame(
2539
+ {
2540
+ c.CHANNEL: c.BASELINE,
2541
+ c.INCREMENTAL_OUTCOME: baseline_outcome.item(),
2542
+ c.PCT_OF_CONTRIBUTION: baseline_pct.item(),
2543
+ },
2544
+ index=[0],
2545
+ )
2546
+ else:
2547
+ return pd.DataFrame({
2548
+ c.TIME: self._selected_times or summary_metrics.time.values,
2549
+ c.CHANNEL: c.BASELINE,
2550
+ c.INCREMENTAL_OUTCOME: baseline_outcome.values,
2551
+ c.PCT_OF_CONTRIBUTION: baseline_pct.values,
2552
+ })
2180
2553
 
2181
2554
  def _transform_contribution_spend_metrics(self) -> pd.DataFrame:
2182
2555
  """Transforms the media metrics for the spend vs contribution plot.
@@ -2189,12 +2562,13 @@ class MediaSummary:
2189
2562
  Returns:
2190
2563
  A dataframe of spend and outcome percentages and ROI per channel.
2191
2564
  """
2565
+ paid_summary_metrics = self.get_paid_summary_metrics()
2192
2566
  if self._meridian.input_data.revenue_per_kpi is not None:
2193
2567
  outcome = summary_text.REVENUE_LABEL
2194
2568
  else:
2195
2569
  outcome = summary_text.KPI_LABEL
2196
2570
  total_media_outcome = (
2197
- self.paid_summary_metrics[c.INCREMENTAL_OUTCOME]
2571
+ paid_summary_metrics[c.INCREMENTAL_OUTCOME]
2198
2572
  .sel(
2199
2573
  distribution=c.POSTERIOR,
2200
2574
  metric=c.MEAN,
@@ -2203,7 +2577,7 @@ class MediaSummary:
2203
2577
  .item()
2204
2578
  )
2205
2579
  outcome_pct_df = self._summary_metrics_to_mean_df(
2206
- metrics=[c.INCREMENTAL_OUTCOME]
2580
+ paid_summary_metrics, metrics=[c.INCREMENTAL_OUTCOME]
2207
2581
  )
2208
2582
  outcome_pct_df[c.PCT] = outcome_pct_df[c.INCREMENTAL_OUTCOME].div(
2209
2583
  total_media_outcome
@@ -2211,7 +2585,7 @@ class MediaSummary:
2211
2585
  outcome_pct_df.drop(columns=[c.INCREMENTAL_OUTCOME], inplace=True)
2212
2586
  outcome_pct_df['label'] = f'% {outcome}'
2213
2587
  spend_pct_df = (
2214
- self.paid_summary_metrics[c.PCT_OF_SPEND]
2588
+ paid_summary_metrics[c.PCT_OF_SPEND]
2215
2589
  .drop_sel(channel=[c.ALL_CHANNELS])
2216
2590
  .to_dataframe()
2217
2591
  .reset_index()
@@ -2221,7 +2595,9 @@ class MediaSummary:
2221
2595
  spend_pct_df['label'] = '% Spend'
2222
2596
 
2223
2597
  pct_df = pd.concat([outcome_pct_df, spend_pct_df])
2224
- roi_df = self._summary_metrics_to_mean_df(metrics=[c.ROI])
2598
+ roi_df = self._summary_metrics_to_mean_df(
2599
+ paid_summary_metrics, metrics=[c.ROI]
2600
+ )
2225
2601
  plot_df = pct_df.merge(roi_df, on=c.CHANNEL)
2226
2602
  scale_factor = plot_df[c.PCT].max() / plot_df[c.ROI].max()
2227
2603
  plot_df[c.ROI_SCALED] = plot_df[c.ROI] * scale_factor
@@ -2230,9 +2606,9 @@ class MediaSummary:
2230
2606
 
2231
2607
  def _summary_metrics_to_mean_df(
2232
2608
  self,
2609
+ summary_metrics: xr.Dataset,
2233
2610
  metrics: Sequence[str],
2234
2611
  selected_channels: Sequence[str] | None = None,
2235
- include_non_paid: bool = False,
2236
2612
  ) -> pd.DataFrame:
2237
2613
  """Transforms the summary metrics to a dataframe of mean values.
2238
2614
 
@@ -2241,20 +2617,14 @@ class MediaSummary:
2241
2617
  channels.
2242
2618
 
2243
2619
  Args:
2620
+ summary_metrics: The summary metrics dataset.
2244
2621
  metrics: A list of the metrics to include in the dataframe.
2245
2622
  selected_channels: List of channels to include. If None, all media
2246
2623
  channels will be included.
2247
- include_non_paid: If `True`, includes the organic media, organic RF and
2248
- non-media channels in the dataframe. Defaults to `False`.
2249
2624
 
2250
2625
  Returns:
2251
2626
  A dataframe of posterior mean values for the selected metrics and media.
2252
2627
  """
2253
- summary_metrics = (
2254
- self.all_summary_metrics
2255
- if include_non_paid
2256
- else self.paid_summary_metrics
2257
- )
2258
2628
  metrics_dataset = summary_metrics[metrics].sel(
2259
2629
  distribution=c.POSTERIOR, metric=c.MEAN
2260
2630
  )
@@ -2284,7 +2654,7 @@ class MediaSummary:
2284
2654
  central_tendency = c.MEDIAN if metric == c.CPIK else c.MEAN
2285
2655
  unused_central_tendency = c.MEAN if metric == c.CPIK else c.MEDIAN
2286
2656
  return (
2287
- self.paid_summary_metrics[metric]
2657
+ self.get_paid_summary_metrics()[metric]
2288
2658
  .sel(distribution=c.POSTERIOR)
2289
2659
  .drop_sel(
2290
2660
  channel=c.ALL_CHANNELS,