google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/checks.py +118 -116
  7. meridian/analysis/review/constants.py +3 -3
  8. meridian/analysis/review/results.py +131 -68
  9. meridian/analysis/review/reviewer.py +8 -23
  10. meridian/analysis/summarizer.py +6 -1
  11. meridian/analysis/test_utils.py +2898 -2538
  12. meridian/analysis/visualizer.py +28 -9
  13. meridian/backend/__init__.py +106 -0
  14. meridian/constants.py +1 -0
  15. meridian/data/input_data.py +30 -52
  16. meridian/data/input_data_builder.py +2 -9
  17. meridian/data/test_utils.py +25 -41
  18. meridian/data/validator.py +48 -0
  19. meridian/mlflow/autolog.py +19 -9
  20. meridian/model/adstock_hill.py +3 -5
  21. meridian/model/context.py +134 -0
  22. meridian/model/eda/constants.py +334 -4
  23. meridian/model/eda/eda_engine.py +724 -312
  24. meridian/model/eda/eda_outcome.py +177 -33
  25. meridian/model/model.py +159 -110
  26. meridian/model/model_test_data.py +38 -0
  27. meridian/model/posterior_sampler.py +103 -62
  28. meridian/model/prior_sampler.py +114 -94
  29. meridian/model/spec.py +23 -14
  30. meridian/templates/card.html.jinja +9 -7
  31. meridian/templates/chart.html.jinja +1 -6
  32. meridian/templates/finding.html.jinja +19 -0
  33. meridian/templates/findings.html.jinja +33 -0
  34. meridian/templates/formatter.py +41 -5
  35. meridian/templates/formatter_test.py +127 -0
  36. meridian/templates/style.css +66 -9
  37. meridian/templates/style.scss +85 -4
  38. meridian/templates/table.html.jinja +1 -0
  39. meridian/version.py +1 -1
  40. scenarioplanner/linkingapi/constants.py +1 -1
  41. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  42. schema/processors/marketing_processor.py +11 -10
  43. schema/processors/model_processor.py +4 -1
  44. schema/serde/distribution.py +12 -7
  45. schema/serde/hyperparameters.py +54 -107
  46. schema/serde/meridian_serde.py +12 -3
  47. schema/utils/__init__.py +1 -0
  48. schema/utils/proto_enum_converter.py +127 -0
  49. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
  50. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
@@ -48,7 +48,10 @@ class ModelDiagnostics:
48
48
 
49
49
  def __init__(self, meridian: model.Meridian, use_kpi: bool = False):
50
50
  self._meridian = meridian
51
- self._analyzer = analyzer.Analyzer(meridian)
51
+ self._analyzer = analyzer.Analyzer(
52
+ model_context=meridian.model_context,
53
+ inference_data=meridian.inference_data,
54
+ )
52
55
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
53
56
 
54
57
  @functools.lru_cache(maxsize=128)
@@ -271,11 +274,15 @@ class ModelDiagnostics:
271
274
  x=c.INDEPENDENT
272
275
  )
273
276
 
