google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
  2. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/summarizer.py +7 -2
  5. meridian/analysis/test_utils.py +934 -485
  6. meridian/analysis/visualizer.py +10 -6
  7. meridian/constants.py +1 -0
  8. meridian/data/test_utils.py +82 -10
  9. meridian/model/__init__.py +2 -0
  10. meridian/model/context.py +925 -0
  11. meridian/model/eda/constants.py +1 -0
  12. meridian/model/equations.py +418 -0
  13. meridian/model/knots.py +58 -47
  14. meridian/model/model.py +93 -792
  15. meridian/version.py +1 -1
  16. scenarioplanner/__init__.py +42 -0
  17. scenarioplanner/converters/__init__.py +25 -0
  18. scenarioplanner/converters/dataframe/__init__.py +28 -0
  19. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  20. scenarioplanner/converters/dataframe/common.py +71 -0
  21. scenarioplanner/converters/dataframe/constants.py +137 -0
  22. scenarioplanner/converters/dataframe/converter.py +42 -0
  23. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  24. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  25. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  26. scenarioplanner/converters/mmm.py +743 -0
  27. scenarioplanner/converters/mmm_converter.py +58 -0
  28. scenarioplanner/converters/sheets.py +156 -0
  29. scenarioplanner/converters/test_data.py +714 -0
  30. scenarioplanner/linkingapi/__init__.py +47 -0
  31. scenarioplanner/linkingapi/constants.py +27 -0
  32. scenarioplanner/linkingapi/url_generator.py +131 -0
  33. scenarioplanner/mmm_ui_proto_generator.py +354 -0
  34. schema/__init__.py +5 -2
  35. schema/mmm_proto_generator.py +71 -0
  36. schema/model_consumer.py +133 -0
  37. schema/processors/__init__.py +77 -0
  38. schema/processors/budget_optimization_processor.py +832 -0
  39. schema/processors/common.py +64 -0
  40. schema/processors/marketing_processor.py +1136 -0
  41. schema/processors/model_fit_processor.py +367 -0
  42. schema/processors/model_kernel_processor.py +117 -0
  43. schema/processors/model_processor.py +412 -0
  44. schema/processors/reach_frequency_optimization_processor.py +584 -0
  45. schema/test_data.py +380 -0
  46. schema/utils/__init__.py +1 -0
  47. schema/utils/date_range_bucketing.py +117 -0
  48. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
  49. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -243,6 +243,12 @@ class ModelDiagnostics:
243
243
 
244
244
  groupby = posterior_df.columns.tolist()
245
245
  groupby.remove(parameter)
