google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -48,7 +48,10 @@ class ModelDiagnostics:
48
48
 
49
49
  def __init__(self, meridian: model.Meridian, use_kpi: bool = False):
50
50
  self._meridian = meridian
51
- self._analyzer = analyzer.Analyzer(meridian)
51
+ self._analyzer = analyzer.Analyzer(
52
+ model_context=meridian.model_context,
53
+ inference_data=meridian.inference_data,
54
+ )
52
55
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
53
56
 
54
57
  @functools.lru_cache(maxsize=128)
@@ -243,6 +246,12 @@ class ModelDiagnostics:
243
246
 
244
247
  groupby = posterior_df.columns.tolist()
245
248
  groupby.remove(parameter)
249
+
250
+ parameter_99_max = prior_posterior_df[parameter].quantile(0.99)
251
+ # Remove outliers that make the chart hard to read.
252
+ prior_posterior_df[parameter] = prior_posterior_df[parameter].clip(
253
+ upper=parameter_99_max * c.OUTLIER_CLIP_FACTOR
254
+ )
246
255
  plot = (
247
256
  alt.Chart(prior_posterior_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
248
257
  .transform_density(
@@ -265,11 +274,15 @@ class ModelDiagnostics:
265
274
  x=c.INDEPENDENT
266
275
  )
267
276
 
268
- return plot.properties(
269
- title=formatter.custom_title_params(
270
- summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
277
+ return (
278
+ plot.properties(
279
+ title=formatter.custom_title_params(
280
+ summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
281
+ )
271
282
  )
272
- ).configure_axis(**formatter.TEXT_CONFIG)
283
+ .configure_axis(**formatter.TEXT_CONFIG)
284
+ .interactive()
285
+ )
273
286
 
274
287
  def plot_rhat_boxplot(self) -> alt.Chart:
275
288
  """Plots the R-hat box plot.
@@ -381,7 +394,10 @@ class ModelFit:
381
394
  represented as a value between zero and one. Default is `0.9`.
382
395
  """
383
396
  self._meridian = meridian
384
- self._analyzer = analyzer.Analyzer(meridian)
397
+ self._analyzer = analyzer.Analyzer(
398
+ model_context=meridian.model_context,
399
+ inference_data=meridian.inference_data,
400
+ )
385
401
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
386
402
  self._model_fit_data = self._analyzer.expected_vs_actual_data(
387
403
  use_kpi=self._use_kpi, confidence_level=confidence_level
@@ -651,7 +667,10 @@ class ReachAndFrequency:
651
667
  use_kpi: If `True`, KPI is used instead of revenue.
652
668
  """
653
669
  self._meridian = meridian
654
- self._analyzer = analyzer.Analyzer(meridian)
670
+ self._analyzer = analyzer.Analyzer(
671
+ model_context=meridian.model_context,
672
+ inference_data=meridian.inference_data,
673
+ )
655
674
  self._selected_times = selected_times
656
675
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
657
676
  self._optimal_frequency_data = self._analyzer.optimal_freq(
@@ -851,7 +870,10 @@ class MediaEffects:
851
870
  the incremental revenue using the revenue per KPI (if available).
852
871
  """
853
872
  self._meridian = meridian
854
- self._analyzer = analyzer.Analyzer(meridian)
873
+ self._analyzer = analyzer.Analyzer(
874
+ model_context=meridian.model_context,
875
+ inference_data=meridian.inference_data,
876
+ )
855
877
  self._by_reach = by_reach
856
878
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
857
879
 
@@ -1425,7 +1447,10 @@ class MediaSummary:
1425
1447
  use_kpi: If `True`, use KPI instead of revenue.
1426
1448
  """
1427
1449
  self._meridian = meridian
1428
- self._analyzer = analyzer.Analyzer(meridian)
1450
+ self._analyzer = analyzer.Analyzer(
1451
+ model_context=meridian.model_context,
1452
+ inference_data=meridian.inference_data,
1453
+ )
1429
1454
  self._confidence_level = confidence_level
1430
1455
  self._selected_times = selected_times
1431
1456
  self._marginal_roi_by_reach = marginal_roi_by_reach
@@ -1450,17 +1475,15 @@ class MediaSummary:
1450
1475
 
1451
1476
  Args:
1452
1477
  aggregate_times: If `True`, aggregates the metrics across all time
1453
- periods. If `False`, returns time-varying metrics.
1478
+ periods. If `False`, returns time-varying metrics.
1454
1479
 
1455
1480
  Returns:
1456
1481
  An `xarray.Dataset` containing the following:
1457
1482
  - **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
1458
- `ci_hi`),
1459
- `distribution` (`prior`, `posterior`)
1483
+ `ci_hi`), `distribution` (`prior`, `posterior`)
1460
1484
  - **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
1461
1485
  `pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
1462
- `roi`,
1463
- `effectiveness`, `mroi`.
1486
+ `roi`, `effectiveness`, `mroi`.
1464
1487
  """
1465
1488
  return self._analyzer.summary_metrics(
1466
1489
  selected_times=self._selected_times,
@@ -909,6 +909,77 @@ if _BACKEND == config.Backend.JAX:
909
909
 
910
910
  xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
911
911
 
912
+ def _jax_adstock_process(
913
+ media: "_jax.Array", weights: "_jax.Array", n_times_output: int
914
+ ) -> "_jax.Array":
915
+ """JAX implementation for adstock_process using convolution.
916
+
917
+ This function applies an adstock process to media spend data using a
918
+ convolutional approach. The weights represent the adstock decay over time.
919
+
920
+ Args:
921
+ media: A JAX array of media spend. Expected shape is
922
+ `(batch_dims, n_geos, n_times_in, n_channels)`.
923
+ weights: A JAX array of adstock weights. Expected shape is
924
+ `(batch_dims, n_channels, window_size)`, where `batch_dims` must be
925
+ broadcastable to the batch dimensions of `media`.
926
+ n_times_output: The number of time periods in the output. This corresponds
927
+ to `n_times_in - window_size + 1`.
928
+
929
+ Returns:
930
+ A JAX array representing the adstocked media, with shape
931
+ `(batch_dims, n_geos, n_times_output, n_channels)`.
932
+ """
933
+
934
+ batch_dims = weights.shape[:-2]
935
+ if media.shape[:-3] != batch_dims:
936
+ media = jax_ops.broadcast_to(media, batch_dims + media.shape[-3:])
937
+
938
+ n_geos = media.shape[-3]
939
+ n_times_in = media.shape[-2]
940
+ n_channels = media.shape[-1]
941
+ window_size = weights.shape[-1]
942
+
943
+ perm = list(range(media.ndim))
944
+ perm[-2], perm[-1] = perm[-1], perm[-2]
945
+ media_transposed = jax_ops.transpose(media, perm)
946
+ media_reshaped = jax_ops.reshape(media_transposed, (1, -1, n_times_in))
947
+
948
+ total_channels = media_reshaped.shape[1]
949
+ weights_expanded = jax_ops.expand_dims(weights, -3)
950
+ weights_tiled = jax_ops.broadcast_to(
951
+ weights_expanded, batch_dims + (n_geos, n_channels, window_size)
952
+ )
953
+ kernel_reshaped = jax_ops.reshape(
954
+ weights_tiled, (total_channels, 1, window_size)
955
+ )
956
+
957
+ dn = jax.lax.conv_dimension_numbers(
958
+ media_reshaped.shape, kernel_reshaped.shape, ("NCH", "OIH", "NCH")
959
+ )
960
+
961
+ out = jax.lax.conv_general_dilated(
962
+ lhs=media_reshaped,
963
+ rhs=kernel_reshaped,
964
+ window_strides=(1,),
965
+ padding="VALID",
966
+ lhs_dilation=(1,),
967
+ rhs_dilation=(1,),
968
+ dimension_numbers=dn,
969
+ feature_group_count=total_channels,
970
+ precision=jax.lax.Precision.HIGHEST,
971
+ )
972
+
973
+ t_out = out.shape[-1]
974
+ out_reshaped = jax_ops.reshape(
975
+ out, batch_dims + (n_geos, n_channels, t_out)
976
+ )
977
+ perm_back = list(range(out_reshaped.ndim))
978
+ perm_back[-2], perm_back[-1] = perm_back[-1], perm_back[-2]
979
+ out_final = jax_ops.transpose(out_reshaped, perm_back)
980
+
981
+ return out_final[..., :n_times_output, :]
982
+
912
983
  _ops = jax_ops
913
984
  errors = _JaxErrors()
914
985
  Tensor = jax.Array
@@ -920,6 +991,7 @@ if _BACKEND == config.Backend.JAX:
920
991
 
921
992
  # Standardized Public API
922
993
  absolute = _ops.abs
994
+ adstock_process = _jax_adstock_process
923
995
  allclose = _ops.allclose
924
996
  arange = _jax_arange
925
997
  argmax = _jax_argmax
@@ -1059,6 +1131,39 @@ elif _BACKEND == config.Backend.TENSORFLOW:
1059
1131
 
1060
1132
  xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
1061
1133
 
1134
+ def _tf_adstock_process(
1135
+ media: "_tf.Tensor", weights: "_tf.Tensor", n_times_output: int
1136
+ ) -> "_tf.Tensor":
1137
+ """TensorFlow implementation for adstock_process using loop/einsum.
1138
+
1139
+ This function applies an adstock process to media spend data. It achieves
1140
+ this by creating a windowed view of the `media` tensor and then using
1141
+ `tf.einsum` to efficiently compute the weighted sum based on the provided
1142
+ `weights`. The `weights` tensor defines the decay effect over a specific
1143
+ `window_size`. The output is truncated to `n_times_output` periods.
1144
+
1145
+ Args:
1146
+ media: Input media tensor. Expected shape is `(..., num_geos,
1147
+ num_times_in, num_channels)`. The `...` represents optional batch
1148
+ dimensions.
1149
+ weights: Adstock weights tensor. Expected shape is `(..., num_channels,
1150
+ window_size)`. The batch dimensions must be broadcast-compatible with
1151
+ those in `media`.
1152
+ n_times_output: The number of time periods to output. This should be less
1153
+ than or equal to `num_times_in - window_size + 1`.
1154
+
1155
+ Returns:
1156
+ A tensor of shape `(..., num_geos, n_times_output, num_channels)`
1157
+ representing the adstocked media.
1158
+ """
1159
+
1160
+ window_size = weights.shape[-1]
1161
+ window_list = [
1162
+ media[..., i : i + n_times_output, :] for i in range(window_size)
1163
+ ]
1164
+ windowed = tf_backend.stack(window_list)
1165
+ return tf_backend.einsum("...cw,w...gtc->...gtc", weights, windowed)
1166
+
1062
1167
  tfd = tfp.distributions
1063
1168
  bijectors = tfp.bijectors
1064
1169
  experimental = tfp.experimental
@@ -1067,6 +1172,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
1067
1172
 
1068
1173
  # Standardized Public API
1069
1174
  absolute = _ops.math.abs
1175
+ adstock_process = _tf_adstock_process
1070
1176
  allclose = _ops.experimental.numpy.allclose
1071
1177
  arange = _tf_arange
1072
1178
  argmax = _tf_argmax
meridian/constants.py CHANGED
@@ -392,6 +392,7 @@ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
392
392
  ETA_RF,
393
393
  ETA_OM,
394
394
  ETA_ORF,
395
+ TAU_G,
395
396
  )
396
397
 
397
398
  MEDIA_PARAMETERS = (
@@ -755,6 +756,7 @@ STROKE_DASH = (4, 2)
755
756
  POINT_SIZE = 80
756
757
  INDEPENDENT = 'independent'
757
758
  RESPONSE_CURVE_STEP_SIZE = 0.01
759
+ OUTLIER_CLIP_FACTOR = 1.2
758
760
 
759
761
 
760
762
  # Font names.
@@ -20,13 +20,13 @@ The `InputData` class is used to store all the input data to the model.
20
20
  from collections import abc
21
21
  from collections.abc import Sequence
22
22
  import dataclasses
23
- import datetime as dt
24
23
  import functools
25
24
  import warnings
26
25
 
27
26
  from meridian import constants
28
27
  from meridian.data import arg_builder
29
28
  from meridian.data import time_coordinates as tc
29
+ from meridian.data import validator
30
30
  import numpy as np
31
31
  import xarray as xr
32
32
 
@@ -298,6 +298,7 @@ class InputData:
298
298
  self._validate_time_formats()
299
299
  self._validate_times()
300
300
  self._validate_geos()
301
+ self._validate_no_negative_values()
301
302
 
302
303
  def _convert_geos_to_strings(self):
303
304
  """Converts geo coordinates to strings in all relevant DataArrays."""
@@ -542,17 +543,36 @@ class InputData:
542
543
  f" `{constants.REVENUE}` or `{constants.NON_REVENUE}`."
543
544
  )
544
545
 
545
- if (self.kpi.values < 0).any():
546
- raise ValueError("KPI values must be non-negative.")
547
-
548
546
  if (
549
547
  self.revenue_per_kpi is not None
550
- and (self.revenue_per_kpi.values <= 0).all()
548
+ and (self.revenue_per_kpi.values == 0).all()
551
549
  ):
552
550
  raise ValueError(
553
- "Revenue per KPI values must not be all zero or negative."
551
+ "All Revenue per KPI values are 0, which can break the ROI"
552
+ " computation. If this is not a data error, please consider setting"
553
+ " revenue_per_kpi to None or follow the instructions at"
554
+ " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-default#default-total-paid-media-contribution-prior."
554
555
  )
555
556
 
557
+ def _validate_no_negative_values(self) -> None:
558
+ """Validates no negative values for applicable fields."""
559
+
560
+ fields_to_loggable_name = {
561
+ constants.MEDIA_SPEND: "Media Spend",
562
+ constants.RF_SPEND: "RF Spend",
563
+ constants.REACH: "Reach",
564
+ constants.FREQUENCY: "Frequency",
565
+ constants.ORGANIC_REACH: "Organic Reach",
566
+ constants.ORGANIC_FREQUENCY: "Organic Frequency",
567
+ constants.REVENUE_PER_KPI: "Revenue per KPI",
568
+ constants.KPI: "KPI",
569
+ }
570
+
571
+ for field, loggable_field in fields_to_loggable_name.items():
572
+ da = getattr(self, field)
573
+ if da is not None and (da.values < 0).any():
574
+ raise ValueError(f"{loggable_field} values must be non-negative.")
575
+
556
576
  def _validate_names(self):
557
577
  """Verifies that the names of the data arrays are correct."""
558
578
  # Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
@@ -762,52 +782,10 @@ class InputData:
762
782
 
763
783
  def _validate_time_formats(self):
764
784
  """Validates the time coordinate format for all variables."""
765
- self._validate_time_coord_format(self.kpi)
766
- self._validate_time_coord_format(self.revenue_per_kpi)
767
- self._validate_time_coord_format(self.controls)
768
- self._validate_time_coord_format(self.media)
769
- self._validate_time_coord_format(self.media_spend)
770
- self._validate_time_coord_format(self.reach)
771
- self._validate_time_coord_format(self.frequency)
772
- self._validate_time_coord_format(self.rf_spend)
773
- self._validate_time_coord_format(self.organic_media)
774
- self._validate_time_coord_format(self.organic_reach)
775
- self._validate_time_coord_format(self.organic_frequency)
776
- self._validate_time_coord_format(self.non_media_treatments)
777
-
778
- def _validate_time_coord_format(self, array: xr.DataArray | None):
779
- """Validates the `time` dimensions format of the selected DataArray.
780
-
781
- The `time` dimension of the selected array must have labels that are
782
- formatted in the Meridian conventional `"yyyy-mm-dd"` format.
783
-
784
- Args:
785
- array: An optional DataArray to validate.
786
- """
787
- if array is None:
788
- return
789
-
790
- time_values = array.coords.get(constants.TIME, None)
791
- if time_values is not None:
792
- for time in time_values:
793
- try:
794
- _ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
795
- except (TypeError, ValueError) as exc:
796
- raise ValueError(
797
- f"Invalid time label: {time.item()}. Expected format:"
798
- f" {constants.DATE_FORMAT}"
799
- ) from exc
800
-
801
- media_time_values = array.coords.get(constants.MEDIA_TIME, None)
802
- if media_time_values is not None:
803
- for time in media_time_values:
804
- try:
805
- _ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
806
- except (TypeError, ValueError) as exc:
807
- raise ValueError(
808
- f"Invalid media_time label: {time.item()}. Expected format:"
809
- f" {constants.DATE_FORMAT}"
810
- ) from exc
785
+ for field in dataclasses.fields(self):
786
+ attr = getattr(self, field.name)
787
+ if field.name != constants.POPULATION and isinstance(attr, xr.DataArray):
788
+ validator.validate_time_coord_format(attr)
811
789
 
812
790
  def _check_unique_names(self, dim: str, array: xr.DataArray | None):
813
791
  """Checks if a DataArray contains unique names on the specified dimension."""
@@ -21,11 +21,11 @@ validation logic and an overall final validation logic before a valid
21
21
 
22
22
  import abc
23
23
  from collections.abc import Sequence
24
- import datetime
25
24
  import warnings
26
25
  from meridian import constants
27
26
  from meridian.data import input_data
28
27
  from meridian.data import time_coordinates as tc
28
+ from meridian.data import validator
29
29
  import natsort
30
30
  import numpy as np
31
31
  import xarray as xr
@@ -676,14 +676,7 @@ class InputDataBuilder(abc.ABC):
676
676
 
677
677
  # Assume that the time coordinate labels are date-formatted strings.
678
678
  # We don't currently support other, arbitrary object types in the builder.
679
- for time in da.coords[time_dimension_name].values:
680
- try:
681
- _ = datetime.datetime.strptime(time, constants.DATE_FORMAT)
682
- except ValueError as exc:
683
- raise ValueError(
684
- f"Invalid time label: '{time}'. Expected format:"
685
- f" '{constants.DATE_FORMAT}'"
686
- ) from exc
679
+ validator.validate_time_coord_format(da)
687
680
 
688
681
  if len(da.coords[constants.GEO].values.tolist()) == 1:
689
682
  da = da.assign_coords(