google-meridian 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,8 @@
16
16
 
17
17
  from collections.abc import Sequence
18
18
  import functools
19
+ from typing import Mapping
20
+ import warnings
19
21
  import altair as alt
20
22
  from meridian import constants as c
21
23
  from meridian.analysis import analyzer
@@ -32,7 +34,9 @@ import xarray as xr
32
34
  __all__ = [
33
35
  'ModelDiagnostics',
34
36
  'ModelFit',
37
+ 'MediaSummary',
35
38
  'ReachAndFrequency',
39
+ 'MediaEffects',
36
40
  ]
37
41
 
38
42
 
@@ -246,7 +250,7 @@ class ModelDiagnostics:
246
250
  .mark_area(opacity=0.7)
247
251
  .encode(
248
252
  x=f'{parameter}:Q',
249
- y='density:Q',
253
+ y=alt.Y(shorthand='density:Q', stack=False),
250
254
  color=f'{c.DISTRIBUTION}:N',
251
255
  )
252
256
  )
@@ -883,11 +887,13 @@ class MediaEffects:
883
887
  A Dataset displaying the response curves data.
884
888
  """
885
889
  selected_times_list = list(selected_times) if selected_times else None
890
+ use_kpi = self._meridian.input_data.revenue_per_kpi is None
886
891
  return self._analyzer.response_curves(
887
892
  spend_multipliers=list(np.arange(0, 2.2, c.RESPONSE_CURVE_STEP_SIZE)),
888
893
  confidence_level=confidence_level,
889
894
  selected_times=selected_times_list,
890
895
  by_reach=by_reach,
896
+ use_kpi=use_kpi,
891
897
  )
892
898
 
893
899
  @functools.lru_cache(maxsize=128)
@@ -1170,7 +1176,7 @@ class MediaEffects:
1170
1176
  confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
1171
1177
  include_prior: bool = True,
1172
1178
  include_ci: bool = True,
1173
- ) -> alt.Chart | list[alt.Chart]:
1179
+ ) -> Mapping[str, alt.Chart]:
1174
1180
  """Plots the Hill curves for each channel.
1175
1181
 
1176
1182
  Args:
@@ -1181,22 +1187,23 @@ class MediaEffects:
1181
1187
  include_ci: If `True`, plots the credible interval. Defaults to `True`.
1182
1188
 
1183
1189
  Returns:
1184
- A faceted Altair plot showing the histogram, prior and posterior lines,
1185
- and bands for the Hill saturation curves. When there are both media and
1186
- RF channels, a list of 2 faceted Altair plots are returned: one
1187
- for the media channels and another for the RF channels.
1190
+ A dictionary mapping channel type constants (`media`, `rf`, and
1191
+ `organic_media`) to their respective Altair chart objects. Keys are only
1192
+ present if charts for that type were generated (i.e., if the
1193
+ corresponding channels exist in the data). Returns an empty dictionary if
1194
+ no relevant channels are found.
1188
1195
  """
1189
1196
  hill_curves_dataframe = self.hill_curves_dataframe(
1190
1197
  confidence_level=confidence_level
1191
1198
  )
1192
1199
  channel_types = list(set(hill_curves_dataframe[c.CHANNEL_TYPE]))
1193
- plot_media, plot_rf = None, None
1200
+ plots: dict[str, alt.Chart] = {}
1194
1201
 
1195
1202
  if c.MEDIA in channel_types:
1196
1203
  media_df = hill_curves_dataframe[
1197
1204
  hill_curves_dataframe[c.CHANNEL_TYPE] == c.MEDIA
1198
1205
  ]
1199
- plot_media = self._plot_hill_curves_helper(
1206
+ plots[c.MEDIA] = self._plot_hill_curves_helper(
1200
1207
  media_df, include_prior, include_ci
1201
1208
  )
1202
1209
 
@@ -1204,14 +1211,19 @@ class MediaEffects:
1204
1211
  rf_df = hill_curves_dataframe[
1205
1212
  hill_curves_dataframe[c.CHANNEL_TYPE] == c.RF
1206
1213
  ]
1207
- plot_rf = self._plot_hill_curves_helper(rf_df, include_prior, include_ci)
1214
+ plots[c.RF] = self._plot_hill_curves_helper(
1215
+ rf_df, include_prior, include_ci
1216
+ )
1208
1217
 
1209
- if plot_media and plot_rf:
1210
- return [plot_media, plot_rf]
1211
- elif plot_media:
1212
- return plot_media
1213
- else:
1214
- return plot_rf
1218
+ if c.ORGANIC_MEDIA in channel_types:
1219
+ organic_media_df = hill_curves_dataframe[
1220
+ hill_curves_dataframe[c.CHANNEL_TYPE] == c.ORGANIC_MEDIA
1221
+ ]
1222
+ plots[c.ORGANIC_MEDIA] = self._plot_hill_curves_helper(
1223
+ organic_media_df, include_prior, include_ci
1224
+ )
1225
+
1226
+ return plots
1215
1227
 
1216
1228
  def _plot_hill_curves_helper(
1217
1229
  self,
@@ -1231,13 +1243,26 @@ class MediaEffects:
1231
1243
  Returns:
1232
1244
  A faceted Altair plot showing the histogram and prior+posterior lines and
1233
1245
  bands for the Hill curves.
1246
+
1247
+ Raises:
1248
+ ValueError: If the input DataFrame is empty, missing the channel_type
1249
+ column, or contains an unsupported channel type.
1234
1250
  """
