google-meridian 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2571,6 +2571,9 @@ ADSTOCK_DECAY_MEAN = np.array([1.0, 1.0, 0.8493, 0.8630, 0.7215])
2571
2571
  ORGANIC_ADSTOCK_DECAY_CI_HI = np.array([1.0, 0.9636, 0.9291, 0.8962, 0.8650])
2572
2572
  ORGANIC_ADSTOCK_DECAY_CI_LO = np.array([1.0, 0.6623, 0.4394, 0.2920, 0.1944])
2573
2573
  ORGANIC_ADSTOCK_DECAY_MEAN = np.array([1.0, 0.8076, 0.6633, 0.5537, 0.4693])
2574
+ ORGANIC_RF_ADSTOCK_DECAY_CI_HI = np.array([1.0, 0.9208, 0.8482, 0.781, 0.7202])
2575
+ ORGANIC_RF_ADSTOCK_DECAY_CI_LO = np.array([1.0, 0.6674, 0.4460, 0.2985, 0.2001])
2576
+ ORGANIC_RF_ADSTOCK_DECAY_MEAN = np.array([1.0, 0.8344, 0.7042, 0.6001, 0.5155])
2574
2577
  HILL_CURVES_CI_HI = np.array([0.0, 0.0, 0.00098, 0.00895, 0.00195])
2575
2578
  HILL_CURVES_CI_LO = np.array([0.0, 0.0, 0.00085, 0.00322, 0.00169])
2576
2579
  HILL_CURVES_MEAN = np.array([0.0, 0.0, 0.00091, 0.00606, 0.00183])
@@ -2600,110 +2603,106 @@ PREDICTIVE_ACCURACY_NO_HOLDOUT_ID_TIMES_AND_GEOS = np.array(
2600
2603
  [-13.597, -7.360, 1.634, 0.887, 0.990, 0.757]
2601
2604
  )
2602
2605
  PREDICTIVE_ACCURACY_HOLDOUT_ID_NO_GEOS_OR_TIMES = np.array([
2603
- -2.907,
2604
- -2.356,
2605
- -2.784,
2606
- -3.267,
2607
- -1.431,
2608
- -5.836,
2609
- 2.481,
2610
- 46.381,
2611
- 10.724,
2612
- 0.729,
2613
- 63.022,
2614
- 0.696,
2615
- 0.994,
2616
- 1.038,
2617
- 1.001,
2618
- 0.633,
2619
- 0.906,
2620
- 0.596,
2606
+ -2.690704,
2607
+ -3.231603,
2608
+ -2.784759,
2609
+ -2.866354,
2610
+ -1.595214,
2611
+ -5.836171,
2612
+ 12.76909,
2613
+ 1.634914,
2614
+ 10.724035,
2615
+ 0.720384,
2616
+ 1.306767,
2617
+ 0.696319,
2618
+ 0.993516,
2619
+ 1.036484,
2620
+ 1.00181,
2621
+ 0.595755,
2622
+ 0.945436,
2623
+ 0.596676,
2621
2624
  ])
2622
-
2623
2625
  PREDICTIVE_ACCURACY_HOLDOUT_ID_GEOS_NO_TIMES = np.array([
2624
- -4.6241765,
2625
- -3.13614225,
2626
- -4.29795837,
2627
- -4.48636341,
2628
- -1.84796333,
2629
- -4.71139717,
2630
- 2.49890709,
2631
- 8.42507458,
2632
- 3.6478579,
2633
- 1.22353196,
2634
- 4.74503851,
2635
- 1.15957963,
2636
- 1.09398592,
2637
- 1.03744256,
2638
- 1.08360898,
2639
- 0.91437268,
2640
- 1.00845361,
2641
- 0.82047617,
2626
+ -5.167992,
2627
+ -2.34246,
2628
+ -4.297958,
2629
+ -3.945161,
2630
+ -2.930509,
2631
+ -4.711397,
2632
+ 4.080629,
2633
+ 1.58583,
2634
+ 3.647858,
2635
+ 1.457559,
2636
+ 1.268536,
2637
+ 1.15958,
2638
+ 1.123409,
2639
+ 0.932309,
2640
+ 1.083609,
2641
+ 0.8971,
2642
+ 0.932309,
2643
+ 0.820476,
2642
2644
  ])
