google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
meridian/analysis/visualizer.py
CHANGED
|
@@ -48,7 +48,10 @@ class ModelDiagnostics:
|
|
|
48
48
|
|
|
49
49
|
def __init__(self, meridian: model.Meridian, use_kpi: bool = False):
|
|
50
50
|
self._meridian = meridian
|
|
51
|
-
self._analyzer = analyzer.Analyzer(
|
|
51
|
+
self._analyzer = analyzer.Analyzer(
|
|
52
|
+
model_context=meridian.model_context,
|
|
53
|
+
inference_data=meridian.inference_data,
|
|
54
|
+
)
|
|
52
55
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
53
56
|
|
|
54
57
|
@functools.lru_cache(maxsize=128)
|
|
@@ -243,6 +246,12 @@ class ModelDiagnostics:
|
|
|
243
246
|
|
|
244
247
|
groupby = posterior_df.columns.tolist()
|
|
245
248
|
groupby.remove(parameter)
|
|
249
|
+
|
|
250
|
+
parameter_99_max = prior_posterior_df[parameter].quantile(0.99)
|
|
251
|
+
# Remove outliers that make the chart hard to read.
|
|
252
|
+
prior_posterior_df[parameter] = prior_posterior_df[parameter].clip(
|
|
253
|
+
upper=parameter_99_max * c.OUTLIER_CLIP_FACTOR
|
|
254
|
+
)
|
|
246
255
|
plot = (
|
|
247
256
|
alt.Chart(prior_posterior_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
248
257
|
.transform_density(
|
|
@@ -265,11 +274,15 @@ class ModelDiagnostics:
|
|
|
265
274
|
x=c.INDEPENDENT
|
|
266
275
|
)
|
|
267
276
|
|
|
268
|
-
return
|
|
269
|
-
|
|
270
|
-
|
|
277
|
+
return (
|
|
278
|
+
plot.properties(
|
|
279
|
+
title=formatter.custom_title_params(
|
|
280
|
+
summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
|
|
281
|
+
)
|
|
271
282
|
)
|
|
272
|
-
|
|
283
|
+
.configure_axis(**formatter.TEXT_CONFIG)
|
|
284
|
+
.interactive()
|
|
285
|
+
)
|
|
273
286
|
|
|
274
287
|
def plot_rhat_boxplot(self) -> alt.Chart:
|
|
275
288
|
"""Plots the R-hat box plot.
|
|
@@ -381,7 +394,10 @@ class ModelFit:
|
|
|
381
394
|
represented as a value between zero and one. Default is `0.9`.
|
|
382
395
|
"""
|
|
383
396
|
self._meridian = meridian
|
|
384
|
-
self._analyzer = analyzer.Analyzer(
|
|
397
|
+
self._analyzer = analyzer.Analyzer(
|
|
398
|
+
model_context=meridian.model_context,
|
|
399
|
+
inference_data=meridian.inference_data,
|
|
400
|
+
)
|
|
385
401
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
386
402
|
self._model_fit_data = self._analyzer.expected_vs_actual_data(
|
|
387
403
|
use_kpi=self._use_kpi, confidence_level=confidence_level
|
|
@@ -651,7 +667,10 @@ class ReachAndFrequency:
|
|
|
651
667
|
use_kpi: If `True`, KPI is used instead of revenue.
|
|
652
668
|
"""
|
|
653
669
|
self._meridian = meridian
|
|
654
|
-
self._analyzer = analyzer.Analyzer(
|
|
670
|
+
self._analyzer = analyzer.Analyzer(
|
|
671
|
+
model_context=meridian.model_context,
|
|
672
|
+
inference_data=meridian.inference_data,
|
|
673
|
+
)
|
|
655
674
|
self._selected_times = selected_times
|
|
656
675
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
657
676
|
self._optimal_frequency_data = self._analyzer.optimal_freq(
|
|
@@ -851,7 +870,10 @@ class MediaEffects:
|
|
|
851
870
|
the incremental revenue using the revenue per KPI (if available).
|
|
852
871
|
"""
|
|
853
872
|
self._meridian = meridian
|
|
854
|
-
self._analyzer = analyzer.Analyzer(
|
|
873
|
+
self._analyzer = analyzer.Analyzer(
|
|
874
|
+
model_context=meridian.model_context,
|
|
875
|
+
inference_data=meridian.inference_data,
|
|
876
|
+
)
|
|
855
877
|
self._by_reach = by_reach
|
|
856
878
|
self._use_kpi = self._analyzer._use_kpi(use_kpi)
|
|
857
879
|
|
|
@@ -1425,7 +1447,10 @@ class MediaSummary:
|
|
|
1425
1447
|
use_kpi: If `True`, use KPI instead of revenue.
|
|
1426
1448
|
"""
|
|
1427
1449
|
self._meridian = meridian
|
|
1428
|
-
self._analyzer = analyzer.Analyzer(
|
|
1450
|
+
self._analyzer = analyzer.Analyzer(
|
|
1451
|
+
model_context=meridian.model_context,
|
|
1452
|
+
inference_data=meridian.inference_data,
|
|
1453
|
+
)
|
|
1429
1454
|
self._confidence_level = confidence_level
|
|
1430
1455
|
self._selected_times = selected_times
|
|
1431
1456
|
self._marginal_roi_by_reach = marginal_roi_by_reach
|
|
@@ -1450,17 +1475,15 @@ class MediaSummary:
|
|
|
1450
1475
|
|
|
1451
1476
|
Args:
|
|
1452
1477
|
aggregate_times: If `True`, aggregates the metrics across all time
|
|
1453
|
-
periods.
|
|
1478
|
+
periods. If `False`, returns time-varying metrics.
|
|
1454
1479
|
|
|
1455
1480
|
Returns:
|
|
1456
1481
|
An `xarray.Dataset` containing the following:
|
|
1457
1482
|
- **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
|
|
1458
|
-
|
|
1459
|
-
`distribution` (`prior`, `posterior`)
|
|
1483
|
+
`ci_hi`), `distribution` (`prior`, `posterior`)
|
|
1460
1484
|
- **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
|
|
1461
1485
|
`pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
|
|
1462
|
-
`roi`,
|
|
1463
|
-
`effectiveness`, `mroi`.
|
|
1486
|
+
`roi`, `effectiveness`, `mroi`.
|
|
1464
1487
|
"""
|
|
1465
1488
|
return self._analyzer.summary_metrics(
|
|
1466
1489
|
selected_times=self._selected_times,
|
meridian/backend/__init__.py
CHANGED
|
@@ -909,6 +909,77 @@ if _BACKEND == config.Backend.JAX:
|
|
|
909
909
|
|
|
910
910
|
xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
|
|
911
911
|
|
|
912
|
+
def _jax_adstock_process(
|
|
913
|
+
media: "_jax.Array", weights: "_jax.Array", n_times_output: int
|
|
914
|
+
) -> "_jax.Array":
|
|
915
|
+
"""JAX implementation for adstock_process using convolution.
|
|
916
|
+
|
|
917
|
+
This function applies an adstock process to media spend data using a
|
|
918
|
+
convolutional approach. The weights represent the adstock decay over time.
|
|
919
|
+
|
|
920
|
+
Args:
|
|
921
|
+
media: A JAX array of media spend. Expected shape is
|
|
922
|
+
`(batch_dims, n_geos, n_times_in, n_channels)`.
|
|
923
|
+
weights: A JAX array of adstock weights. Expected shape is
|
|
924
|
+
`(batch_dims, n_channels, window_size)`, where `batch_dims` must be
|
|
925
|
+
broadcastable to the batch dimensions of `media`.
|
|
926
|
+
n_times_output: The number of time periods in the output. This corresponds
|
|
927
|
+
to `n_times_in - window_size + 1`.
|
|
928
|
+
|
|
929
|
+
Returns:
|
|
930
|
+
A JAX array representing the adstocked media, with shape
|
|
931
|
+
`(batch_dims, n_geos, n_times_output, n_channels)`.
|
|
932
|
+
"""
|
|
933
|
+
|
|
934
|
+
batch_dims = weights.shape[:-2]
|
|
935
|
+
if media.shape[:-3] != batch_dims:
|
|
936
|
+
media = jax_ops.broadcast_to(media, batch_dims + media.shape[-3:])
|
|
937
|
+
|
|
938
|
+
n_geos = media.shape[-3]
|
|
939
|
+
n_times_in = media.shape[-2]
|
|
940
|
+
n_channels = media.shape[-1]
|
|
941
|
+
window_size = weights.shape[-1]
|
|
942
|
+
|
|
943
|
+
perm = list(range(media.ndim))
|
|
944
|
+
perm[-2], perm[-1] = perm[-1], perm[-2]
|
|
945
|
+
media_transposed = jax_ops.transpose(media, perm)
|
|
946
|
+
media_reshaped = jax_ops.reshape(media_transposed, (1, -1, n_times_in))
|
|
947
|
+
|
|
948
|
+
total_channels = media_reshaped.shape[1]
|
|
949
|
+
weights_expanded = jax_ops.expand_dims(weights, -3)
|
|
950
|
+
weights_tiled = jax_ops.broadcast_to(
|
|
951
|
+
weights_expanded, batch_dims + (n_geos, n_channels, window_size)
|
|
952
|
+
)
|
|
953
|
+
kernel_reshaped = jax_ops.reshape(
|
|
954
|
+
weights_tiled, (total_channels, 1, window_size)
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
dn = jax.lax.conv_dimension_numbers(
|
|
958
|
+
media_reshaped.shape, kernel_reshaped.shape, ("NCH", "OIH", "NCH")
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
out = jax.lax.conv_general_dilated(
|
|
962
|
+
lhs=media_reshaped,
|
|
963
|
+
rhs=kernel_reshaped,
|
|
964
|
+
window_strides=(1,),
|
|
965
|
+
padding="VALID",
|
|
966
|
+
lhs_dilation=(1,),
|
|
967
|
+
rhs_dilation=(1,),
|
|
968
|
+
dimension_numbers=dn,
|
|
969
|
+
feature_group_count=total_channels,
|
|
970
|
+
precision=jax.lax.Precision.HIGHEST,
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
t_out = out.shape[-1]
|
|
974
|
+
out_reshaped = jax_ops.reshape(
|
|
975
|
+
out, batch_dims + (n_geos, n_channels, t_out)
|
|
976
|
+
)
|
|
977
|
+
perm_back = list(range(out_reshaped.ndim))
|
|
978
|
+
perm_back[-2], perm_back[-1] = perm_back[-1], perm_back[-2]
|
|
979
|
+
out_final = jax_ops.transpose(out_reshaped, perm_back)
|
|
980
|
+
|
|
981
|
+
return out_final[..., :n_times_output, :]
|
|
982
|
+
|
|
912
983
|
_ops = jax_ops
|
|
913
984
|
errors = _JaxErrors()
|
|
914
985
|
Tensor = jax.Array
|
|
@@ -920,6 +991,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
920
991
|
|
|
921
992
|
# Standardized Public API
|
|
922
993
|
absolute = _ops.abs
|
|
994
|
+
adstock_process = _jax_adstock_process
|
|
923
995
|
allclose = _ops.allclose
|
|
924
996
|
arange = _jax_arange
|
|
925
997
|
argmax = _jax_argmax
|
|
@@ -1059,6 +1131,39 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
1059
1131
|
|
|
1060
1132
|
xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
|
|
1061
1133
|
|
|
1134
|
+
def _tf_adstock_process(
|
|
1135
|
+
media: "_tf.Tensor", weights: "_tf.Tensor", n_times_output: int
|
|
1136
|
+
) -> "_tf.Tensor":
|
|
1137
|
+
"""TensorFlow implementation for adstock_process using loop/einsum.
|
|
1138
|
+
|
|
1139
|
+
This function applies an adstock process to media spend data. It achieves
|
|
1140
|
+
this by creating a windowed view of the `media` tensor and then using
|
|
1141
|
+
`tf.einsum` to efficiently compute the weighted sum based on the provided
|
|
1142
|
+
`weights`. The `weights` tensor defines the decay effect over a specific
|
|
1143
|
+
`window_size`. The output is truncated to `n_times_output` periods.
|
|
1144
|
+
|
|
1145
|
+
Args:
|
|
1146
|
+
media: Input media tensor. Expected shape is `(..., num_geos,
|
|
1147
|
+
num_times_in, num_channels)`. The `...` represents optional batch
|
|
1148
|
+
dimensions.
|
|
1149
|
+
weights: Adstock weights tensor. Expected shape is `(..., num_channels,
|
|
1150
|
+
window_size)`. The batch dimensions must be broadcast-compatible with
|
|
1151
|
+
those in `media`.
|
|
1152
|
+
n_times_output: The number of time periods to output. This should be less
|
|
1153
|
+
than or equal to `num_times_in - window_size + 1`.
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
A tensor of shape `(..., num_geos, n_times_output, num_channels)`
|
|
1157
|
+
representing the adstocked media.
|
|
1158
|
+
"""
|
|
1159
|
+
|
|
1160
|
+
window_size = weights.shape[-1]
|
|
1161
|
+
window_list = [
|
|
1162
|
+
media[..., i : i + n_times_output, :] for i in range(window_size)
|
|
1163
|
+
]
|
|
1164
|
+
windowed = tf_backend.stack(window_list)
|
|
1165
|
+
return tf_backend.einsum("...cw,w...gtc->...gtc", weights, windowed)
|
|
1166
|
+
|
|
1062
1167
|
tfd = tfp.distributions
|
|
1063
1168
|
bijectors = tfp.bijectors
|
|
1064
1169
|
experimental = tfp.experimental
|
|
@@ -1067,6 +1172,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
1067
1172
|
|
|
1068
1173
|
# Standardized Public API
|
|
1069
1174
|
absolute = _ops.math.abs
|
|
1175
|
+
adstock_process = _tf_adstock_process
|
|
1070
1176
|
allclose = _ops.experimental.numpy.allclose
|
|
1071
1177
|
arange = _tf_arange
|
|
1072
1178
|
argmax = _tf_argmax
|
meridian/constants.py
CHANGED
|
@@ -392,6 +392,7 @@ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
|
|
|
392
392
|
ETA_RF,
|
|
393
393
|
ETA_OM,
|
|
394
394
|
ETA_ORF,
|
|
395
|
+
TAU_G,
|
|
395
396
|
)
|
|
396
397
|
|
|
397
398
|
MEDIA_PARAMETERS = (
|
|
@@ -755,6 +756,7 @@ STROKE_DASH = (4, 2)
|
|
|
755
756
|
POINT_SIZE = 80
|
|
756
757
|
INDEPENDENT = 'independent'
|
|
757
758
|
RESPONSE_CURVE_STEP_SIZE = 0.01
|
|
759
|
+
OUTLIER_CLIP_FACTOR = 1.2
|
|
758
760
|
|
|
759
761
|
|
|
760
762
|
# Font names.
|
meridian/data/input_data.py
CHANGED
|
@@ -20,13 +20,13 @@ The `InputData` class is used to store all the input data to the model.
|
|
|
20
20
|
from collections import abc
|
|
21
21
|
from collections.abc import Sequence
|
|
22
22
|
import dataclasses
|
|
23
|
-
import datetime as dt
|
|
24
23
|
import functools
|
|
25
24
|
import warnings
|
|
26
25
|
|
|
27
26
|
from meridian import constants
|
|
28
27
|
from meridian.data import arg_builder
|
|
29
28
|
from meridian.data import time_coordinates as tc
|
|
29
|
+
from meridian.data import validator
|
|
30
30
|
import numpy as np
|
|
31
31
|
import xarray as xr
|
|
32
32
|
|
|
@@ -298,6 +298,7 @@ class InputData:
|
|
|
298
298
|
self._validate_time_formats()
|
|
299
299
|
self._validate_times()
|
|
300
300
|
self._validate_geos()
|
|
301
|
+
self._validate_no_negative_values()
|
|
301
302
|
|
|
302
303
|
def _convert_geos_to_strings(self):
|
|
303
304
|
"""Converts geo coordinates to strings in all relevant DataArrays."""
|
|
@@ -542,17 +543,36 @@ class InputData:
|
|
|
542
543
|
f" `{constants.REVENUE}` or `{constants.NON_REVENUE}`."
|
|
543
544
|
)
|
|
544
545
|
|
|
545
|
-
if (self.kpi.values < 0).any():
|
|
546
|
-
raise ValueError("KPI values must be non-negative.")
|
|
547
|
-
|
|
548
546
|
if (
|
|
549
547
|
self.revenue_per_kpi is not None
|
|
550
|
-
and (self.revenue_per_kpi.values
|
|
548
|
+
and (self.revenue_per_kpi.values == 0).all()
|
|
551
549
|
):
|
|
552
550
|
raise ValueError(
|
|
553
|
-
"Revenue per KPI values
|
|
551
|
+
"All Revenue per KPI values are 0, which can break the ROI"
|
|
552
|
+
" computation. If this is not a data error, please consider setting"
|
|
553
|
+
" revenue_per_kpi to None or follow the instructions at"
|
|
554
|
+
" https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-default#default-total-paid-media-contribution-prior."
|
|
554
555
|
)
|
|
555
556
|
|
|
557
|
+
def _validate_no_negative_values(self) -> None:
|
|
558
|
+
"""Validates no negative values for applicable fields."""
|
|
559
|
+
|
|
560
|
+
fields_to_loggable_name = {
|
|
561
|
+
constants.MEDIA_SPEND: "Media Spend",
|
|
562
|
+
constants.RF_SPEND: "RF Spend",
|
|
563
|
+
constants.REACH: "Reach",
|
|
564
|
+
constants.FREQUENCY: "Frequency",
|
|
565
|
+
constants.ORGANIC_REACH: "Organic Reach",
|
|
566
|
+
constants.ORGANIC_FREQUENCY: "Organic Frequency",
|
|
567
|
+
constants.REVENUE_PER_KPI: "Revenue per KPI",
|
|
568
|
+
constants.KPI: "KPI",
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
for field, loggable_field in fields_to_loggable_name.items():
|
|
572
|
+
da = getattr(self, field)
|
|
573
|
+
if da is not None and (da.values < 0).any():
|
|
574
|
+
raise ValueError(f"{loggable_field} values must be non-negative.")
|
|
575
|
+
|
|
556
576
|
def _validate_names(self):
|
|
557
577
|
"""Verifies that the names of the data arrays are correct."""
|
|
558
578
|
# Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
|
|
@@ -762,52 +782,10 @@ class InputData:
|
|
|
762
782
|
|
|
763
783
|
def _validate_time_formats(self):
|
|
764
784
|
"""Validates the time coordinate format for all variables."""
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
self._validate_time_coord_format(self.media_spend)
|
|
770
|
-
self._validate_time_coord_format(self.reach)
|
|
771
|
-
self._validate_time_coord_format(self.frequency)
|
|
772
|
-
self._validate_time_coord_format(self.rf_spend)
|
|
773
|
-
self._validate_time_coord_format(self.organic_media)
|
|
774
|
-
self._validate_time_coord_format(self.organic_reach)
|
|
775
|
-
self._validate_time_coord_format(self.organic_frequency)
|
|
776
|
-
self._validate_time_coord_format(self.non_media_treatments)
|
|
777
|
-
|
|
778
|
-
def _validate_time_coord_format(self, array: xr.DataArray | None):
|
|
779
|
-
"""Validates the `time` dimensions format of the selected DataArray.
|
|
780
|
-
|
|
781
|
-
The `time` dimension of the selected array must have labels that are
|
|
782
|
-
formatted in the Meridian conventional `"yyyy-mm-dd"` format.
|
|
783
|
-
|
|
784
|
-
Args:
|
|
785
|
-
array: An optional DataArray to validate.
|
|
786
|
-
"""
|
|
787
|
-
if array is None:
|
|
788
|
-
return
|
|
789
|
-
|
|
790
|
-
time_values = array.coords.get(constants.TIME, None)
|
|
791
|
-
if time_values is not None:
|
|
792
|
-
for time in time_values:
|
|
793
|
-
try:
|
|
794
|
-
_ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
|
|
795
|
-
except (TypeError, ValueError) as exc:
|
|
796
|
-
raise ValueError(
|
|
797
|
-
f"Invalid time label: {time.item()}. Expected format:"
|
|
798
|
-
f" {constants.DATE_FORMAT}"
|
|
799
|
-
) from exc
|
|
800
|
-
|
|
801
|
-
media_time_values = array.coords.get(constants.MEDIA_TIME, None)
|
|
802
|
-
if media_time_values is not None:
|
|
803
|
-
for time in media_time_values:
|
|
804
|
-
try:
|
|
805
|
-
_ = dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
|
|
806
|
-
except (TypeError, ValueError) as exc:
|
|
807
|
-
raise ValueError(
|
|
808
|
-
f"Invalid media_time label: {time.item()}. Expected format:"
|
|
809
|
-
f" {constants.DATE_FORMAT}"
|
|
810
|
-
) from exc
|
|
785
|
+
for field in dataclasses.fields(self):
|
|
786
|
+
attr = getattr(self, field.name)
|
|
787
|
+
if field.name != constants.POPULATION and isinstance(attr, xr.DataArray):
|
|
788
|
+
validator.validate_time_coord_format(attr)
|
|
811
789
|
|
|
812
790
|
def _check_unique_names(self, dim: str, array: xr.DataArray | None):
|
|
813
791
|
"""Checks if a DataArray contains unique names on the specified dimension."""
|
|
@@ -21,11 +21,11 @@ validation logic and an overall final validation logic before a valid
|
|
|
21
21
|
|
|
22
22
|
import abc
|
|
23
23
|
from collections.abc import Sequence
|
|
24
|
-
import datetime
|
|
25
24
|
import warnings
|
|
26
25
|
from meridian import constants
|
|
27
26
|
from meridian.data import input_data
|
|
28
27
|
from meridian.data import time_coordinates as tc
|
|
28
|
+
from meridian.data import validator
|
|
29
29
|
import natsort
|
|
30
30
|
import numpy as np
|
|
31
31
|
import xarray as xr
|
|
@@ -676,14 +676,7 @@ class InputDataBuilder(abc.ABC):
|
|
|
676
676
|
|
|
677
677
|
# Assume that the time coordinate labels are date-formatted strings.
|
|
678
678
|
# We don't currently support other, arbitrary object types in the builder.
|
|
679
|
-
|
|
680
|
-
try:
|
|
681
|
-
_ = datetime.datetime.strptime(time, constants.DATE_FORMAT)
|
|
682
|
-
except ValueError as exc:
|
|
683
|
-
raise ValueError(
|
|
684
|
-
f"Invalid time label: '{time}'. Expected format:"
|
|
685
|
-
f" '{constants.DATE_FORMAT}'"
|
|
686
|
-
) from exc
|
|
679
|
+
validator.validate_time_coord_format(da)
|
|
687
680
|
|
|
688
681
|
if len(da.coords[constants.GEO].values.tolist()) == 1:
|
|
689
682
|
da = da.assign_coords(
|