1235
- if c.MEDIA in list(df_channel_type[c.CHANNEL_TYPE]):
1251
+ channel_type = df_channel_type[c.CHANNEL_TYPE].iloc[0]
1252
+ if channel_type == c.MEDIA:
1236
1253
  x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1237
1254
  shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1238
- else:
1255
+ elif channel_type == c.RF:
1239
1256
  x_axis_title = summary_text.HILL_X_AXIS_RF_LABEL
1240
1257
  shaded_area_title = summary_text.HILL_SHADED_REGION_RF_LABEL
1258
+ elif channel_type == c.ORGANIC_MEDIA:
1259
+ x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1260
+ shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1261
+ else:
1262
+ raise ValueError(
1263
+ f"Unsupported channel type '{channel_type}' found in Hill curve data."
1264
+ ' Expected one of: {c.MEDIA}, {c.RF}, {c.ORGANIC_MEDIA}.'
1265
+ )
1241
1266
  domain_list = [
1242
1267
  c.POSTERIOR,
1243
1268
  c.PRIOR,
@@ -1410,18 +1435,35 @@ class MediaSummary:
1410
1435
  self._marginal_roi_by_reach = marginal_roi_by_reach
1411
1436
  self._non_media_baseline_values = non_media_baseline_values
1412
1437
 
1413
- @functools.cached_property
1414
- def paid_summary_metrics(self) -> xr.Dataset:
1438
+ @property
1439
+ def paid_summary_metrics(self):
1440
+ warnings.warn(
1441
+ 'The `paid_summary_metrics` property is deprecated. Use the'
1442
+ ' `get_paid_summary_metrics()` method instead.',
1443
+ DeprecationWarning,
1444
+ stacklevel=2,
1445
+ )
1446
+ return self.get_paid_summary_metrics()
1447
+
1448
+ @functools.lru_cache(maxsize=128)
1449
+ def get_paid_summary_metrics(
1450
+ self, aggregate_times: bool = True
1451
+ ) -> xr.Dataset:
1415
1452
  """Dataset holding the calculated summary metrics for the paid channels.
1416
1453
 
1417
- The dataset contains the following:
1454
+ Args:
1455
+ aggregate_times: If `True`, aggregates the metrics across all time
1456
+ periods. If `False`, returns time-varying metrics.
1418
1457
 
1419
- - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`, `ci_hi`),
1420
- `distribution` (`prior`, `posterior`)
1421
- - **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
1422
- `pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
1423
- `roi`,
1424
- `effectiveness`, `mroi`.
1458
+ Returns:
1459
+ An `xarray.Dataset` containing the following:
1460
+ - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
1461
+ `ci_hi`),
1462
+ `distribution` (`prior`, `posterior`)
1463
+ - **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
1464
+ `pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
1465
+ `roi`,
1466
+ `effectiveness`, `mroi`.
1425
1467
  """
1426
1468
  return self._analyzer.summary_metrics(
1427
1469
  selected_times=self._selected_times,
@@ -1429,18 +1471,34 @@ class MediaSummary:
1429
1471
  use_kpi=self._meridian.input_data.revenue_per_kpi is None,
1430
1472
  confidence_level=self._confidence_level,
1431
1473
  include_non_paid_channels=False,
1474
+ aggregate_times=aggregate_times,
1432
1475
  )
1433
1476
 
1434
- @functools.cached_property
1435
- def all_summary_metrics(self) -> xr.Dataset:
1477
+ @property
1478
+ def all_summary_metrics(self):
1479
+ warnings.warn(
1480
+ 'The `all_summary_metrics` property is deprecated. Use the'
1481
+ ' `get_all_summary_metrics()` method instead.',
1482
+ DeprecationWarning,
1483
+ stacklevel=2,
1484
+ )
1485
+ return self.get_all_summary_metrics()
1486
+
1487
+ @functools.lru_cache(maxsize=128)
1488
+ def get_all_summary_metrics(self, aggregate_times: bool = True) -> xr.Dataset:
1436
1489
  """Dataset holding the calculated summary metrics for all channels.
1437
1490
 
1438
- The dataset contains the following:
1491
+ Args:
1492
+ aggregate_times: If `True`, aggregates the metrics across all time
1493
+ periods. If `False`, returns time-varying metrics.
1439
1494
 
1440
- - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`, `ci_hi`),
1441
- `distribution` (`prior`, `posterior`)
1442
- - **Data variables:** `incremental_outcome`, `pct_of_contribution`,
1443
- `effectiveness`.
1495
+ Returns:
1496
+ An `xarray.Dataset` containing the following:
1497
+ - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
1498
+ `ci_hi`),
1499
+ `distribution` (`prior`, `posterior`)
1500
+ - **Data variables:** `incremental_outcome`, `pct_of_contribution`,
1501
+ `effectiveness`.
1444
1502
  """
1445
1503
  return self._analyzer.summary_metrics(
1446
1504
  selected_times=self._selected_times,
@@ -1448,6 +1506,7 @@ class MediaSummary:
1448
1506
  confidence_level=self._confidence_level,
1449
1507
  include_non_paid_channels=True,
1450
1508
  non_media_baseline_values=self._non_media_baseline_values,
1509
+ aggregate_times=aggregate_times,
1451
1510
  )
1452
1511
 
1453
1512
  def summary_table(
@@ -1488,7 +1547,7 @@ class MediaSummary:
1488
1547
  ]
1489
1548
  if include_non_paid_channels:
1490
1549
  monetary_metrics = [c.INCREMENTAL_OUTCOME] * use_revenue
1491
- summary_metrics = self.all_summary_metrics
1550
+ summary_metrics = self.get_all_summary_metrics()
1492
1551
  columns_rename_dict = {
1493
1552
  c.PCT_OF_CONTRIBUTION: summary_text.PCT_CONTRIBUTION_COL,
1494
1553
  c.INCREMENTAL_OUTCOME: (
@@ -1512,7 +1571,7 @@ class MediaSummary:
1512
1571
  c.SPEND,
1513
1572
  c.INCREMENTAL_OUTCOME,
1514
1573
  ] * use_revenue
1515
- summary_metrics = self.paid_summary_metrics
1574
+ summary_metrics = self.get_paid_summary_metrics()
1516
1575
  columns_rename_dict = {
1517
1576
  c.PCT_OF_IMPRESSIONS: summary_text.PCT_IMPRESSIONS_COL,
1518
1577
  c.PCT_OF_SPEND: summary_text.PCT_SPEND_COL,
@@ -1598,6 +1657,194 @@ class MediaSummary:
1598
1657
  self._marginal_roi_by_reach = marginal_roi_by_reach
1599
1658
  self._non_media_baseline_values = non_media_baseline_values
1600
1659
 
1660
+ def plot_channel_contribution_area_chart(self) -> alt.Chart:
1661
+ """Plots a stacked area chart of the contribution share per channel by time.
1662
+
1663
+ Returns:
1664
+ An Altair plot showing the contribution share per channel by time.
1665
+ """
1666
+ outcome_df = self._transform_contribution_metrics(
1667
+ include_non_paid=True, aggregate_times=False
1668
+ )
1669
+
1670
+ # Ensure proper ordering for the stacked area chart. Baseline should be at
1671
+ # the bottom. Separate the *stacking* order from the *legend* order.
1672
+ stack_order = sorted([
1673
+ channel
1674
+ for channel in outcome_df[c.CHANNEL].unique()
1675
+ if channel != c.BASELINE
1676
+ ]) + [c.BASELINE]
1677
+
1678
+ legend_order = [c.BASELINE] + sorted([
1679
+ channel
1680
+ for channel in outcome_df[c.CHANNEL].unique()
1681
+ if channel != c.BASELINE
1682
+ ])
1683
+
1684
+ # Get the minimum incremental outcome for baseline across all time periods
1685
+ # as the lower bound for the stacked area chart.
1686
+ min_y = (
1687
+ outcome_df[outcome_df[c.CHANNEL] == c.BASELINE]
1688
+ .groupby(c.TIME)[c.INCREMENTAL_OUTCOME]
1689
+ .min()
1690
+ .min()
1691
+ )
1692
+
1693
+ plot = (
1694
+ alt.Chart(outcome_df, width=c.VEGALITE_FACET_LARGE_WIDTH)
1695
+ .mark_area()
1696
+ .transform_calculate(
1697
+ sort_channel=f'indexof({stack_order}, datum.channel)'
1698
+ )
1699
+ .encode(
1700
+ x=alt.X(
1701
+ f'{c.TIME}:T',
1702
+ title='Time period',
1703
+ axis=alt.Axis(
1704
+ format='%Y Q%q',
1705
+ grid=False,
1706
+ tickCount=8,
1707
+ domainColor=c.GREY_300,
1708
+ ),
1709
+ ),
1710
+ y=alt.Y(
1711
+ f'{c.INCREMENTAL_OUTCOME}:Q',
1712
+ title=(
1713
+ c.REVENUE.title()
1714
+ if self._meridian.input_data.revenue_per_kpi is not None
1715
+ else c.KPI.upper()
1716
+ ),
1717
+ axis=alt.Axis(
1718
+ ticks=False,
1719
+ domain=False,
1720
+ tickCount=5,
1721
+ labelPadding=c.PADDING_10,
1722
+ labelExpr=formatter.compact_number_expr(),
1723
+ **formatter.Y_AXIS_TITLE_CONFIG,
1724
+ ),
1725
+ scale=alt.Scale(domainMin=min_y, clamp=True),
1726
+ ),
1727
+ color=alt.Color(
1728
+ f'{c.CHANNEL}:N',
1729
+ legend=alt.Legend(
1730
+ labelFontSize=c.AXIS_FONT_SIZE,
1731
+ labelFont=c.FONT_ROBOTO,
1732
+ title=None,
1733
+ ),
1734
+ scale=alt.Scale(domain=legend_order),
1735
+ sort=legend_order,
1736
+ ),
1737
+ tooltip=[
1738
+ alt.Tooltip(f'{c.TIME}:T', format='%Y-%m-%d'),
1739
+ c.CHANNEL,
1740
+ alt.Tooltip(f'{c.INCREMENTAL_OUTCOME}:Q', format=',.2f'),
1741
+ ],
1742
+ order=alt.Order('sort_channel:N', sort='descending'),
1743
+ )
1744
+ .properties(
1745
+ title=formatter.custom_title_params(
1746
+ summary_text.CHANNEL_CONTRIB_BY_TIME_CHART_TITLE
1747
+ ),
1748
+ )
1749
+ .configure_axis(titlePadding=c.PADDING_10, **formatter.TEXT_CONFIG)
1750
+ .configure_view(strokeOpacity=0)
1751
+ )
1752
+ return plot
1753
+
1754
+ def plot_channel_contribution_bump_chart(self) -> alt.Chart:
1755
+ """Plots a bump chart of channel contribution rank over time (Quarterly).
1756
+
1757
+ This chart shows the relative rank of each channel's contribution,
1758
+ including the baseline, based on incremental outcome at the end of each
1759
+ quarter. Rank 1 represents the highest contribution.
1760
+
1761
+ Returns:
1762
+ An Altair plot showing the contribution rank per channel by quarter.
1763
+ """
1764
+ outcome_df = self._transform_contribution_metrics(
1765
+ include_non_paid=True, aggregate_times=False
1766
+ )
1767
+ outcome_df[c.TIME] = pd.to_datetime(outcome_df[c.TIME])
1768
+
1769
+ outcome_df['rank'] = outcome_df.groupby(c.TIME)[c.INCREMENTAL_OUTCOME].rank(
1770
+ method='first', ascending=False
1771
+ )
1772
+
1773
+ # Filter data to keep only the last available date within each quarter
1774
+ # for a quarterly view of ranking changes.
1775
+ unique_times = pd.Series(outcome_df[c.TIME].unique()).sort_values()
1776
+ quarters = unique_times.dt.to_period('Q')
1777
+ quarterly_dates = unique_times[~quarters.duplicated(keep='last')]
1778
+ quarterly_rank_df = outcome_df[
1779
+ outcome_df[c.TIME].isin(quarterly_dates)
1780
+ ].copy()
1781
+
1782
+ legend_order = [c.BASELINE] + sorted([
1783
+ channel
1784
+ for channel in quarterly_rank_df[c.CHANNEL].unique()
1785
+ if channel != c.BASELINE
1786
+ ])
1787
+
1788
+ plot = (
1789
+ alt.Chart(quarterly_rank_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
1790
+ .mark_line(point=True)
1791
+ .encode(
1792
+ x=alt.X(
1793
+ f'{c.TIME}:T',
1794
+ title='Time period',
1795
+ axis=alt.Axis(
1796
+ format='%Y Q%q',
1797
+ grid=False,
1798
+ domainColor=c.GREY_300,
1799
+ ),
1800
+ ),
1801
+ y=alt.Y(
1802
+ 'rank:Q',
1803
+ title='Contribution Rank',
1804
+ axis=alt.Axis(
1805
+ ticks=False,
1806
+ domain=False,
1807
+ labelPadding=c.PADDING_10,
1808
+ tickMinStep=1,
1809
+ format='d',
1810
+ ),
1811
+ scale=alt.Scale(
1812
+ zero=False,
1813
+ reverse=True,
1814
+ ),
1815
+ ),
1816
+ color=alt.Color(
1817
+ f'{c.CHANNEL}:N',
1818
+ legend=alt.Legend(
1819
+ labelFontSize=c.AXIS_FONT_SIZE,
1820
+ labelFont=c.FONT_ROBOTO,
1821
+ title=None,
1822
+ ),
1823
+ scale=alt.Scale(domain=legend_order),
1824
+ sort=legend_order,
1825
+ ),
1826
+ tooltip=[
1827
+ alt.Tooltip(f'{c.TIME}:T', format='%Y Q%q', title='Quarter'),
1828
+ alt.Tooltip(f'{c.CHANNEL}:N', title='Channel'),
1829
+ alt.Tooltip('rank:O', title='Rank'),
1830
+ alt.Tooltip(
1831
+ f'{c.INCREMENTAL_OUTCOME}:Q',
1832
+ format=',.0f',
1833
+ title='Incremental Outcome',
1834
+ ),
1835
+ ],
1836
+ )
1837
+ .properties(
1838
+ title=formatter.custom_title_params(
1839
+ summary_text.CHANNEL_CONTRIB_RANK_CHART_TITLE
1840
+ )
1841
+ )
1842
+ .configure_axis(titlePadding=c.PADDING_10, **formatter.TEXT_CONFIG)
1843
+ .configure_view(strokeOpacity=0)
1844
+ )
1845
+
1846
+ return plot
1847
+
1601
1848
  def plot_contribution_waterfall_chart(self) -> alt.Chart:
1602
1849
  """Plots a waterfall chart of the contribution share per channel.
1603
1850
 
@@ -1621,7 +1868,7 @@ class MediaSummary:
1621
1868
  num_channels = len(outcome_df[c.CHANNEL])
1622
1869
 
1623
1870
  base = (
1624
- alt.Chart(outcome_df)
1871
+ alt.Chart(outcome_df, width=c.VEGALITE_FACET_LARGE_WIDTH)
1625
1872
  .transform_window(
1626
1873
  sum_outcome=f'sum({c.PCT_OF_CONTRIBUTION})',
1627
1874
  kwargs=f'lead({c.CHANNEL})',
@@ -1682,7 +1929,6 @@ class MediaSummary:
1682
1929
  ),
1683
1930
  height=c.BAR_SIZE * num_channels
1684
1931
  + c.BAR_SIZE * 2 * c.SCALED_PADDING,
1685
- width=500,
1686
1932
  )
1687
1933
  .configure_axis(titlePadding=c.PADDING_10, **formatter.TEXT_CONFIG)
1688
1934
  .configure_view(strokeOpacity=0)
@@ -1968,7 +2214,7 @@ class MediaSummary:
1968
2214
  An Altair bubble plot showing the ROI, spend, and another metric.
1969
2215
  """
1970
2216
  if selected_channels:
1971
- channels = self.paid_summary_metrics.channel
2217
+ channels = self.get_paid_summary_metrics().channel
1972
2218
  if any(channel not in channels for channel in selected_channels):
1973
2219
  raise ValueError(
1974
2220
  '`selected_channels` should match the channel dimension names from '
@@ -2105,16 +2351,20 @@ class MediaSummary:
2105
2351
  Returns:
2106
2352
  A dataframe filtered based on the specifications.
2107
2353
  """
2354
+ paid_summary_metrics = self.get_paid_summary_metrics()
2108
2355
  metrics_df = self._summary_metrics_to_mean_df(
2109
- metrics=[c.ROI, metric], selected_channels=selected_channels
2356
+ paid_summary_metrics,
2357
+ metrics=[c.ROI, metric],
2358
+ selected_channels=selected_channels,
2110
2359
  )
2111
- spend_df = self.paid_summary_metrics[c.SPEND].to_dataframe().reset_index()
2360
+ spend_df = paid_summary_metrics[c.SPEND].to_dataframe().reset_index()
2112
2361
  return metrics_df.merge(spend_df, on=c.CHANNEL)
2113
2362
 
2114
2363
  def _transform_contribution_metrics(
2115
2364
  self,
2116
2365
  selected_channels: Sequence[str] | None = None,
2117
2366
  include_non_paid: bool = False,
2367
+ aggregate_times: bool = True,
2118
2368
  ) -> pd.DataFrame:
2119
2369
  """Transforms the media metrics for the contribution plot.
2120
2370
 
@@ -2127,56 +2377,133 @@ class MediaSummary:
2127
2377
  selected_channels: Optional list of a subset of channels to filter by.
2128
2378
  include_non_paid: If `True`, includes the organic media, organic RF and
2129
2379
  non-media channels in the contribution plot. Defaults to `False`.
2380
+ aggregate_times: If `True`, aggregates the metrics across all time
2381
+ periods. If `False`, returns time-varying metrics.
2130
2382
 
2131
2383
  Returns:
2132
2384
  A dataframe with contributions per channel.
2133
2385
  """
2134
- total_media_criteria = {
2135
- c.DISTRIBUTION: c.POSTERIOR,
2136
- c.METRIC: c.MEAN,
2137
- c.CHANNEL: c.ALL_CHANNELS,
2138
- }
2139
2386
  summary_metrics = (
2140
- self.all_summary_metrics
2387
+ self.get_all_summary_metrics(aggregate_times=aggregate_times)
2141
2388
  if include_non_paid
2142
- else self.paid_summary_metrics
2389
+ else self.get_paid_summary_metrics(aggregate_times=aggregate_times)
2143
2390
  )
2144
- total_media_outcome = (
2145
- summary_metrics[c.INCREMENTAL_OUTCOME].sel(total_media_criteria).item()
2391
+
2392
+ contribution_df = self._calculate_contribution_dataframe(
2393
+ summary_metrics, selected_channels
2146
2394
  )
2147
- total_media_pct = (
2148
- summary_metrics[c.PCT_OF_CONTRIBUTION].sel(total_media_criteria).item()
2149
- / 100
2395
+ baseline_df = self._calculate_baseline_contribution_dataframe(
2396
+ summary_metrics, aggregate_times
2150
2397
  )
2151
- total_outcome = total_media_outcome / total_media_pct
2152
- baseline_pct = 1 - total_media_pct
2153
- baseline_outcome = total_outcome * baseline_pct
2154
2398
 
2155
- baseline_df = pd.DataFrame(
2156
- {
2157
- c.CHANNEL: c.BASELINE,
2158
- c.INCREMENTAL_OUTCOME: baseline_outcome,
2159
- c.PCT_OF_CONTRIBUTION: baseline_pct,
2160
- },
2161
- index=[0],
2399
+ combined_df = pd.concat([baseline_df, contribution_df]).reset_index(
2400
+ drop=True
2162
2401
  )
2163
- outcome_df = self._summary_metrics_to_mean_df(
2402
+ if aggregate_times:
2403
+ combined_df.sort_values(
2404
+ by=c.INCREMENTAL_OUTCOME, ascending=False, inplace=True
2405
+ )
2406
+ else:
2407
+ combined_df.sort_values(
2408
+ by=[c.TIME, c.INCREMENTAL_OUTCOME],
2409
+ ascending=[True, False],
2410
+ inplace=True,
2411
+ )
2412
+ return combined_df
2413
+
2414
+ def _calculate_contribution_dataframe(
2415
+ self,
2416
+ summary_metrics: xr.Dataset,
2417
+ selected_channels: Sequence[str] | None,
2418
+ ) -> pd.DataFrame:
2419
+ """Calculates the contribution dataframe.
2420
+
2421
+ Args:
2422
+ summary_metrics: xarray Dataset of summary metrics.
2423
+ selected_channels: Optional list of channels.
2424
+
2425
+ Returns:
2426
+ pd.DataFrame: Contribution dataframe.
2427
+ Shape:
2428
+ - If `aggregate_times=True`: (n_channels, 3)
2429
+ Columns: 'channel', 'incremental_outcome', 'pct_of_contribution'
2430
+ - If `aggregate_times=False`: (n_channels * n_times, 4)
2431
+ Columns: 'time', 'channel', 'incremental_outcome',
2432
+ 'pct_of_contribution'
2433
+ """
2434
+
2435
+ contribution_df = self._summary_metrics_to_mean_df(
2436
+ summary_metrics=summary_metrics,
2164
2437
  metrics=[
2165
2438
  c.INCREMENTAL_OUTCOME,
2166
2439
  c.PCT_OF_CONTRIBUTION,
2167
2440
  ],
2168
2441
  selected_channels=selected_channels,
2169
- include_non_paid=include_non_paid,
2170
2442
  )
2171
2443
  # Convert to percentage values between 0-1.
2172
- outcome_df[c.PCT_OF_CONTRIBUTION] = outcome_df[c.PCT_OF_CONTRIBUTION].div(
2173
- 100
2444
+ contribution_df[c.PCT_OF_CONTRIBUTION] = contribution_df[
2445
+ c.PCT_OF_CONTRIBUTION
2446
+ ].div(100)
2447
+ return contribution_df
2448
+
2449
+ def _calculate_baseline_contribution_dataframe(
2450
+ self, summary_metrics: xr.Dataset, aggregate_times: bool
2451
+ ) -> pd.DataFrame:
2452
+ """Calculates the baseline contribution dataframe.
2453
+
2454
+ Calculates a single total outcome and baseline if aggregating.
2455
+ Calculates time-varying total and baseline if not aggregating.
2456
+
2457
+ Args:
2458
+ summary_metrics: The summary metrics dataset.
2459
+ aggregate_times: Whether to aggregate times.
2460
+
2461
+ Returns:
2462
+ A DataFrame containing the baseline metrics.
2463
+ Shape:
2464
+ - If `aggregate_times=True`: (1, 3)
2465
+ Columns: 'channel', 'incremental_outcome', 'pct_of_contribution'
2466
+ - If `aggregate_times=False`: (n_times, 4)
2467
+ Columns: 'time', 'channel', 'incremental_outcome',
2468
+ 'pct_of_contribution'
2469
+ """
2470
+ total_media_criteria = {
2471
+ c.DISTRIBUTION: c.POSTERIOR,
2472
+ c.METRIC: c.MEAN,
2473
+ c.CHANNEL: c.ALL_CHANNELS,
2474
+ }
2475
+ if not aggregate_times:
2476
+ total_media_criteria[c.TIME] = (
2477
+ self._selected_times
2478
+ or self.get_all_summary_metrics(aggregate_times=False).time
2479
+ )
2480
+
2481
+ total_media_outcome = summary_metrics[c.INCREMENTAL_OUTCOME].sel(
2482
+ total_media_criteria
2174
2483
  )
2175
- outcome_df = pd.concat([baseline_df, outcome_df]).reset_index(drop=True)
2176
- outcome_df.sort_values(
2177
- by=c.INCREMENTAL_OUTCOME, ascending=False, inplace=True
2484
+ total_media_pct = (
2485
+ summary_metrics[c.PCT_OF_CONTRIBUTION].sel(total_media_criteria) / 100
2178
2486
  )
2179
- return outcome_df
2487
+ total_outcome = total_media_outcome / total_media_pct
2488
+ baseline_pct = 1 - total_media_pct
2489
+ baseline_outcome = total_outcome * baseline_pct
2490
+
2491
+ if aggregate_times:
2492
+ return pd.DataFrame(
2493
+ {
2494
+ c.CHANNEL: c.BASELINE,
2495
+ c.INCREMENTAL_OUTCOME: baseline_outcome.item(),
2496
+ c.PCT_OF_CONTRIBUTION: baseline_pct.item(),
2497
+ },
2498
+ index=[0],
2499
+ )
2500
+ else:
2501
+ return pd.DataFrame({
2502
+ c.TIME: self._selected_times or summary_metrics.time.values,
2503
+ c.CHANNEL: c.BASELINE,
2504
+ c.INCREMENTAL_OUTCOME: baseline_outcome.values,
2505
+ c.PCT_OF_CONTRIBUTION: baseline_pct.values,
2506
+ })
2180
2507
 
2181
2508
  def _transform_contribution_spend_metrics(self) -> pd.DataFrame:
2182
2509
  """Transforms the media metrics for the spend vs contribution plot.
@@ -2189,12 +2516,13 @@ class MediaSummary:
2189
2516
  Returns:
2190
2517
  A dataframe of spend and outcome percentages and ROI per channel.
2191
2518
  """
2519
+ paid_summary_metrics = self.get_paid_summary_metrics()
2192
2520
  if self._meridian.input_data.revenue_per_kpi is not None:
2193
2521
  outcome = summary_text.REVENUE_LABEL
2194
2522
  else:
2195
2523
  outcome = summary_text.KPI_LABEL
2196
2524
  total_media_outcome = (
2197
- self.paid_summary_metrics[c.INCREMENTAL_OUTCOME]
2525
+ paid_summary_metrics[c.INCREMENTAL_OUTCOME]
2198
2526
  .sel(
2199
2527
  distribution=c.POSTERIOR,
2200
2528
  metric=c.MEAN,
@@ -2203,7 +2531,7 @@ class MediaSummary:
2203
2531
  .item()
2204
2532
  )
2205
2533
  outcome_pct_df = self._summary_metrics_to_mean_df(
2206
- metrics=[c.INCREMENTAL_OUTCOME]
2534
+ paid_summary_metrics, metrics=[c.INCREMENTAL_OUTCOME]
2207
2535
  )
2208
2536
  outcome_pct_df[c.PCT] = outcome_pct_df[c.INCREMENTAL_OUTCOME].div(
2209
2537
  total_media_outcome
@@ -2211,7 +2539,7 @@ class MediaSummary:
2211
2539
  outcome_pct_df.drop(columns=[c.INCREMENTAL_OUTCOME], inplace=True)
2212
2540
  outcome_pct_df['label'] = f'% {outcome}'
2213
2541
  spend_pct_df = (
2214
- self.paid_summary_metrics[c.PCT_OF_SPEND]
2542
+ paid_summary_metrics[c.PCT_OF_SPEND]
2215
2543
  .drop_sel(channel=[c.ALL_CHANNELS])
2216
2544
  .to_dataframe()
2217
2545
  .reset_index()
@@ -2221,7 +2549,9 @@ class MediaSummary:
2221
2549
  spend_pct_df['label'] = '% Spend'
2222
2550
 
2223
2551
  pct_df = pd.concat([outcome_pct_df, spend_pct_df])
2224
- roi_df = self._summary_metrics_to_mean_df(metrics=[c.ROI])
2552
+ roi_df = self._summary_metrics_to_mean_df(
2553
+ paid_summary_metrics, metrics=[c.ROI]
2554
+ )
2225
2555
  plot_df = pct_df.merge(roi_df, on=c.CHANNEL)
2226
2556
  scale_factor = plot_df[c.PCT].max() / plot_df[c.ROI].max()
2227
2557
  plot_df[c.ROI_SCALED] = plot_df[c.ROI] * scale_factor
@@ -2230,9 +2560,9 @@ class MediaSummary:
2230
2560
 
2231
2561
  def _summary_metrics_to_mean_df(
2232
2562
  self,
2563
+ summary_metrics: xr.Dataset,
2233
2564
  metrics: Sequence[str],
2234
2565
  selected_channels: Sequence[str] | None = None,
2235
- include_non_paid: bool = False,
2236
2566
  ) -> pd.DataFrame:
2237
2567
  """Transforms the summary metrics to a dataframe of mean values.
2238
2568
 
@@ -2241,20 +2571,14 @@ class MediaSummary:
2241
2571
  channels.
2242
2572
 
2243
2573
  Args:
2574
+ summary_metrics: The summary metrics dataset.
2244
2575
  metrics: A list of the metrics to include in the dataframe.
2245
2576
  selected_channels: List of channels to include. If None, all media
2246
2577
  channels will be included.
2247
- include_non_paid: If `True`, includes the organic media, organic RF and
2248
- non-media channels in the dataframe. Defaults to `False`.
2249
2578
 
2250
2579
  Returns:
2251
2580
  A dataframe of posterior mean values for the selected metrics and media.
2252
2581
  """
2253
- summary_metrics = (
2254
- self.all_summary_metrics
2255
- if include_non_paid
2256
- else self.paid_summary_metrics
2257
- )
2258
2582
  metrics_dataset = summary_metrics[metrics].sel(
2259
2583
  distribution=c.POSTERIOR, metric=c.MEAN
2260
2584
  )
@@ -2284,7 +2608,7 @@ class MediaSummary:
2284
2608
  central_tendency = c.MEDIAN if metric == c.CPIK else c.MEAN
2285
2609
  unused_central_tendency = c.MEAN if metric == c.CPIK else c.MEDIAN
2286
2610
  return (
2287
- self.paid_summary_metrics[metric]
2611
+ self.get_paid_summary_metrics()[metric]
2288
2612
  .sel(distribution=c.POSTERIOR)
2289
2613
  .drop_sel(
2290
2614
  channel=c.ALL_CHANNELS,