2643
-
2644
2645
  PREDICTIVE_ACCURACY_HOLDOUT_ID_TIMES_NO_GEOS = np.array([
2645
- -1.23524213,
2646
- -9.06220913,
2647
- -1.39263272,
2648
- 0.3634333,
2649
- -8.32915783,
2650
- -0.81341398,
2651
- 1.35193038,
2652
- 4.68957043,
2653
- 2.24196768,
2654
- 0.75223929,
2655
- 3.71988177,
2656
- 0.99283481,
2657
- 1.01492858,
2658
- 2.17278171,
2659
- 1.17750895,
2660
- 0.45635709,
2661
- 2.17278171,
2662
- 0.60861236,
2646
+ -1.398977,
2647
+ 0.791522,
2648
+ -1.392633,
2649
+ -0.294972,
2650
+ 0.791522,
2651
+ -0.813414,
2652
+ 2.577664,
2653
+ 0.059942,
2654
+ 2.241968,
2655
+ 1.445928,
2656
+ 0.059942,
2657
+ 0.992835,
2658
+ 1.349253,
2659
+ 0.051477,
2660
+ 1.177509,
2661
+ 0.693587,
2662
+ 0.051477,
2663
+ 0.608612,
2663
2664
  ])
2664
2665
  PREDICTIVE_ACCURACY_HOLDOUT_ID_TIMES_AND_GEO = np.array([
2665
- -38.25726318,
2666
- float("-inf"),
2667
- -13.59724903,
2668
- -1.61034203,
2669
- float("-inf"),
2670
- -7.37024498,
2671
- 0.80568475,
2672
- 5.78146744,
2673
- 1.63498175,
2674
- 0.43890095,
2675
- 5.78146744,
2676
- 0.88831377,
2677
- 0.7162329,
2678
- 5.78146744,
2679
- 0.99158531,
2680
- 0.46946037,
2681
- 5.78146744,
2682
- 0.75822759,
2666
+ -20.432268,
2667
+ 0.791522,
2668
+ -13.597249,
2669
+ -4.614312,
2670
+ 0.791522,
2671
+ -7.370245,
2672
+ 2.422502,
2673
+ 0.059942,
2674
+ 1.634982,
2675
+ 2.440133,
2676
+ 0.059942,
2677
+ 0.888314,
2678
+ 1.646492,
2679
+ 0.051477,
2680
+ 0.991585,
2681
+ 1.25399,
2682
+ 0.051477,
2683
+ 0.758228,
2683
2684
  ])
2684
-
2685
2685
  PREDICTIVE_ACCURACY_HOLDOUT_ID_NATIONAL_NO_TIMES = np.array([
2686
- 0.42883771657943726,
2687
- 0.4715208411216736,
2688
- 0.45594334602355957,
2689
- 0.8378637433052063,
2690
- 13.80582332611084,
2691
- 2.9550814628601074,
2692
- 0.34947845339775085,
2693
- 0.4262354075908661,
2694
- 0.3586611747741699,
2686
+ -15.619549,
2687
+ -28.130356,
2688
+ -17.316074,
2689
+ 16.30377,
2690
+ 10.817584,
2691
+ 15.296103,
2692
+ 2.40538,
2693
+ 2.640707,
2694
+ 2.449049,
2695
2695
  ])
2696
-
2697
2696
  PREDICTIVE_ACCURACY_HOLDOUT_ID_NATIONAL_TIMES = np.array([
2698
- -0.30289185,
2699
- float("-inf"),
2700
- 0.15624052,
2701
- 0.86708003,
2702
- 107.46259308,
2703
- 36.39891815,
2704
- 0.61977416,
2705
- 107.46259308,
2706
- 0.88497865,
2697
+ -22.270792,
2698
+ np.nan,
2699
+ -22.270792,
2700
+ 161.652573,
2701
+ np.nan,
2702
+ 161.652573,
2703
+ 4.597788,
2704
+ np.nan,
2705
+ 4.597788,
2707
2706
  ])
2708
2707
 
