google-meridian 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.0.6.dist-info → google_meridian-1.0.8.dist-info}/METADATA +11 -10
- {google_meridian-1.0.6.dist-info → google_meridian-1.0.8.dist-info}/RECORD +18 -18
- {google_meridian-1.0.6.dist-info → google_meridian-1.0.8.dist-info}/WHEEL +1 -1
- meridian/__init__.py +1 -1
- meridian/analysis/analyzer.py +383 -320
- meridian/analysis/optimizer.py +531 -269
- meridian/analysis/summarizer.py +21 -3
- meridian/analysis/summary_text.py +20 -1
- meridian/analysis/templates/chart.html.jinja +1 -0
- meridian/analysis/test_utils.py +47 -99
- meridian/analysis/visualizer.py +407 -83
- meridian/constants.py +31 -0
- meridian/data/input_data.py +49 -5
- meridian/data/load.py +10 -7
- meridian/model/model.py +5 -4
- meridian/model/posterior_sampler.py +15 -5
- {google_meridian-1.0.6.dist-info → google_meridian-1.0.8.dist-info/licenses}/LICENSE +0 -0
- {google_meridian-1.0.6.dist-info → google_meridian-1.0.8.dist-info}/top_level.txt +0 -0
meridian/constants.py
CHANGED
|
@@ -95,6 +95,35 @@ POSSIBLE_INPUT_DATA_ARRAY_NAMES = (
|
|
|
95
95
|
+ MEDIA_INPUT_DATA_ARRAY_NAMES
|
|
96
96
|
+ RF_INPUT_DATA_ARRAY_NAMES
|
|
97
97
|
)
|
|
98
|
+
PAID_DATA = (
|
|
99
|
+
MEDIA,
|
|
100
|
+
REACH,
|
|
101
|
+
FREQUENCY,
|
|
102
|
+
REVENUE_PER_KPI,
|
|
103
|
+
)
|
|
104
|
+
NON_PAID_DATA = (
|
|
105
|
+
ORGANIC_MEDIA,
|
|
106
|
+
ORGANIC_REACH,
|
|
107
|
+
ORGANIC_FREQUENCY,
|
|
108
|
+
NON_MEDIA_TREATMENTS,
|
|
109
|
+
)
|
|
110
|
+
SPEND_DATA = (
|
|
111
|
+
MEDIA_SPEND,
|
|
112
|
+
RF_SPEND,
|
|
113
|
+
)
|
|
114
|
+
PERFORMANCE_DATA = PAID_DATA + SPEND_DATA
|
|
115
|
+
IMPRESSIONS_DATA = (
|
|
116
|
+
MEDIA,
|
|
117
|
+
REACH,
|
|
118
|
+
FREQUENCY,
|
|
119
|
+
) + NON_PAID_DATA
|
|
120
|
+
RF_DATA = (
|
|
121
|
+
REACH,
|
|
122
|
+
FREQUENCY,
|
|
123
|
+
RF_SPEND,
|
|
124
|
+
REVENUE_PER_KPI,
|
|
125
|
+
)
|
|
126
|
+
NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
|
|
98
127
|
|
|
99
128
|
# Scaled input data variables.
|
|
100
129
|
MEDIA_SCALED = 'media_scaled'
|
|
@@ -543,6 +572,7 @@ TARGET_ROI = 'target_roi'
|
|
|
543
572
|
TARGET_MROI = 'target_mroi'
|
|
544
573
|
SPEND_CONSTRAINT_DEFAULT_FIXED_BUDGET = 0.3
|
|
545
574
|
SPEND_CONSTRAINT_DEFAULT_FLEXIBLE_BUDGET = 1.0
|
|
575
|
+
SPEND_CONSTRAINT_DEFAULT = 1.0
|
|
546
576
|
|
|
547
577
|
|
|
548
578
|
# Plot constants.
|
|
@@ -591,3 +621,4 @@ CARD_STATS = 'stats'
|
|
|
591
621
|
|
|
592
622
|
# VegaLite common params.
|
|
593
623
|
VEGALITE_FACET_DEFAULT_WIDTH = 400
|
|
624
|
+
VEGALITE_FACET_LARGE_WIDTH = 500
|
meridian/data/input_data.py
CHANGED
|
@@ -401,6 +401,7 @@ class InputData:
|
|
|
401
401
|
)
|
|
402
402
|
|
|
403
403
|
def _validate_kpi(self):
|
|
404
|
+
"""Validates the KPI data."""
|
|
404
405
|
if (
|
|
405
406
|
self.kpi_type != constants.REVENUE
|
|
406
407
|
and self.kpi_type != constants.NON_REVENUE
|
|
@@ -413,6 +414,14 @@ class InputData:
|
|
|
413
414
|
if (self.kpi.values < 0).any():
|
|
414
415
|
raise ValueError("KPI values must be non-negative.")
|
|
415
416
|
|
|
417
|
+
if (
|
|
418
|
+
self.revenue_per_kpi is not None
|
|
419
|
+
and (self.revenue_per_kpi.values <= 0).all()
|
|
420
|
+
):
|
|
421
|
+
raise ValueError(
|
|
422
|
+
"Revenue per KPI values must not be all zero or negative."
|
|
423
|
+
)
|
|
424
|
+
|
|
416
425
|
def _validate_names(self):
|
|
417
426
|
"""Verifies that the names of the data arrays are correct."""
|
|
418
427
|
arrays = [
|
|
@@ -534,15 +543,50 @@ class InputData:
|
|
|
534
543
|
def _validate_media_channels(self):
|
|
535
544
|
"""Verifies Meridian media channel names invariants.
|
|
536
545
|
|
|
537
|
-
In the input data,
|
|
538
|
-
`rf_channel`
|
|
546
|
+
In the input data, channel names across `media_channel`,
|
|
547
|
+
`rf_channel`, `organic_media_channel`, `organic_rf_channel`,
|
|
548
|
+
`non_media_channel` must be unique.
|
|
539
549
|
"""
|
|
540
550
|
all_channels = self.get_all_channels()
|
|
541
551
|
if len(np.unique(all_channels)) != all_channels.size:
|
|
542
|
-
|
|
543
|
-
"
|
|
544
|
-
"
|
|
552
|
+
error_msg = (
|
|
553
|
+
"Channel names across `media_channel`, `rf_channel`,"
|
|
554
|
+
" `organic_media_channel`, `organic_rf_channel`, and"
|
|
555
|
+
" `non_media_channel` must be unique."
|
|
545
556
|
)
|
|
557
|
+
# For each channel, store all occurrences of the channel in particular
|
|
558
|
+
# channel type.
|
|
559
|
+
from_channel_to_type = {}
|
|
560
|
+
for channel in all_channels:
|
|
561
|
+
if channel not in from_channel_to_type:
|
|
562
|
+
from_channel_to_type[channel] = []
|
|
563
|
+
|
|
564
|
+
# pytype: disable=attribute-error
|
|
565
|
+
if self.media_channel is not None:
|
|
566
|
+
for channel in self.media_channel.values:
|
|
567
|
+
from_channel_to_type[channel].append(constants.MEDIA_CHANNEL)
|
|
568
|
+
if self.rf_channel is not None:
|
|
569
|
+
for channel in self.rf_channel.values:
|
|
570
|
+
from_channel_to_type[channel].append(constants.RF_CHANNEL)
|
|
571
|
+
if self.organic_media_channel is not None:
|
|
572
|
+
for channel in self.organic_media_channel.values:
|
|
573
|
+
from_channel_to_type[channel].append(constants.ORGANIC_MEDIA_CHANNEL)
|
|
574
|
+
if self.organic_rf_channel is not None:
|
|
575
|
+
for channel in self.organic_rf_channel.values:
|
|
576
|
+
from_channel_to_type[channel].append(constants.ORGANIC_RF_CHANNEL)
|
|
577
|
+
if self.non_media_channel is not None:
|
|
578
|
+
for channel in self.non_media_channel.values:
|
|
579
|
+
from_channel_to_type[channel].append(constants.NON_MEDIA_CHANNEL)
|
|
580
|
+
# pytype: enable=attribute-error
|
|
581
|
+
|
|
582
|
+
for channel, types in from_channel_to_type.items():
|
|
583
|
+
if len(types) > 1:
|
|
584
|
+
error_msg += (
|
|
585
|
+
f" Channel `{channel}` is present in multiple channel types:"
|
|
586
|
+
f" {types}."
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
raise ValueError(error_msg)
|
|
546
590
|
|
|
547
591
|
def _validate_times(self):
|
|
548
592
|
"""Validates time coordinate values."""
|
meridian/data/load.py
CHANGED
|
@@ -950,7 +950,7 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
950
950
|
raise ValueError('NA values found in the organic_frequency columns.')
|
|
951
951
|
|
|
952
952
|
# Determine columns in which NAs are expected in the lagged-media period.
|
|
953
|
-
|
|
953
|
+
not_lagged_columns = []
|
|
954
954
|
coords = [
|
|
955
955
|
constants.KPI,
|
|
956
956
|
constants.CONTROLS,
|
|
@@ -967,12 +967,12 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
967
967
|
for coord in coords:
|
|
968
968
|
columns = getattr(self.coord_to_columns, coord)
|
|
969
969
|
columns = [columns] if isinstance(columns, str) else columns
|
|
970
|
-
|
|
970
|
+
not_lagged_columns.extend(columns)
|
|
971
971
|
|
|
972
972
|
# Dates with at least one non-NA value in columns different from media,
|
|
973
973
|
# reach, frequency, organic_media, organic_reach, and organic_frequency.
|
|
974
974
|
time_column_name = self.coord_to_columns.time
|
|
975
|
-
no_na_period = self.df[(~self.df[
|
|
975
|
+
no_na_period = self.df[(~self.df[not_lagged_columns].isna()).any(axis=1)][
|
|
976
976
|
time_column_name
|
|
977
977
|
].unique()
|
|
978
978
|
|
|
@@ -999,13 +999,16 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
999
999
|
# organic_frequency.
|
|
1000
1000
|
not_lagged_data = self.df.loc[
|
|
1001
1001
|
self.df[time_column_name].isin(no_na_period),
|
|
1002
|
-
|
|
1002
|
+
not_lagged_columns,
|
|
1003
1003
|
]
|
|
1004
1004
|
if not_lagged_data.isna().any(axis=None):
|
|
1005
|
+
incorrect_columns = []
|
|
1006
|
+
for column in not_lagged_columns:
|
|
1007
|
+
if not_lagged_data[column].isna().any(axis=None):
|
|
1008
|
+
incorrect_columns.append(column)
|
|
1005
1009
|
raise ValueError(
|
|
1006
|
-
'NA values found in
|
|
1007
|
-
|
|
1008
|
-
' non-media columns).'
|
|
1010
|
+
f'NA values found in columns {incorrect_columns} within the modeling'
|
|
1011
|
+
' time window (time periods where the KPI is modeled).'
|
|
1009
1012
|
)
|
|
1010
1013
|
|
|
1011
1014
|
def load(self) -> input_data.InputData:
|
meridian/model/model.py
CHANGED
|
@@ -1030,7 +1030,7 @@ class Meridian:
|
|
|
1030
1030
|
max_energy_diff: float = 500.0,
|
|
1031
1031
|
unrolled_leapfrog_steps: int = 1,
|
|
1032
1032
|
parallel_iterations: int = 10,
|
|
1033
|
-
seed: Sequence[int] | None = None,
|
|
1033
|
+
seed: Sequence[int] | int | None = None,
|
|
1034
1034
|
**pins,
|
|
1035
1035
|
):
|
|
1036
1036
|
"""Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
|
|
@@ -1080,9 +1080,10 @@ class Meridian:
|
|
|
1080
1080
|
trajectory length implied by `max_tree_depth`. Defaults is `1`.
|
|
1081
1081
|
parallel_iterations: Number of iterations allowed to run in parallel. Must
|
|
1082
1082
|
be a positive integer. For more information, see `tf.while_loop`.
|
|
1083
|
-
seed:
|
|
1084
|
-
|
|
1085
|
-
|
|
1083
|
+
seed: An `int32[2]` Tensor or a Python list or tuple of 2 `int`s, which
|
|
1084
|
+
will be treated as stateless seeds; or a Python `int` or `None`, which
|
|
1085
|
+
will be treated as stateful seeds. See [tfp.random.sanitize_seed]
|
|
1086
|
+
(https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed).
|
|
1086
1087
|
**pins: These are used to condition the provided joint distribution, and
|
|
1087
1088
|
are passed directly to `joint_dist.experimental_pin(**pins)`.
|
|
1088
1089
|
|
|
@@ -393,7 +393,7 @@ class PosteriorMCMCSampler:
|
|
|
393
393
|
max_energy_diff: float = 500.0,
|
|
394
394
|
unrolled_leapfrog_steps: int = 1,
|
|
395
395
|
parallel_iterations: int = 10,
|
|
396
|
-
seed: Sequence[int] | None = None,
|
|
396
|
+
seed: Sequence[int] | int | None = None,
|
|
397
397
|
**pins,
|
|
398
398
|
) -> az.InferenceData:
|
|
399
399
|
"""Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
|
|
@@ -441,9 +441,10 @@ class PosteriorMCMCSampler:
|
|
|
441
441
|
trajectory length implied by `max_tree_depth`. Defaults is `1`.
|
|
442
442
|
parallel_iterations: Number of iterations allowed to run in parallel. Must
|
|
443
443
|
be a positive integer. For more information, see `tf.while_loop`.
|
|
444
|
-
seed:
|
|
445
|
-
|
|
446
|
-
|
|
444
|
+
seed: An `int32[2]` Tensor or a Python list or tuple of 2 `int`s, which
|
|
445
|
+
will be treated as stateless seeds; or a Python `int` or `None`, which
|
|
446
|
+
will be treated as stateful seeds. See [tfp.random.sanitize_seed]
|
|
447
|
+
(https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed).
|
|
447
448
|
**pins: These are used to condition the provided joint distribution, and
|
|
448
449
|
are passed directly to `joint_dist.experimental_pin(**pins)`.
|
|
449
450
|
|
|
@@ -457,7 +458,14 @@ class PosteriorMCMCSampler:
|
|
|
457
458
|
[ResourceExhaustedError when running Meridian.sample_posterior]
|
|
458
459
|
(https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
|
|
459
460
|
"""
|
|
460
|
-
seed
|
|
461
|
+
if seed is not None and isinstance(seed, Sequence) and len(seed) != 2:
|
|
462
|
+
raise ValueError(
|
|
463
|
+
"Invalid seed: Must be either a single integer (stateful seed) or a"
|
|
464
|
+
" pair of two integers (stateless seed). See"
|
|
465
|
+
" [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
|
|
466
|
+
" for details."
|
|
467
|
+
)
|
|
468
|
+
seed = tfp.random.sanitize_seed(seed) if seed is not None else None
|
|
461
469
|
n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
|
|
462
470
|
total_chains = np.sum(n_chains_list)
|
|
463
471
|
|
|
@@ -486,6 +494,8 @@ class PosteriorMCMCSampler:
|
|
|
486
494
|
" integers as `n_chains` to sample chains serially (see"
|
|
487
495
|
" https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error)"
|
|
488
496
|
) from error
|
|
497
|
+
if seed is not None:
|
|
498
|
+
seed += 1
|
|
489
499
|
states.append(mcmc.all_states._asdict())
|
|
490
500
|
traces.append(mcmc.trace)
|
|
491
501
|
|
|
File without changes
|
|
File without changes
|