246
+
247
+ parameter_99_max = prior_posterior_df[parameter].quantile(0.99)
248
+ # Remove outliers that make the chart hard to read.
249
+ prior_posterior_df[parameter] = prior_posterior_df[parameter].clip(
250
+ upper=parameter_99_max * c.OUTLIER_CLIP_FACTOR
251
+ )
246
252
  plot = (
247
253
  alt.Chart(prior_posterior_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
248
254
  .transform_density(
@@ -269,7 +275,7 @@ class ModelDiagnostics:
269
275
  title=formatter.custom_title_params(
270
276
  summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
271
277
  )
272
- ).configure_axis(**formatter.TEXT_CONFIG)
278
+ ).configure_axis(**formatter.TEXT_CONFIG).interactive()
273
279
 
274
280
  def plot_rhat_boxplot(self) -> alt.Chart:
275
281
  """Plots the R-hat box plot.
@@ -1450,17 +1456,15 @@ class MediaSummary:
1450
1456
 
1451
1457
  Args:
1452
1458
  aggregate_times: If `True`, aggregates the metrics across all time
1453
- periods. If `False`, returns time-varying metrics.
1459
+ periods. If `False`, returns time-varying metrics.
1454
1460
 
1455
1461
  Returns:
1456
1462
  An `xarray.Dataset` containing the following:
1457
1463
  - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
1458
- `ci_hi`),
1459
- `distribution` (`prior`, `posterior`)
1464
+ `ci_hi`), `distribution` (`prior`, `posterior`)
1460
1465
  - **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
1461
1466
  `pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
1462
- `roi`,
1463
- `effectiveness`, `mroi`.
1467
+ `roi`, `effectiveness`, `mroi`.
1464
1468
  """
1465
1469
  return self._analyzer.summary_metrics(
1466
1470
  selected_times=self._selected_times,
meridian/constants.py CHANGED
@@ -755,6 +755,7 @@ STROKE_DASH = (4, 2)
755
755
  POINT_SIZE = 80
756
756
  INDEPENDENT = 'independent'
757
757
  RESPONSE_CURVE_STEP_SIZE = 0.01
758
+ OUTLIER_CLIP_FACTOR = 1.2
758
759
 
759
760
 
760
761
  # Font names.
@@ -642,6 +642,7 @@ def random_media_da(
642
642
  channel_variable_name: str = 'media_channel',
643
643
  channel_prefix: str = 'ch_',
644
644
  integer_geos: bool = False,
645
+ nonzero_shift: float = 0.0,
645
646
  ) -> xr.DataArray:
646
647
  """Generates a sample `media` DataArray.
647
648
 
@@ -662,6 +663,7 @@ def random_media_da(
662
663
  channel_variable_name: The name of the channel variable
663
664
  channel_prefix: The prefix of the channel names
664
665
  integer_geos: If True, the geos will be integers.
666
+ nonzero_shift: A scalar value to add to the generated data.
665
667
 
666
668
  Returns:
667
669
  A DataArray containing random data.
@@ -695,6 +697,8 @@ def random_media_da(
695
697
  )
696
698
  )
697
699
 
700
+ media = media + nonzero_shift
701
+
698
702
  if explicit_geo_names is None:
699
703
  geos = sample_geos(n_geos, integer_geos)
700
704
  else:
@@ -736,6 +740,7 @@ def random_organic_media_da(
736
740
  explicit_time_index: Sequence[str] | None = None,
737
741
  explicit_media_channel_names: Sequence[str] | None = None,
738
742
  integer_geos: bool = False,
743
+ nonzero_shift: float = 0.0,
739
744
  ) -> xr.DataArray:
740
745
  """Generates a sample `organic_media` DataArray."""
741
746
  return random_media_da(
@@ -751,6 +756,7 @@ def random_organic_media_da(
751
756
  channel_variable_name='organic_media_channel',
752
757
  channel_prefix='organic_media_',
753
758
  integer_geos=integer_geos,
759
+ nonzero_shift=nonzero_shift,
754
760
  )
755
761
 
756
762
 
@@ -761,6 +767,7 @@ def random_media_spend_nd_da(
761
767
  seed=0,
762
768
  integer_geos: bool = False,
763
769
  explicit_media_channel_names: Sequence[str] | None = None,
770
+ nonzero_shift: float = 0.0,
764
771
  ) -> xr.DataArray:
765
772
  """Generates a sample N-dimensional `media_spend` DataArray.
766
773
 
@@ -781,6 +788,7 @@ def random_media_spend_nd_da(
781
788
  integer_geos: If True, the geos will be integers.
782
789
  explicit_media_channel_names: If given, ignore `n_media_channels` and use
783
790
  this as is.
791
+ nonzero_shift: A scalar value to add to the generated data.
784
792
 
785
793
  Returns:
786
794
  A DataArray containing the generated `media_spend` data with the given
@@ -818,7 +826,7 @@ def random_media_spend_nd_da(
818
826
  f'Shape {dims} not supported by the random_media_spend_nd_da function.'
819
827
  )
820
828
 
821
- media_spend = abs(np.random.normal(1, 1, size=shape))
829
+ media_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
822
830
 
823
831
  return xr.DataArray(
824
832
  media_spend,
@@ -1007,8 +1015,27 @@ def random_reach_da(
1007
1015
  channel_variable_name: str = 'rf_channel',
1008
1016
  channel_prefix: str = 'rf_ch_',
1009
1017
  integer_geos: bool = False,
1018
+ nonzero_shift: float = 0.0,
1010
1019
  ) -> xr.DataArray:
1011
- """Generates a sample `reach` DataArray."""
1020
+ """Generates a sample `reach` DataArray.
1021
+
1022
+ Args:
1023
+ n_geos: Number of geos
1024
+ n_times: Number of time periods
1025
+ n_media_times: Number of media time periods
1026
+ n_rf_channels: Number of reach and frequency channels
1027
+ seed: Random seed used by `np.random.seed()`
1028
+ explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
1029
+ is
1030
+ array_name: The name of the array to be created
1031
+ channel_variable_name: The name of the channel variable
1032
+ channel_prefix: The prefix of the channel names
1033
+ integer_geos: If True, the geos will be integers.
1034
+ nonzero_shift: A scalar value to add to the generated data.
1035
+
1036
+ Returns:
1037
+ A DataArray containing random data.
1038
+ """
1012
1039
 
1013
1040
  np.random.seed(seed)
1014
1041
 
@@ -1016,12 +1043,15 @@ def random_reach_da(
1016
1043
  if n_times < n_media_times:
1017
1044
  start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
1018
1045
 
1019
- reach = np.round(
1020
- abs(
1021
- np.random.normal(
1022
- 3000, 100, size=(n_geos, n_media_times, n_rf_channels)
1046
+ reach = (
1047
+ np.round(
1048
+ abs(
1049
+ np.random.normal(
1050
+ 3000, 100, size=(n_geos, n_media_times, n_rf_channels)
1051
+ )
1023
1052
  )
1024
1053
  )
1054
+ + nonzero_shift
1025
1055
  )
1026
1056
 
1027
1057
  channels = (
@@ -1051,6 +1081,7 @@ def random_organic_reach_da(
1051
1081
  seed: int = 0,
1052
1082
  explicit_organic_rf_channel_names: Sequence[str] | None = None,
1053
1083
  integer_geos: bool = False,
1084
+ nonzero_shift: float = 0.0,
1054
1085
  ) -> xr.DataArray:
1055
1086
  """Generates a sample `organic_reach` DataArray."""
1056
1087
  return random_reach_da(
@@ -1064,6 +1095,7 @@ def random_organic_reach_da(
1064
1095
  channel_variable_name='organic_rf_channel',
1065
1096
  channel_prefix='organic_rf_ch_',
1066
1097
  integer_geos=integer_geos,
1098
+ nonzero_shift=nonzero_shift,
1067
1099
  )
1068
1100
 
1069
1101
 
@@ -1078,8 +1110,27 @@ def random_frequency_da(
1078
1110
  channel_variable_name: str = 'rf_channel',
1079
1111
  channel_prefix: str = 'rf_ch_',
1080
1112
  integer_geos: bool = False,
1113
+ nonzero_shift: float = 0.0,
1081
1114
  ) -> xr.DataArray:
1082
- """Generates a sample `frequency` DataArray."""
1115
+ """Generates a sample `frequency` DataArray.
1116
+
1117
+ Args:
1118
+ n_geos: Number of geos
1119
+ n_times: Number of time periods
1120
+ n_media_times: Number of media time periods
1121
+ n_rf_channels: Number of reach and frequency channels
1122
+ seed: Random seed used by `np.random.seed()`
1123
+ explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
1124
+ is
1125
+ array_name: The name of the array to be created
1126
+ channel_variable_name: The name of the channel variable
1127
+ channel_prefix: The prefix of the channel names
1128
+ integer_geos: If True, the geos will be integers.
1129
+ nonzero_shift: A scalar value to add to the generated data.
1130
+
1131
+ Returns:
1132
+ A DataArray containing random data.
1133
+ """
1083
1134
 
1084
1135
  np.random.seed(seed)
1085
1136
 
@@ -1087,8 +1138,9 @@ def random_frequency_da(
1087
1138
  if n_times < n_media_times:
1088
1139
  start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
1089
1140
 
1090
- frequency = abs(
1091
- np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels))
1141
+ frequency = (
1142
+ abs(np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels)))
1143
+ + nonzero_shift
1092
1144
  )
1093
1145
 
1094
1146
  channels = (
@@ -1119,6 +1171,7 @@ def random_organic_frequency_da(
1119
1171
  seed: int = 0,
1120
1172
  explicit_organic_rf_channel_names: Sequence[str] | None = None,
1121
1173
  integer_geos: bool = False,
1174
+ nonzero_shift: float = 0.0,
1122
1175
  ) -> xr.DataArray:
1123
1176
  """Generates a sample `organic_frequency` DataArray."""
1124
1177
  return random_frequency_da(
@@ -1132,6 +1185,7 @@ def random_organic_frequency_da(
1132
1185
  channel_variable_name='organic_rf_channel',
1133
1186
  channel_prefix='organic_rf_ch_',
1134
1187
  integer_geos=integer_geos,
1188
+ nonzero_shift=nonzero_shift,
1135
1189
  )
1136
1190
 
1137
1191
 
@@ -1141,6 +1195,7 @@ def random_rf_spend_nd_da(
1141
1195
  n_rf_channels: int | None = None,
1142
1196
  seed=0,
1143
1197
  integer_geos: bool = False,
1198
+ nonzero_shift: float = 0.0,
1144
1199
  ) -> xr.DataArray:
1145
1200
  """Generates a sample N-dimensional `rf_spend` DataArray.
1146
1201
 
@@ -1157,6 +1212,7 @@ def random_rf_spend_nd_da(
1157
1212
  n_rf_channels: Number of channels in the created `rf_spend` array.
1158
1213
  seed: Random seed used by `np.random.seed()`.
1159
1214
  integer_geos: If True, the geos will be integers.
1215
+ nonzero_shift: A scalar value to add to the generated data.
1160
1216
 
1161
1217
  Returns:
1162
1218
  A DataArray containing the generated `rf_spend` data with the given
@@ -1187,7 +1243,7 @@ def random_rf_spend_nd_da(
1187
1243
  f'Shape {dims} not supported by the random_rf_spend_nd_da function.'
1188
1244
  )
1189
1245
 
1190
- rf_spend = abs(np.random.normal(1, 1, size=shape))
1246
+ rf_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
1191
1247
 
1192
1248
  return xr.DataArray(
1193
1249
  rf_spend,
@@ -1206,6 +1262,7 @@ def random_non_media_treatments_da(
1206
1262
  date_format: str = c.DATE_FORMAT,
1207
1263
  explicit_time_index: Sequence[str] | None = None,
1208
1264
  integer_geos: bool = False,
1265
+ nonzero_shift: float = 0.0,
1209
1266
  ) -> xr.DataArray:
1210
1267
  """Generates a sample `non_media_treatments` DataArray.
1211
1268
 
@@ -1218,6 +1275,7 @@ def random_non_media_treatments_da(
1218
1275
  date_format: The date format to use for time coordinate labels
1219
1276
  explicit_time_index: If given, ignore `date_format` and use this as is
1220
1277
  integer_geos: If True, the geos will be integers.
1278
+ nonzero_shift: A scalar value to add to the generated data.
1221
1279
 
1222
1280
  Returns:
1223
1281
  A DataArray containing random non-media variable.
@@ -1232,6 +1290,8 @@ def random_non_media_treatments_da(
1232
1290
  non_media_channel,
1233
1291
  size=(n_geos, n_times, n_non_media_channels),
1234
1292
  )
1293
+ non_media_treatments = non_media_treatments + nonzero_shift
1294
+
1235
1295
  return xr.DataArray(
1236
1296
  non_media_treatments,
1237
1297
  dims=['geo', 'time', 'non_media_channel'],
@@ -1268,6 +1328,7 @@ def random_dataset(
1268
1328
  remove_media_time: bool = False,
1269
1329
  integer_geos: bool = False,
1270
1330
  kpi_data_pattern: str = '',
1331
+ nonzero_shift: float = 0.0,
1271
1332
  ) -> xr.Dataset:
1272
1333
  """Generates a random dataset."""
1273
1334
  if n_media_channels:
@@ -1280,6 +1341,7 @@ def random_dataset(
1280
1341
  integer_geos=integer_geos,
1281
1342
  explicit_media_channel_names=explicit_media_channel_names,
1282
1343
  media_value_scales=media_value_scales,
1344
+ nonzero_shift=nonzero_shift,
1283
1345
  )
1284
1346
  media_spend = random_media_spend_nd_da(
1285
1347
  n_geos=n_geos,
@@ -1288,6 +1350,7 @@ def random_dataset(
1288
1350
  explicit_media_channel_names=explicit_media_channel_names,
1289
1351
  seed=seed,
1290
1352
  integer_geos=integer_geos,
1353
+ nonzero_shift=nonzero_shift,
1291
1354
  )
1292
1355
  else:
1293
1356
  media = None
@@ -1301,6 +1364,7 @@ def random_dataset(
1301
1364
  n_rf_channels=n_rf_channels,
1302
1365
  seed=seed,
1303
1366
  integer_geos=integer_geos,
1367
+ nonzero_shift=nonzero_shift,
1304
1368
  )
1305
1369
  frequency = random_frequency_da(
1306
1370
  n_geos=n_geos,
@@ -1309,6 +1373,7 @@ def random_dataset(
1309
1373
  n_rf_channels=n_rf_channels,
1310
1374
  seed=seed,
1311
1375
  integer_geos=integer_geos,
1376
+ nonzero_shift=nonzero_shift,
1312
1377
  )
1313
1378
  rf_spend = random_rf_spend_nd_da(
1314
1379
  n_geos=n_geos,
@@ -1316,6 +1381,7 @@ def random_dataset(
1316
1381
  n_rf_channels=n_rf_channels,
1317
1382
  seed=seed,
1318
1383
  integer_geos=integer_geos,
1384
+ nonzero_shift=nonzero_shift,
1319
1385
  )
1320
1386
  else:
1321
1387
  reach = None
@@ -1352,6 +1418,7 @@ def random_dataset(
1352
1418
  n_non_media_channels=n_non_media_channels,
1353
1419
  seed=seed,
1354
1420
  integer_geos=integer_geos,
1421
+ nonzero_shift=nonzero_shift,
1355
1422
  )
1356
1423
  else:
1357
1424
  non_media_treatments = None
@@ -1364,6 +1431,7 @@ def random_dataset(
1364
1431
  n_organic_media_channels=n_organic_media_channels,
1365
1432
  seed=seed,
1366
1433
  integer_geos=integer_geos,
1434
+ nonzero_shift=nonzero_shift,
1367
1435
  )
1368
1436
  else:
1369
1437
  organic_media = None
@@ -1376,6 +1444,7 @@ def random_dataset(
1376
1444
  n_organic_rf_channels=n_organic_rf_channels,
1377
1445
  seed=seed,
1378
1446
  integer_geos=integer_geos,
1447
+ nonzero_shift=nonzero_shift,
1379
1448
  )
1380
1449
  organic_frequency = random_organic_frequency_da(
1381
1450
  n_geos=n_geos,
@@ -1384,6 +1453,7 @@ def random_dataset(
1384
1453
  n_organic_rf_channels=n_organic_rf_channels,
1385
1454
  seed=seed,
1386
1455
  integer_geos=integer_geos,
1456
+ nonzero_shift=nonzero_shift,
1387
1457
  )
1388
1458
  else:
1389
1459
  organic_reach = None
@@ -1794,6 +1864,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1794
1864
  n_organic_media_channels: int | None = None,
1795
1865
  n_organic_rf_channels: int | None = None,
1796
1866
  seed: int = 0,
1867
+ nonzero_shift: float = 0.0,
1797
1868
  ) -> input_data.InputData:
1798
1869
  """Generates sample InputData for `non_revenue` KPI w/ revenue_per_kpi."""
1799
1870
  dataset = random_dataset(
@@ -1807,6 +1878,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1807
1878
  n_organic_media_channels=n_organic_media_channels,
1808
1879
  n_organic_rf_channels=n_organic_rf_channels,
1809
1880
  seed=seed,
1881
+ nonzero_shift=nonzero_shift,
1810
1882
  )
1811
1883
  return input_data.InputData(
1812
1884
  kpi=dataset.kpi,
@@ -15,7 +15,9 @@
15
15
  """The Meridian API module that models the data."""
16
16
 
17
17
  from meridian.model import adstock_hill
18
+ from meridian.model import context
18
19
  from meridian.model import eda
20
+ from meridian.model import equations
19
21
  from meridian.model import knots
20
22
  from meridian.model import media
21
23
  from meridian.model import model