274
- return plot.properties(
275
- title=formatter.custom_title_params(
276
- summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
277
+ return (
278
+ plot.properties(
279
+ title=formatter.custom_title_params(
280
+ summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
281
+ )
277
282
  )
278
- ).configure_axis(**formatter.TEXT_CONFIG).interactive()
283
+ .configure_axis(**formatter.TEXT_CONFIG)
284
+ .interactive()
285
+ )
279
286
 
280
287
  def plot_rhat_boxplot(self) -> alt.Chart:
281
288
  """Plots the R-hat box plot.
@@ -387,7 +394,10 @@ class ModelFit:
387
394
  represented as a value between zero and one. Default is `0.9`.
388
395
  """
389
396
  self._meridian = meridian
390
- self._analyzer = analyzer.Analyzer(meridian)
397
+ self._analyzer = analyzer.Analyzer(
398
+ model_context=meridian.model_context,
399
+ inference_data=meridian.inference_data,
400
+ )
391
401
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
392
402
  self._model_fit_data = self._analyzer.expected_vs_actual_data(
393
403
  use_kpi=self._use_kpi, confidence_level=confidence_level
@@ -657,7 +667,10 @@ class ReachAndFrequency:
657
667
  use_kpi: If `True`, KPI is used instead of revenue.
658
668
  """
659
669
  self._meridian = meridian
660
- self._analyzer = analyzer.Analyzer(meridian)
670
+ self._analyzer = analyzer.Analyzer(
671
+ model_context=meridian.model_context,
672
+ inference_data=meridian.inference_data,
673
+ )
661
674
  self._selected_times = selected_times
662
675
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
663
676
  self._optimal_frequency_data = self._analyzer.optimal_freq(
@@ -857,7 +870,10 @@ class MediaEffects:
857
870
  the incremental revenue using the revenue per KPI (if available).
858
871
  """
859
872
  self._meridian = meridian
860
- self._analyzer = analyzer.Analyzer(meridian)
873
+ self._analyzer = analyzer.Analyzer(
874
+ model_context=meridian.model_context,
875
+ inference_data=meridian.inference_data,
876
+ )
861
877
  self._by_reach = by_reach
862
878
  self._use_kpi = self._analyzer._use_kpi(use_kpi)
863
879
 
@@ -1431,7 +1447,10 @@ class MediaSummary:
1431
1447
  use_kpi: If `True`, use KPI instead of revenue.
1432
1448
  """
1433
1449
  self._meridian = meridian
1434
- self._analyzer = analyzer.Analyzer(meridian)
1450
+ self._analyzer = analyzer.Analyzer(
1451
+ model_context=meridian.model_context,
1452
+ inference_data=meridian.inference_data,
1453
+ )
1435
1454
  self._confidence_level = confidence_level
1436
1455
  self._selected_times = selected_times
1437
1456
  self._marginal_roi_by_reach = marginal_roi_by_reach
@@ -909,6 +909,77 @@ if _BACKEND == config.Backend.JAX:
909
909
 
910
910
  xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
911
911
 
912
+ def _jax_adstock_process(
913
+ media: "_jax.Array", weights: "_jax.Array", n_times_output: int
914
+ ) -> "_jax.Array":
915
+ """JAX implementation for adstock_process using convolution.
916
+
917
+ This function applies an adstock process to media spend data using a
918
+ convolutional approach. The weights represent the adstock decay over time.
919
+
920
+ Args:
921
+ media: A JAX array of media spend. Expected shape is
922
+ `(batch_dims, n_geos, n_times_in, n_channels)`.
923
+ weights: A JAX array of adstock weights. Expected shape is
924
+ `(batch_dims, n_channels, window_size)`, where `batch_dims` must be
925
+ broadcastable to the batch dimensions of `media`.
926
+ n_times_output: The number of time periods in the output. This corresponds
927
+ to `n_times_in - window_size + 1`.
928
+
929
+ Returns:
930
+ A JAX array representing the adstocked media, with shape
931
+ `(batch_dims, n_geos, n_times_output, n_channels)`.
932
+ """
933
+
934
+ batch_dims = weights.shape[:-2]
935
+ if media.shape[:-3] != batch_dims:
936
+ media = jax_ops.broadcast_to(media, batch_dims + media.shape[-3:])
937
+
938
+ n_geos = media.shape[-3]
939
+ n_times_in = media.shape[-2]
940
+ n_channels = media.shape[-1]
941
+ window_size = weights.shape[-1]
942
+
943
+ perm = list(range(media.ndim))
944
+ perm[-2], perm[-1] = perm[-1], perm[-2]
945
+ media_transposed = jax_ops.transpose(media, perm)
946
+ media_reshaped = jax_ops.reshape(media_transposed, (1, -1, n_times_in))
947
+
948
+ total_channels = media_reshaped.shape[1]
949
+ weights_expanded = jax_ops.expand_dims(weights, -3)
950
+ weights_tiled = jax_ops.broadcast_to(
951
+ weights_expanded, batch_dims + (n_geos, n_channels, window_size)
952
+ )
953
+ kernel_reshaped = jax_ops.reshape(
954
+ weights_tiled, (total_channels, 1, window_size)
955
+ )
956
+
957
+ dn = jax.lax.conv_dimension_numbers(
958
+ media_reshaped.shape, kernel_reshaped.shape, ("NCH", "OIH", "NCH")
959
+ )
960
+
961
+ out = jax.lax.conv_general_dilated(
962
+ lhs=media_reshaped,
963
+ rhs=kernel_reshaped,
964
+ window_strides=(1,),
965
+ padding="VALID",
966
+ lhs_dilation=(1,),
967
+ rhs_dilation=(1,),
968
+ dimension_numbers=dn,
969
+ feature_group_count=total_channels,
970
+ precision=jax.lax.Precision.HIGHEST,
971
+ )
972
+
973
+ t_out = out.shape[-1]
974
+ out_reshaped = jax_ops.reshape(
975
+ out, batch_dims + (n_geos, n_channels, t_out)
976
+ )
977
+ perm_back = list(range(out_reshaped.ndim))
978
+ perm_back[-2], perm_back[-1] = perm_back[-1], perm_back[-2]
979
+ out_final = jax_ops.transpose(out_reshaped, perm_back)
980
+
981
+ return out_final[..., :n_times_output, :]
982
+
912
983
  _ops = jax_ops
913
984
  errors = _JaxErrors()
914
985
  Tensor = jax.Array
@@ -920,6 +991,7 @@ if _BACKEND == config.Backend.JAX:
920
991
 
921
992
  # Standardized Public API
922
993
  absolute = _ops.abs
994
+ adstock_process = _jax_adstock_process
923
995
  allclose = _ops.allclose
924
996
  arange = _jax_arange
925
997
  argmax = _jax_argmax
@@ -1059,6 +1131,39 @@ elif _BACKEND == config.Backend.TENSORFLOW:
1059
1131
 
1060
1132
  xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
1061
1133
 
1134
+ def _tf_adstock_process(
1135
+ media: "_tf.Tensor", weights: "_tf.Tensor", n_times_output: int
1136
+ ) -> "_tf.Tensor":
1137
+ """TensorFlow implementation for adstock_process using loop/einsum.
1138
+
1139
+ This function applies an adstock process to media spend data. It achieves
1140
+ this by creating a windowed view of the `media` tensor and then using
1141
+ `tf.einsum` to efficiently compute the weighted sum based on the provided
1142
+ `weights`. The `weights` tensor defines the decay effect over a specific
1143
+ `window_size`. The output is truncated to `n_times_output` periods.
1144
+
1145
+ Args:
1146
+ media: Input media tensor. Expected shape is `(..., num_geos,
1147
+ num_times_in, num_channels)`. The `...` represents optional batch
1148
+ dimensions.
1149
+ weights: Adstock weights tensor. Expected shape is `(..., num_channels,
1150
+ window_size)`. The batch dimensions must be broadcast-compatible with
1151
+ those in `media`.
1152
+ n_times_output: The number of time periods to output. This should be less
1153
+ than or equal to `num_times_in - window_size + 1`.
1154
+
1155
+ Returns:
1156
+ A tensor of shape `(..., num_geos, n_times_output, num_channels)`
1157
+ representing the adstocked media.
1158
+ """
1159
+
1160
+ window_size = weights.shape[-1]
1161
+ window_list = [
1162
+ media[..., i : i + n_times_output, :] for i in range(window_size)
1163
+ ]
1164
+ windowed = tf_backend.stack(window_list)
1165
+ return tf_backend.einsum("...cw,w...gtc->...gtc", weights, windowed)
1166
+
1062
1167
  tfd = tfp.distributions
1063
1168
  bijectors = tfp.bijectors
1064
1169
  experimental = tfp.experimental
@@ -1067,6 +1172,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
1067
1172
 
1068
1173
  # Standardized Public API
1069
1174
  absolute = _ops.math.abs
1175
+ adstock_process = _tf_adstock_process
1070
1176
  allclose = _ops.experimental.numpy.allclose
1071
1177
  arange = _tf_arange
1072
1178
  argmax = _tf_argmax
meridian/constants.py CHANGED
@@ -392,6 +392,7 @@ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
392
392
  ETA_RF,
393
393
  ETA_OM,
394
394
  ETA_ORF,
395
+ TAU_G,
395
396
  )
396
397
 
397
398
  MEDIA_PARAMETERS = (
@@ -20,13 +20,13 @@ The `InputData` class is used to store all the input data to the model.
20
20
  from collections import abc
21
21
  from collections.abc import Sequence
22
22
  import dataclasses
23
- import datetime as dt
24
23
  import functools
25
24
  import warnings
26
25
 
27
26
  from meridian import constants
28
27
  from meridian.data import arg_builder
29
28
  from meridian.data import time_coordinates as tc
29
+ from meridian.data import validator
30
30
  import numpy as np
31
31
  import xarray as xr
32
32
 
@@ -298,6 +298,7 @@ class InputData:
298
298
  self._validate_time_formats()
299
299
  self._validate_times()
300
300
  self._validate_geos()
301
+ self._validate_no_negative_values()
301
302
 
302
303
  def _convert_geos_to_strings(self):
303
304
  """Converts geo coordinates to strings in all relevant DataArrays."""
@@ -542,17 +543,36 @@ class InputData:
542
543
  f" `{constants.REVENUE}` or `{constants.NON_REVENUE}`."
543
544
  )
544
545
 
545
- if (self.kpi.values < 0).any():
546
- raise ValueError("KPI values must be non-negative.")
547
-
548
546
  if (
549
547
  self.revenue_per_kpi is not None
550
- and (self.revenue_per_kpi.values <= 0).all()
548
+ and (self.revenue_per_kpi.values == 0).all()
551
549
  ):
552
550
  raise ValueError(
553
- "Revenue per KPI values must not be all zero or negative."
551
+ "All Revenue per KPI values are 0, which can break the ROI"
552
+ " computation. If this is not a data error, please consider setting"
553
+ " revenue_per_kpi to None or follow the instructions at"
554
+ " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-default#default-total-paid-media-contribution-prior."
554
555
  )
555
556
 
557
+ def _validate_no_negative_values(self) -> None:
558
+ """Validates no negative values for applicable fields."""
559
+
560
+ fields_to_loggable_name = {
561
+ constants.MEDIA_SPEND: "Media Spend",
562
+ constants.RF_SPEND: "RF Spend",
563
+ constants.REACH: "Reach",
564
+ constants.FREQUENCY: "Frequency",
565
+ constants.ORGANIC_REACH: "Organic Reach",
566
+ constants.ORGANIC_FREQUENCY: "Organic Frequency",
567
+ constants.REVENUE_PER_KPI: "Revenue per KPI",
568
+ constants.KPI: "KPI",
569
+ }
570
+
571
+ for field, loggable_field in fields_to_loggable_name.items():
572
+ da = getattr(self, field)
573
+ if da is not None and (da.values < 0).any():
574
+ raise ValueError(f"{loggable_field} values must be non-negative.")
575
+
556
576
  def _validate_names(self):
557
577
  """Verifies that the names of the data arrays are correct."""
558
578
  # Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
@@ -762,52 +782,10 @@ class InputData:
762
782
 
763
783
  def _validate_time_formats(self):
764
784
  """Validates the time coordinate format for all variables."""
765
- self._validate_time_coord_format(self.kpi)
766
- self._validate_time_coord_format(self.revenue_per_kpi)
767
- self._validate_time_coord_format(self.controls)
768
- self._validate_time_coord_format(self.media)
769
- self._validate_time_coord_format(self.media_spend)
770
- self._validate_time_coord_format(self.reach)
771
- self._validate_time_coord_format(self.frequency)
772
- self._validate_time_coord_format(self.rf_spend)
773
- self._validate_time_coord_format(self.organic_media)
774
- self._validate_time_coord_format(self.organic_reach)
775
- self._validate_time_coord_format(self.organic_frequency)
776
- self._validate_time_coord_format(self.non_media_treatments)
777
-
778
- def _validate_time_coord_format(self, array: xr.DataArray | None):
779
- """Validates the `time` dimensions format of the selected DataArray.
780
-
781
- The `time` dimension of the selected array must have labels that are
782
- formatted in the Meridian conventional `"yyyy-mm-dd"` format.
783
-
784
- Args:
785
- array: An optional DataArray to validate.
786
- """
787
- if array is None:
788
- return
789
-
790
- time_values = array.coords.get(constants.TIME, None)
791
- if time_values is not None:
792
- for time in time_values:
793
- try:
794
- _ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
795
- except (TypeError, ValueError) as exc:
796
- raise ValueError(
797
- f"Invalid time label: {time.item()}. Expected format:"
798
- f" {constants.DATE_FORMAT}"
799
- ) from exc
800
-
801
- media_time_values = array.coords.get(constants.MEDIA_TIME, None)
802
- if media_time_values is not None:
803
- for time in media_time_values:
804
- try:
805
- _ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
806
- except (TypeError, ValueError) as exc:
807
- raise ValueError(
808
- f"Invalid media_time label: {time.item()}. Expected format:"
809
- f" {constants.DATE_FORMAT}"
810
- ) from exc
785
+ for field in dataclasses.fields(self):
786
+ attr = getattr(self, field.name)
787
+ if field.name != constants.POPULATION and isinstance(attr, xr.DataArray):
788
+ validator.validate_time_coord_format(attr)
811
789
 
812
790
  def _check_unique_names(self, dim: str, array: xr.DataArray | None):
813
791
  """Checks if a DataArray contains unique names on the specified dimension."""
@@ -21,11 +21,11 @@ validation logic and an overall final validation logic before a valid
21
21
 
22
22
  import abc
23
23
  from collections.abc import Sequence
24
- import datetime
25
24
  import warnings
26
25
  from meridian import constants
27
26
  from meridian.data import input_data
28
27
  from meridian.data import time_coordinates as tc
28
+ from meridian.data import validator
29
29
  import natsort
30
30
  import numpy as np
31
31
  import xarray as xr
@@ -676,14 +676,7 @@ class InputDataBuilder(abc.ABC):
676
676
 
677
677
  # Assume that the time coordinate labels are date-formatted strings.
678
678
  # We don't currently support other, arbitrary object types in the builder.
679
- for time in da.coords[time_dimension_name].values:
680
- try:
681
- _ = datetime.datetime.strptime(time, constants.DATE_FORMAT)
682
- except ValueError as exc:
683
- raise ValueError(
684
- f"Invalid time label: '{time}'. Expected format:"
685
- f" '{constants.DATE_FORMAT}'"
686
- ) from exc
679
+ validator.validate_time_coord_format(da)
687
680
 
688
681
  if len(da.coords[constants.GEO].values.tolist()) == 1:
689
682
  da = da.assign_coords(
@@ -1476,53 +1476,37 @@ def random_dataset(
1476
1476
  constant_value=constant_population_value,
1477
1477
  )
1478
1478
 
1479
- dataset = xr.combine_by_coords(
1480
- [kpi, population] + ([controls] if controls is not None else [])
1481
- )
1479
+ to_merge = [kpi, population]
1480
+ if controls is not None:
1481
+ to_merge.append(controls)
1482
1482
  if revenue_per_kpi is not None:
1483
- dataset = xr.combine_by_coords([dataset, revenue_per_kpi])
1483
+ to_merge.append(revenue_per_kpi)
1484
1484
  if media is not None:
1485
- media_renamed = (
1486
- media.rename({'media_time': 'time'}) if remove_media_time else media
1487
- )
1488
-
1489
- dataset = xr.combine_by_coords([dataset, media_renamed, media_spend])
1485
+ if remove_media_time:
1486
+ media = media.rename({'media_time': 'time'})
1487
+ to_merge.append(media)
1488
+ to_merge.append(media_spend)
1490
1489
  if reach is not None:
1491
- reach_renamed = (
1492
- reach.rename({'media_time': 'time'}) if remove_media_time else reach
1493
- )
1494
- frequency_renamed = (
1495
- frequency.rename({'media_time': 'time'})
1496
- if remove_media_time
1497
- else frequency
1498
- )
1499
- dataset = xr.combine_by_coords(
1500
- [dataset, reach_renamed, frequency_renamed, rf_spend]
1501
- )
1490
+ if remove_media_time:
1491
+ reach = reach.rename({'media_time': 'time'})
1492
+ frequency = frequency.rename({'media_time': 'time'})
1493
+ to_merge.append(reach)
1494
+ to_merge.append(frequency)
1495
+ to_merge.append(rf_spend)
1502
1496
  if non_media_treatments is not None:
1503
- dataset = xr.combine_by_coords([dataset, non_media_treatments])
1497
+ to_merge.append(non_media_treatments)
1504
1498
  if organic_media is not None:
1505
- organic_media_renamed = (
1506
- organic_media.rename({'media_time': 'time'})
1507
- if remove_media_time
1508
- else organic_media
1509
- )
1510
- dataset = xr.combine_by_coords([dataset, organic_media_renamed])
1499
+ if remove_media_time:
1500
+ organic_media = organic_media.rename({'media_time': 'time'})
1501
+ to_merge.append(organic_media)
1511
1502
  if organic_reach is not None:
1512
- organic_reach_renamed = (
1513
- organic_reach.rename({'media_time': 'time'})
1514
- if remove_media_time
1515
- else organic_reach
1516
- )
1517
- organic_frequency_renamed = (
1518
- organic_frequency.rename({'media_time': 'time'})
1519
- if remove_media_time
1520
- else organic_frequency
1521
- )
1522
- dataset = xr.combine_by_coords(
1523
- [dataset, organic_reach_renamed, organic_frequency_renamed]
1524
- )
1525
- return dataset
1503
+ if remove_media_time:
1504
+ organic_reach = organic_reach.rename({'media_time': 'time'})
1505
+ organic_frequency = organic_frequency.rename({'media_time': 'time'})
1506
+ to_merge.append(organic_reach)
1507
+ to_merge.append(organic_frequency)
1508
+
1509
+ return xr.merge(to_merge, join='outer', compat='no_conflicts')
1526
1510
 
1527
1511
 
1528
1512
  def dataset_to_dataframe(
@@ -0,0 +1,48 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """This module contains common validation functions for Meridian data."""
16
+
17
+ import datetime as dt
18
+ from meridian import constants
19
+ import xarray as xr
20
+
21
+
22
+ def validate_time_coord_format(array: xr.DataArray | None):
23
+ """Validates the `time` dimensions format of the selected DataArray.
24
+
25
+ The `time` dimension of the selected array must have labels that are
26
+ formatted in the Meridian conventional `"yyyy-mm-dd"` format.
27
+
28
+ Args:
29
+ array: An optional DataArray to validate.
30
+ """
31
+ if array is None:
32
+ return
33
+
34
+ # The component data arrays from the input data builders that call this helper
35
+ # method should only have one of either `media_time` or `time` as its time
36
+ # dimension.
37
+ target_coords = [constants.TIME, constants.MEDIA_TIME]
38
+
39
+ for coord_name in target_coords:
40
+ if (values := array.coords.get(coord_name)) is not None:
41
+ for time in values:
42
+ try:
43
+ dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
44
+ except (TypeError, ValueError) as exc:
45
+ raise ValueError(
46
+ f"Invalid {coord_name} label: {time.item()!r}. "
47
+ f"Expected format: '{constants.DATE_FORMAT}'"
48
+ ) from exc
@@ -70,6 +70,7 @@ import dataclasses
70
70
  import inspect
71
71
  import json
72
72
  from typing import Any, Callable
73
+ import warnings
73
74
 
74
75
  import arviz as az
75
76
  from meridian import backend
@@ -180,16 +181,25 @@ def autolog(
180
181
  f"sample_posterior.{param}", kwargs.get(param, "default")
181
182
  )
182
183
 
183
- original(self, *args, **kwargs)
184
+ result = original(self, *args, **kwargs)
184
185
  if log_metrics:
185
- model_diagnostics = visualizer.ModelDiagnostics(self.model)
186
- df_diag = model_diagnostics.predictive_accuracy_table()
187
-
188
- get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
189
-
190
- mlflow.log_metric("R_Squared", get_metric("R_Squared"))
191
- mlflow.log_metric("MAPE", get_metric("MAPE"))
192
- mlflow.log_metric("wMAPE", get_metric("wMAPE"))
186
+ # TODO: Direct injection of `model.Meridian` object into
187
+ # `PosteriorMCMCSampler` is deprecated. Revisit patching method here.
188
+ if self.model is not None:
189
+ model_diagnostics = visualizer.ModelDiagnostics(self.model)
190
+ df_diag = model_diagnostics.predictive_accuracy_table()
191
+
192
+ get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
193
+
194
+ mlflow.log_metric("R_Squared", get_metric("R_Squared"))
195
+ mlflow.log_metric("MAPE", get_metric("MAPE"))
196
+ mlflow.log_metric("wMAPE", get_metric("wMAPE"))
197
+ else:
198
+ warnings.warn(
199
+ "log_metrics=True is not supported when PosteriorMCMCSampler is"
200
+ " initialized with model_context."
201
+ )
202
+ return result
193
203
 
194
204
  safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
195
205
  safe_patch(
@@ -279,10 +279,6 @@ def _adstock(
279
279
  media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
280
280
 
281
281
  # Adstock calculation.
282
- window_list = [None] * window_size
283
- for i in range(window_size):
284
- window_list[i] = media[..., i : i + n_times_output, :]
285
- windowed = backend.stack(window_list)
286
282
  l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
287
283
  weights = compute_decay_weights(
288
284
  alpha=alpha,
@@ -291,7 +287,9 @@ def _adstock(
291
287
  decay_functions=decay_functions,
292
288
  normalize=True,
293
289
  )
294
- return backend.einsum('...mw,w...gtm->...gtm', weights, windowed)
290
+ return backend.adstock_process(
291
+ media=media, weights=weights, n_times_output=n_times_output
292
+ )
295
293
 
296
294
 
297
295
  def _map_alpha_for_binomial_decay(x: backend.Tensor):