google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/checks.py +118 -116
- meridian/analysis/review/constants.py +3 -3
- meridian/analysis/review/results.py +131 -68
- meridian/analysis/review/reviewer.py +8 -23
- meridian/analysis/summarizer.py +6 -1
- meridian/analysis/test_utils.py +2898 -2538
- meridian/analysis/visualizer.py +28 -9
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +1 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +25 -41
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +134 -0
- meridian/model/eda/constants.py +334 -4
- meridian/model/eda/eda_engine.py +724 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/model.py +159 -110
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/linkingapi/constants.py +1 -1
- scenarioplanner/mmm_ui_proto_generator.py +1 -0
- schema/processors/marketing_processor.py +11 -10
- schema/processors/model_processor.py +4 -1
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +12 -3
- schema/utils/__init__.py +1 -0
- schema/utils/proto_enum_converter.py +127 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
meridian/analysis/visualizer.py
CHANGED
|
@@ -48,7 +48,10 @@ class ModelDiagnostics:
|
|
|
48
48
|
|
|
49
49
|
def __init__(self, meridian: model.Meridian, use_kpi: bool = False):
|
|
50
50
|
self._meridian = meridian
|
|
51
|
-
self._analyzer = analyzer.Analyzer(
|
|
51
|
+
self._analyzer = analyzer.Analyzer(
|
|
52
|
+
model_context=meridian.model_context,
|
|
53
|
+
inference_data=meridian.inference_data,
|
|
54
|
+
)
|
|
52
55
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
53
56
|
|
|
54
57
|
@functools.lru_cache(maxsize=128)
|
|
@@ -271,11 +274,15 @@ class ModelDiagnostics:
|
|
|
271
274
|
x=c.INDEPENDENT
|
|
272
275
|
)
|
|
273
276
|
|
|
274
|
-
return
|
|
275
|
-
|
|
276
|
-
|
|
277
|
+
return (
|
|
278
|
+
plot.properties(
|
|
279
|
+
title=formatter.custom_title_params(
|
|
280
|
+
summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
|
|
281
|
+
)
|
|
277
282
|
)
|
|
278
|
-
|
|
283
|
+
.configure_axis(**formatter.TEXT_CONFIG)
|
|
284
|
+
.interactive()
|
|
285
|
+
)
|
|
279
286
|
|
|
280
287
|
def plot_rhat_boxplot(self) -> alt.Chart:
|
|
281
288
|
"""Plots the R-hat box plot.
|
|
@@ -387,7 +394,10 @@ class ModelFit:
|
|
|
387
394
|
represented as a value between zero and one. Default is `0.9`.
|
|
388
395
|
"""
|
|
389
396
|
self._meridian = meridian
|
|
390
|
-
self._analyzer = analyzer.Analyzer(
|
|
397
|
+
self._analyzer = analyzer.Analyzer(
|
|
398
|
+
model_context=meridian.model_context,
|
|
399
|
+
inference_data=meridian.inference_data,
|
|
400
|
+
)
|
|
391
401
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
392
402
|
self._model_fit_data = self._analyzer.expected_vs_actual_data(
|
|
393
403
|
use_kpi=self._use_kpi, confidence_level=confidence_level
|
|
@@ -657,7 +667,10 @@ class ReachAndFrequency:
|
|
|
657
667
|
use_kpi: If `True`, KPI is used instead of revenue.
|
|
658
668
|
"""
|
|
659
669
|
self._meridian = meridian
|
|
660
|
-
self._analyzer = analyzer.Analyzer(
|
|
670
|
+
self._analyzer = analyzer.Analyzer(
|
|
671
|
+
model_context=meridian.model_context,
|
|
672
|
+
inference_data=meridian.inference_data,
|
|
673
|
+
)
|
|
661
674
|
self._selected_times = selected_times
|
|
662
675
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
663
676
|
self._optimal_frequency_data = self._analyzer.optimal_freq(
|
|
@@ -857,7 +870,10 @@ class MediaEffects:
|
|
|
857
870
|
the incremental revenue using the revenue per KPI (if available).
|
|
858
871
|
"""
|
|
859
872
|
self._meridian = meridian
|
|
860
|
-
self._analyzer = analyzer.Analyzer(
|
|
873
|
+
self._analyzer = analyzer.Analyzer(
|
|
874
|
+
model_context=meridian.model_context,
|
|
875
|
+
inference_data=meridian.inference_data,
|
|
876
|
+
)
|
|
861
877
|
self._by_reach = by_reach
|
|
862
878
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
863
879
|
|
|
@@ -1431,7 +1447,10 @@ class MediaSummary:
|
|
|
1431
1447
|
use_kpi: If `True`, use KPI instead of revenue.
|
|
1432
1448
|
"""
|
|
1433
1449
|
self._meridian = meridian
|
|
1434
|
-
self._analyzer = analyzer.Analyzer(
|
|
1450
|
+
self._analyzer = analyzer.Analyzer(
|
|
1451
|
+
model_context=meridian.model_context,
|
|
1452
|
+
inference_data=meridian.inference_data,
|
|
1453
|
+
)
|
|
1435
1454
|
self._confidence_level = confidence_level
|
|
1436
1455
|
self._selected_times = selected_times
|
|
1437
1456
|
self._marginal_roi_by_reach = marginal_roi_by_reach
|
meridian/backend/__init__.py
CHANGED
|
@@ -909,6 +909,77 @@ if _BACKEND == config.Backend.JAX:
|
|
|
909
909
|
|
|
910
910
|
xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
|
|
911
911
|
|
|
912
|
+
def _jax_adstock_process(
|
|
913
|
+
media: "_jax.Array", weights: "_jax.Array", n_times_output: int
|
|
914
|
+
) -> "_jax.Array":
|
|
915
|
+
"""JAX implementation for adstock_process using convolution.
|
|
916
|
+
|
|
917
|
+
This function applies an adstock process to media spend data using a
|
|
918
|
+
convolutional approach. The weights represent the adstock decay over time.
|
|
919
|
+
|
|
920
|
+
Args:
|
|
921
|
+
media: A JAX array of media spend. Expected shape is
|
|
922
|
+
`(batch_dims, n_geos, n_times_in, n_channels)`.
|
|
923
|
+
weights: A JAX array of adstock weights. Expected shape is
|
|
924
|
+
`(batch_dims, n_channels, window_size)`, where `batch_dims` must be
|
|
925
|
+
broadcastable to the batch dimensions of `media`.
|
|
926
|
+
n_times_output: The number of time periods in the output. This corresponds
|
|
927
|
+
to `n_times_in - window_size + 1`.
|
|
928
|
+
|
|
929
|
+
Returns:
|
|
930
|
+
A JAX array representing the adstocked media, with shape
|
|
931
|
+
`(batch_dims, n_geos, n_times_output, n_channels)`.
|
|
932
|
+
"""
|
|
933
|
+
|
|
934
|
+
batch_dims = weights.shape[:-2]
|
|
935
|
+
if media.shape[:-3] != batch_dims:
|
|
936
|
+
media = jax_ops.broadcast_to(media, batch_dims + media.shape[-3:])
|
|
937
|
+
|
|
938
|
+
n_geos = media.shape[-3]
|
|
939
|
+
n_times_in = media.shape[-2]
|
|
940
|
+
n_channels = media.shape[-1]
|
|
941
|
+
window_size = weights.shape[-1]
|
|
942
|
+
|
|
943
|
+
perm = list(range(media.ndim))
|
|
944
|
+
perm[-2], perm[-1] = perm[-1], perm[-2]
|
|
945
|
+
media_transposed = jax_ops.transpose(media, perm)
|
|
946
|
+
media_reshaped = jax_ops.reshape(media_transposed, (1, -1, n_times_in))
|
|
947
|
+
|
|
948
|
+
total_channels = media_reshaped.shape[1]
|
|
949
|
+
weights_expanded = jax_ops.expand_dims(weights, -3)
|
|
950
|
+
weights_tiled = jax_ops.broadcast_to(
|
|
951
|
+
weights_expanded, batch_dims + (n_geos, n_channels, window_size)
|
|
952
|
+
)
|
|
953
|
+
kernel_reshaped = jax_ops.reshape(
|
|
954
|
+
weights_tiled, (total_channels, 1, window_size)
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
dn = jax.lax.conv_dimension_numbers(
|
|
958
|
+
media_reshaped.shape, kernel_reshaped.shape, ("NCH", "OIH", "NCH")
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
out = jax.lax.conv_general_dilated(
|
|
962
|
+
lhs=media_reshaped,
|
|
963
|
+
rhs=kernel_reshaped,
|
|
964
|
+
window_strides=(1,),
|
|
965
|
+
padding="VALID",
|
|
966
|
+
lhs_dilation=(1,),
|
|
967
|
+
rhs_dilation=(1,),
|
|
968
|
+
dimension_numbers=dn,
|
|
969
|
+
feature_group_count=total_channels,
|
|
970
|
+
precision=jax.lax.Precision.HIGHEST,
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
t_out = out.shape[-1]
|
|
974
|
+
out_reshaped = jax_ops.reshape(
|
|
975
|
+
out, batch_dims + (n_geos, n_channels, t_out)
|
|
976
|
+
)
|
|
977
|
+
perm_back = list(range(out_reshaped.ndim))
|
|
978
|
+
perm_back[-2], perm_back[-1] = perm_back[-1], perm_back[-2]
|
|
979
|
+
out_final = jax_ops.transpose(out_reshaped, perm_back)
|
|
980
|
+
|
|
981
|
+
return out_final[..., :n_times_output, :]
|
|
982
|
+
|
|
912
983
|
_ops = jax_ops
|
|
913
984
|
errors = _JaxErrors()
|
|
914
985
|
Tensor = jax.Array
|
|
@@ -920,6 +991,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
920
991
|
|
|
921
992
|
# Standardized Public API
|
|
922
993
|
absolute = _ops.abs
|
|
994
|
+
adstock_process = _jax_adstock_process
|
|
923
995
|
allclose = _ops.allclose
|
|
924
996
|
arange = _jax_arange
|
|
925
997
|
argmax = _jax_argmax
|
|
@@ -1059,6 +1131,39 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
1059
1131
|
|
|
1060
1132
|
xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
|
|
1061
1133
|
|
|
1134
|
+
def _tf_adstock_process(
|
|
1135
|
+
media: "_tf.Tensor", weights: "_tf.Tensor", n_times_output: int
|
|
1136
|
+
) -> "_tf.Tensor":
|
|
1137
|
+
"""TensorFlow implementation for adstock_process using loop/einsum.
|
|
1138
|
+
|
|
1139
|
+
This function applies an adstock process to media spend data. It achieves
|
|
1140
|
+
this by creating a windowed view of the `media` tensor and then using
|
|
1141
|
+
`tf.einsum` to efficiently compute the weighted sum based on the provided
|
|
1142
|
+
`weights`. The `weights` tensor defines the decay effect over a specific
|
|
1143
|
+
`window_size`. The output is truncated to `n_times_output` periods.
|
|
1144
|
+
|
|
1145
|
+
Args:
|
|
1146
|
+
media: Input media tensor. Expected shape is `(..., num_geos,
|
|
1147
|
+
num_times_in, num_channels)`. The `...` represents optional batch
|
|
1148
|
+
dimensions.
|
|
1149
|
+
weights: Adstock weights tensor. Expected shape is `(..., num_channels,
|
|
1150
|
+
window_size)`. The batch dimensions must be broadcast-compatible with
|
|
1151
|
+
those in `media`.
|
|
1152
|
+
n_times_output: The number of time periods to output. This should be less
|
|
1153
|
+
than or equal to `num_times_in - window_size + 1`.
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
A tensor of shape `(..., num_geos, n_times_output, num_channels)`
|
|
1157
|
+
representing the adstocked media.
|
|
1158
|
+
"""
|
|
1159
|
+
|
|
1160
|
+
window_size = weights.shape[-1]
|
|
1161
|
+
window_list = [
|
|
1162
|
+
media[..., i : i + n_times_output, :] for i in range(window_size)
|
|
1163
|
+
]
|
|
1164
|
+
windowed = tf_backend.stack(window_list)
|
|
1165
|
+
return tf_backend.einsum("...cw,w...gtc->...gtc", weights, windowed)
|
|
1166
|
+
|
|
1062
1167
|
tfd = tfp.distributions
|
|
1063
1168
|
bijectors = tfp.bijectors
|
|
1064
1169
|
experimental = tfp.experimental
|
|
@@ -1067,6 +1172,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
1067
1172
|
|
|
1068
1173
|
# Standardized Public API
|
|
1069
1174
|
absolute = _ops.math.abs
|
|
1175
|
+
adstock_process = _tf_adstock_process
|
|
1070
1176
|
allclose = _ops.experimental.numpy.allclose
|
|
1071
1177
|
arange = _tf_arange
|
|
1072
1178
|
argmax = _tf_argmax
|
meridian/constants.py
CHANGED
meridian/data/input_data.py
CHANGED
|
@@ -20,13 +20,13 @@ The `InputData` class is used to store all the input data to the model.
|
|
|
20
20
|
from collections import abc
|
|
21
21
|
from collections.abc import Sequence
|
|
22
22
|
import dataclasses
|
|
23
|
-
import datetime as dt
|
|
24
23
|
import functools
|
|
25
24
|
import warnings
|
|
26
25
|
|
|
27
26
|
from meridian import constants
|
|
28
27
|
from meridian.data import arg_builder
|
|
29
28
|
from meridian.data import time_coordinates as tc
|
|
29
|
+
from meridian.data import validator
|
|
30
30
|
import numpy as np
|
|
31
31
|
import xarray as xr
|
|
32
32
|
|
|
@@ -298,6 +298,7 @@ class InputData:
|
|
|
298
298
|
self._validate_time_formats()
|
|
299
299
|
self._validate_times()
|
|
300
300
|
self._validate_geos()
|
|
301
|
+
self._validate_no_negative_values()
|
|
301
302
|
|
|
302
303
|
def _convert_geos_to_strings(self):
|
|
303
304
|
"""Converts geo coordinates to strings in all relevant DataArrays."""
|
|
@@ -542,17 +543,36 @@ class InputData:
|
|
|
542
543
|
f" `{constants.REVENUE}` or `{constants.NON_REVENUE}`."
|
|
543
544
|
)
|
|
544
545
|
|
|
545
|
-
if (self.kpi.values < 0).any():
|
|
546
|
-
raise ValueError("KPI values must be non-negative.")
|
|
547
|
-
|
|
548
546
|
if (
|
|
549
547
|
self.revenue_per_kpi is not None
|
|
550
|
-
and (self.revenue_per_kpi.values
|
|
548
|
+
and (self.revenue_per_kpi.values == 0).all()
|
|
551
549
|
):
|
|
552
550
|
raise ValueError(
|
|
553
|
-
"Revenue per KPI values
|
|
551
|
+
"All Revenue per KPI values are 0, which can break the ROI"
|
|
552
|
+
" computation. If this is not a data error, please consider setting"
|
|
553
|
+
" revenue_per_kpi to None or follow the instructions at"
|
|
554
|
+
" https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-default#default-total-paid-media-contribution-prior."
|
|
554
555
|
)
|
|
555
556
|
|
|
557
|
+
def _validate_no_negative_values(self) -> None:
|
|
558
|
+
"""Validates no negative values for applicable fields."""
|
|
559
|
+
|
|
560
|
+
fields_to_loggable_name = {
|
|
561
|
+
constants.MEDIA_SPEND: "Media Spend",
|
|
562
|
+
constants.RF_SPEND: "RF Spend",
|
|
563
|
+
constants.REACH: "Reach",
|
|
564
|
+
constants.FREQUENCY: "Frequency",
|
|
565
|
+
constants.ORGANIC_REACH: "Organic Reach",
|
|
566
|
+
constants.ORGANIC_FREQUENCY: "Organic Frequency",
|
|
567
|
+
constants.REVENUE_PER_KPI: "Revenue per KPI",
|
|
568
|
+
constants.KPI: "KPI",
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
for field, loggable_field in fields_to_loggable_name.items():
|
|
572
|
+
da = getattr(self, field)
|
|
573
|
+
if da is not None and (da.values < 0).any():
|
|
574
|
+
raise ValueError(f"{loggable_field} values must be non-negative.")
|
|
575
|
+
|
|
556
576
|
def _validate_names(self):
|
|
557
577
|
"""Verifies that the names of the data arrays are correct."""
|
|
558
578
|
# Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
|
|
@@ -762,52 +782,10 @@ class InputData:
|
|
|
762
782
|
|
|
763
783
|
def _validate_time_formats(self):
|
|
764
784
|
"""Validates the time coordinate format for all variables."""
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
self._validate_time_coord_format(self.media_spend)
|
|
770
|
-
self._validate_time_coord_format(self.reach)
|
|
771
|
-
self._validate_time_coord_format(self.frequency)
|
|
772
|
-
self._validate_time_coord_format(self.rf_spend)
|
|
773
|
-
self._validate_time_coord_format(self.organic_media)
|
|
774
|
-
self._validate_time_coord_format(self.organic_reach)
|
|
775
|
-
self._validate_time_coord_format(self.organic_frequency)
|
|
776
|
-
self._validate_time_coord_format(self.non_media_treatments)
|
|
777
|
-
|
|
778
|
-
def _validate_time_coord_format(self, array: xr.DataArray | None):
|
|
779
|
-
"""Validates the `time` dimensions format of the selected DataArray.
|
|
780
|
-
|
|
781
|
-
The `time` dimension of the selected array must have labels that are
|
|
782
|
-
formatted in the Meridian conventional `"yyyy-mm-dd"` format.
|
|
783
|
-
|
|
784
|
-
Args:
|
|
785
|
-
array: An optional DataArray to validate.
|
|
786
|
-
"""
|
|
787
|
-
if array is None:
|
|
788
|
-
return
|
|
789
|
-
|
|
790
|
-
time_values = array.coords.get(constants.TIME, None)
|
|
791
|
-
if time_values is not None:
|
|
792
|
-
for time in time_values:
|
|
793
|
-
try:
|
|
794
|
-
_ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
|
|
795
|
-
except (TypeError, ValueError) as exc:
|
|
796
|
-
raise ValueError(
|
|
797
|
-
f"Invalid time label: {time.item()}. Expected format:"
|
|
798
|
-
f" {constants.DATE_FORMAT}"
|
|
799
|
-
) from exc
|
|
800
|
-
|
|
801
|
-
media_time_values = array.coords.get(constants.MEDIA_TIME, None)
|
|
802
|
-
if media_time_values is not None:
|
|
803
|
-
for time in media_time_values:
|
|
804
|
-
try:
|
|
805
|
-
_ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
|
|
806
|
-
except (TypeError, ValueError) as exc:
|
|
807
|
-
raise ValueError(
|
|
808
|
-
f"Invalid media_time label: {time.item()}. Expected format:"
|
|
809
|
-
f" {constants.DATE_FORMAT}"
|
|
810
|
-
) from exc
|
|
785
|
+
for field in dataclasses.fields(self):
|
|
786
|
+
attr = getattr(self, field.name)
|
|
787
|
+
if field.name != constants.POPULATION and isinstance(attr, xr.DataArray):
|
|
788
|
+
validator.validate_time_coord_format(attr)
|
|
811
789
|
|
|
812
790
|
def _check_unique_names(self, dim: str, array: xr.DataArray | None):
|
|
813
791
|
"""Checks if a DataArray contains unique names on the specified dimension."""
|
|
@@ -21,11 +21,11 @@ validation logic and an overall final validation logic before a valid
|
|
|
21
21
|
|
|
22
22
|
import abc
|
|
23
23
|
from collections.abc import Sequence
|
|
24
|
-
import datetime
|
|
25
24
|
import warnings
|
|
26
25
|
from meridian import constants
|
|
27
26
|
from meridian.data import input_data
|
|
28
27
|
from meridian.data import time_coordinates as tc
|
|
28
|
+
from meridian.data import validator
|
|
29
29
|
import natsort
|
|
30
30
|
import numpy as np
|
|
31
31
|
import xarray as xr
|
|
@@ -676,14 +676,7 @@ class InputDataBuilder(abc.ABC):
|
|
|
676
676
|
|
|
677
677
|
# Assume that the time coordinate labels are date-formatted strings.
|
|
678
678
|
# We don't currently support other, arbitrary object types in the builder.
|
|
679
|
-
|
|
680
|
-
try:
|
|
681
|
-
_ = datetime.datetime.strptime(time, constants.DATE_FORMAT)
|
|
682
|
-
except ValueError as exc:
|
|
683
|
-
raise ValueError(
|
|
684
|
-
f"Invalid time label: '{time}'. Expected format:"
|
|
685
|
-
f" '{constants.DATE_FORMAT}'"
|
|
686
|
-
) from exc
|
|
679
|
+
validator.validate_time_coord_format(da)
|
|
687
680
|
|
|
688
681
|
if len(da.coords[constants.GEO].values.tolist()) == 1:
|
|
689
682
|
da = da.assign_coords(
|
meridian/data/test_utils.py
CHANGED
|
@@ -1476,53 +1476,37 @@ def random_dataset(
|
|
|
1476
1476
|
constant_value=constant_population_value,
|
|
1477
1477
|
)
|
|
1478
1478
|
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1479
|
+
to_merge = [kpi, population]
|
|
1480
|
+
if controls is not None:
|
|
1481
|
+
to_merge.append(controls)
|
|
1482
1482
|
if revenue_per_kpi is not None:
|
|
1483
|
-
|
|
1483
|
+
to_merge.append(revenue_per_kpi)
|
|
1484
1484
|
if media is not None:
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
)
|
|
1488
|
-
|
|
1489
|
-
dataset = xr.combine_by_coords([dataset, media_renamed, media_spend])
|
|
1485
|
+
if remove_media_time:
|
|
1486
|
+
media = media.rename({'media_time': 'time'})
|
|
1487
|
+
to_merge.append(media)
|
|
1488
|
+
to_merge.append(media_spend)
|
|
1490
1489
|
if reach is not None:
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
else frequency
|
|
1498
|
-
)
|
|
1499
|
-
dataset = xr.combine_by_coords(
|
|
1500
|
-
[dataset, reach_renamed, frequency_renamed, rf_spend]
|
|
1501
|
-
)
|
|
1490
|
+
if remove_media_time:
|
|
1491
|
+
reach = reach.rename({'media_time': 'time'})
|
|
1492
|
+
frequency = frequency.rename({'media_time': 'time'})
|
|
1493
|
+
to_merge.append(reach)
|
|
1494
|
+
to_merge.append(frequency)
|
|
1495
|
+
to_merge.append(rf_spend)
|
|
1502
1496
|
if non_media_treatments is not None:
|
|
1503
|
-
|
|
1497
|
+
to_merge.append(non_media_treatments)
|
|
1504
1498
|
if organic_media is not None:
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
else organic_media
|
|
1509
|
-
)
|
|
1510
|
-
dataset = xr.combine_by_coords([dataset, organic_media_renamed])
|
|
1499
|
+
if remove_media_time:
|
|
1500
|
+
organic_media = organic_media.rename({'media_time': 'time'})
|
|
1501
|
+
to_merge.append(organic_media)
|
|
1511
1502
|
if organic_reach is not None:
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
)
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
if remove_media_time
|
|
1520
|
-
else organic_frequency
|
|
1521
|
-
)
|
|
1522
|
-
dataset = xr.combine_by_coords(
|
|
1523
|
-
[dataset, organic_reach_renamed, organic_frequency_renamed]
|
|
1524
|
-
)
|
|
1525
|
-
return dataset
|
|
1503
|
+
if remove_media_time:
|
|
1504
|
+
organic_reach = organic_reach.rename({'media_time': 'time'})
|
|
1505
|
+
organic_frequency = organic_frequency.rename({'media_time': 'time'})
|
|
1506
|
+
to_merge.append(organic_reach)
|
|
1507
|
+
to_merge.append(organic_frequency)
|
|
1508
|
+
|
|
1509
|
+
return xr.merge(to_merge, join='outer', compat='no_conflicts')
|
|
1526
1510
|
|
|
1527
1511
|
|
|
1528
1512
|
def dataset_to_dataframe(
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""This module contains common validation functions for Meridian data."""
|
|
16
|
+
|
|
17
|
+
import datetime as dt
|
|
18
|
+
from meridian import constants
|
|
19
|
+
import xarray as xr
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_time_coord_format(array: xr.DataArray | None):
|
|
23
|
+
"""Validates the `time` dimensions format of the selected DataArray.
|
|
24
|
+
|
|
25
|
+
The `time` dimension of the selected array must have labels that are
|
|
26
|
+
formatted in the Meridian conventional `"yyyy-mm-dd"` format.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
array: An optional DataArray to validate.
|
|
30
|
+
"""
|
|
31
|
+
if array is None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
# The component data arrays from the input data builders that call this helper
|
|
35
|
+
# method should only have one of either `media_time` or `time` as its time
|
|
36
|
+
# dimension.
|
|
37
|
+
target_coords = [constants.TIME, constants.MEDIA_TIME]
|
|
38
|
+
|
|
39
|
+
for coord_name in target_coords:
|
|
40
|
+
if (values := array.coords.get(coord_name)) is not None:
|
|
41
|
+
for time in values:
|
|
42
|
+
try:
|
|
43
|
+
dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
|
|
44
|
+
except (TypeError, ValueError) as exc:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Invalid {coord_name} label: {time.item()!r}. "
|
|
47
|
+
f"Expected format: '{constants.DATE_FORMAT}'"
|
|
48
|
+
) from exc
|
meridian/mlflow/autolog.py
CHANGED
|
@@ -70,6 +70,7 @@ import dataclasses
|
|
|
70
70
|
import inspect
|
|
71
71
|
import json
|
|
72
72
|
from typing import Any, Callable
|
|
73
|
+
import warnings
|
|
73
74
|
|
|
74
75
|
import arviz as az
|
|
75
76
|
from meridian import backend
|
|
@@ -180,16 +181,25 @@ def autolog(
|
|
|
180
181
|
f"sample_posterior.{param}", kwargs.get(param, "default")
|
|
181
182
|
)
|
|
182
183
|
|
|
183
|
-
original(self, *args, **kwargs)
|
|
184
|
+
result = original(self, *args, **kwargs)
|
|
184
185
|
if log_metrics:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
186
|
+
# TODO: Direct injection of `model.Meridian` object into
|
|
187
|
+
# `PosteriorMCMCSampler` is deprecated. Revisit patching method here.
|
|
188
|
+
if self.model is not None:
|
|
189
|
+
model_diagnostics = visualizer.ModelDiagnostics(self.model)
|
|
190
|
+
df_diag = model_diagnostics.predictive_accuracy_table()
|
|
191
|
+
|
|
192
|
+
get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
|
|
193
|
+
|
|
194
|
+
mlflow.log_metric("R_Squared", get_metric("R_Squared"))
|
|
195
|
+
mlflow.log_metric("MAPE", get_metric("MAPE"))
|
|
196
|
+
mlflow.log_metric("wMAPE", get_metric("wMAPE"))
|
|
197
|
+
else:
|
|
198
|
+
warnings.warn(
|
|
199
|
+
"log_metrics=True is not supported when PosteriorMCMCSampler is"
|
|
200
|
+
" initialized with model_context."
|
|
201
|
+
)
|
|
202
|
+
return result
|
|
193
203
|
|
|
194
204
|
safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
|
|
195
205
|
safe_patch(
|
meridian/model/adstock_hill.py
CHANGED
|
@@ -279,10 +279,6 @@ def _adstock(
|
|
|
279
279
|
media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
|
|
280
280
|
|
|
281
281
|
# Adstock calculation.
|
|
282
|
-
window_list = [None] * window_size
|
|
283
|
-
for i in range(window_size):
|
|
284
|
-
window_list[i] = media[..., i : i + n_times_output, :]
|
|
285
|
-
windowed = backend.stack(window_list)
|
|
286
282
|
l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
|
|
287
283
|
weights = compute_decay_weights(
|
|
288
284
|
alpha=alpha,
|
|
@@ -291,7 +287,9 @@ def _adstock(
|
|
|
291
287
|
decay_functions=decay_functions,
|
|
292
288
|
normalize=True,
|
|
293
289
|
)
|
|
294
|
-
return backend.
|
|
290
|
+
return backend.adstock_process(
|
|
291
|
+
media=media, weights=weights, n_times_output=n_times_output
|
|
292
|
+
)
|
|
295
293
|
|
|
296
294
|
|
|
297
295
|
def _map_alpha_for_binomial_decay(x: backend.Tensor):
|