2709
2708
  SAMPLE_IMPRESSIONS = np.array([
@@ -3142,6 +3141,7 @@ def generate_hill_curves_dataframe() -> pd.DataFrame:
3142
3141
  [f"ch_{i}" for i in range(3)]
3143
3142
  + [f"rf_ch_{i}" for i in range(2)]
3144
3143
  + [f"organic_ch_{i}" for i in range(2)]
3144
+ + [f"organic_rf_ch_{i}" for i in range(1)]
3145
3145
  )
3146
3146
  channel_array = []
3147
3147
  channel_type_array = []
@@ -3154,6 +3154,8 @@ def generate_hill_curves_dataframe() -> pd.DataFrame:
3154
3154
  channel_type_array.append(c.RF)
3155
3155
  elif channel_name.startswith("organic_ch_"):
3156
3156
  channel_type_array.append(c.ORGANIC_MEDIA)
3157
+ elif channel_name.startswith("organic_rf_ch_"):
3158
+ channel_type_array.append(c.ORGANIC_RF)
3157
3159
 
3158
3160
  np.random.seed(0)
3159
3161
  media_units_array = [
@@ -19,6 +19,7 @@ import functools
19
19
  from typing import Mapping
20
20
  import warnings
21
21
  import altair as alt
22
+ from meridian import backend
22
23
  from meridian import constants as c
23
24
  from meridian.analysis import analyzer
24
25
  from meridian.analysis import formatter
@@ -26,8 +27,6 @@ from meridian.analysis import summary_text
26
27
  from meridian.model import model
27
28
  import numpy as np
28
29
  import pandas as pd
29
- import tensorflow as tf
30
- import tensorflow_probability as tfp
31
30
  import xarray as xr
32
31
 
33
32
 
@@ -312,10 +311,10 @@ class ModelDiagnostics:
312
311
  k: v.values
313
312
  for k, v in self._meridian.inference_data.posterior.data_vars.items()
314
313
  }
315
- for k, v in tfp.mcmc.potential_scale_reduction(
316
- {k: tf.einsum('ij...->ji...', v) for k, v in mcmc_states.items()}
314
+ for k, v in backend.mcmc.potential_scale_reduction(
315
+ {k: backend.einsum('ij...->ji...', v) for k, v in mcmc_states.items()}
317
316
  ).items():
318
- rhat_temp = v.numpy().flatten()
317
+ rhat_temp = np.asarray(v).flatten()
319
318
  rhat = pd.concat([
320
319
  rhat,
321
320
  pd.DataFrame({
@@ -1197,41 +1196,32 @@ class MediaEffects:
1197
1196
  include_ci: If `True`, plots the credible interval. Defaults to `True`.
1198
1197
 
1199
1198
  Returns:
1200
- A dictionary mapping channel type constants (`media`, `rf`, and
1201
- `organic_media`) to their respective Altair chart objects. Keys are only
1202
- present if charts for that type were generated (i.e., if the
1203
- corresponding channels exist in the data). Returns an empty dictionary if
1204
- no relevant channels are found.
1199
+ A dictionary mapping channel type constants (`media`, `rf`,
1200
+ `organic_media`, and `organic_rf`) to their respective Altair chart
1201
+ objects. Keys are only present if charts for that type were generated
1202
+ (i.e., if the corresponding channels exist in the data). Returns an empty
1203
+ dictionary if no relevant channels are found.
1205
1204
  """
1206
1205
  hill_curves_dataframe = self.hill_curves_dataframe(
1207
1206
  confidence_level=confidence_level
1208
1207
  )
1209
- channel_types = list(set(hill_curves_dataframe[c.CHANNEL_TYPE]))
1208
+ all_channel_types = set(hill_curves_dataframe[c.CHANNEL_TYPE])
1210
1209
  plots: dict[str, alt.Chart] = {}
1211
1210
 
1212
- if c.MEDIA in channel_types:
1213
- media_df = hill_curves_dataframe[
1214
- hill_curves_dataframe[c.CHANNEL_TYPE] == c.MEDIA
1215
- ]
1216
- plots[c.MEDIA] = self._plot_hill_curves_helper(
1217
- media_df, include_prior, include_ci
1218
- )
1219
-
1220
- if c.RF in channel_types:
1221
- rf_df = hill_curves_dataframe[
1222
- hill_curves_dataframe[c.CHANNEL_TYPE] == c.RF
1223
- ]
1224
- plots[c.RF] = self._plot_hill_curves_helper(
1225
- rf_df, include_prior, include_ci
1226
- )
1227
-
1228
- if c.ORGANIC_MEDIA in channel_types:
1229
- organic_media_df = hill_curves_dataframe[
1230
- hill_curves_dataframe[c.CHANNEL_TYPE] == c.ORGANIC_MEDIA
1231
- ]
1232
- plots[c.ORGANIC_MEDIA] = self._plot_hill_curves_helper(
1233
- organic_media_df, include_prior, include_ci
1234
- )
1211
+ supported_channel_types = [
1212
+ c.MEDIA,
1213
+ c.RF,
1214
+ c.ORGANIC_MEDIA,
1215
+ c.ORGANIC_RF,
1216
+ ]
1217
+ for channel_type in supported_channel_types:
1218
+ if channel_type in all_channel_types:
1219
+ df_for_type = hill_curves_dataframe[
1220
+ hill_curves_dataframe[c.CHANNEL_TYPE] == channel_type
1221
+ ]
1222
+ plots[channel_type] = self._plot_hill_curves_helper(
1223
+ df_for_type, include_prior, include_ci
1224
+ )
1235
1225
 
1236
1226
  return plots
1237
1227
 
@@ -1259,19 +1249,17 @@ class MediaEffects:
1259
1249
  column, or contains an unsupported channel type.
1260
1250
  """
1261
1251
  channel_type = df_channel_type[c.CHANNEL_TYPE].iloc[0]
1262
- if channel_type == c.MEDIA:
1252
+ if channel_type in [c.MEDIA, c.ORGANIC_MEDIA]:
1263
1253
  x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1264
1254
  shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1265
- elif channel_type == c.RF:
1255
+ elif channel_type in [c.RF, c.ORGANIC_RF]:
1266
1256
  x_axis_title = summary_text.HILL_X_AXIS_RF_LABEL
1267
1257
  shaded_area_title = summary_text.HILL_SHADED_REGION_RF_LABEL
1268
- elif channel_type == c.ORGANIC_MEDIA:
1269
- x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1270
- shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1271
1258
  else:
1272
1259
  raise ValueError(
1273
1260
  f"Unsupported channel type '{channel_type}' found in Hill curve data."
1274
- ' Expected one of: {c.MEDIA}, {c.RF}, {c.ORGANIC_MEDIA}.'
1261
+ ' Expected one of: {c.MEDIA}, {c.RF}, {c.ORGANIC_MEDIA},'
1262
+ ' {c.ORGANIC_RF}.'
1275
1263
  )
1276
1264
  domain_list = [
1277
1265
  c.POSTERIOR,
@@ -1433,8 +1421,8 @@ class MediaSummary:
1433
1421
  non_media_baseline_values: Optional list of shape
1434
1422
  `(n_non_media_channels,)`. Each element is a float denoting the fixed
1435
1423
  value which will be used as baseline for the given channel. If `None`,
1436
- the values defined with `ModelSpec.non_media_baseline_values`
1437
- will be used.
1424
+ the values defined with `ModelSpec.non_media_baseline_values` will be
1425
+ used.
1438
1426
  """
1439
1427
  self._meridian = meridian
1440
1428
  self._analyzer = analyzer.Analyzer(meridian)
@@ -1654,8 +1642,8 @@ class MediaSummary:
1654
1642
  non_media_baseline_values: Optional list of shape
1655
1643
  `(n_non_media_channels,)`. Each element is a float denoting the fixed
1656
1644
  value which will be used as baseline for the given channel. If `None`,
1657
- the values defined with `ModelSpec.non_media_baseline_values`
1658
- will be used.
1645
+ the values defined with `ModelSpec.non_media_baseline_values` will be
1646
+ used.
1659
1647
  """
1660
1648
  self._confidence_level = confidence_level or self._confidence_level
1661
1649
  self._selected_times = selected_times
@@ -1686,7 +1674,7 @@ class MediaSummary:
1686
1674
  c.DATE_FORMAT if time_granularity == c.WEEKLY else c.QUARTER_FORMAT
1687
1675
  )
1688
1676
 
1689
- outcome_df = self._transform_contribution_metrics(
1677
+ outcome_df = self.contribution_metrics(
1690
1678
  include_non_paid=True, aggregate_times=False
1691
1679
  )
1692
1680
 
@@ -1800,7 +1788,7 @@ class MediaSummary:
1800
1788
  f'time_granularity must be one of {c.TIME_GRANULARITIES}'
1801
1789
  )
1802
1790
 
1803
- outcome_df = self._transform_contribution_metrics(
1791
+ outcome_df = self.contribution_metrics(
1804
1792
  include_non_paid=True, aggregate_times=False
1805
1793
  )
1806
1794
  outcome_df[c.TIME] = pd.to_datetime(outcome_df[c.TIME])
@@ -1907,7 +1895,7 @@ class MediaSummary:
1907
1895
  if self._meridian.input_data.revenue_per_kpi is not None
1908
1896
  else c.KPI.upper()
1909
1897
  )
1910
- outcome_df = self._transform_contribution_metrics(include_non_paid=True)
1898
+ outcome_df = self.contribution_metrics(include_non_paid=True)
1911
1899
  pct = c.PCT_OF_CONTRIBUTION
1912
1900
  value = c.INCREMENTAL_OUTCOME
1913
1901
  outcome_df['outcome_text'] = outcome_df.apply(
@@ -1991,7 +1979,7 @@ class MediaSummary:
1991
1979
  Returns:
1992
1980
  An Altair plot showing the contributions for all channels.
1993
1981
  """
1994
- outcome_df = self._transform_contribution_metrics(
1982
+ outcome_df = self.contribution_metrics(
1995
1983
  [c.ALL_CHANNELS], include_non_paid=True
1996
1984
  )
1997
1985
 
@@ -2411,7 +2399,7 @@ class MediaSummary:
2411
2399
  spend_df = paid_summary_metrics[c.SPEND].to_dataframe().reset_index()
2412
2400
  return metrics_df.merge(spend_df, on=c.CHANNEL)
2413
2401
 
2414
- def _transform_contribution_metrics(
2402
+ def contribution_metrics(
2415
2403
  self,
2416
2404
  selected_channels: Sequence[str] | None = None,
2417
2405
  include_non_paid: bool = False,