google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/checks.py +118 -116
- meridian/analysis/review/constants.py +3 -3
- meridian/analysis/review/results.py +131 -68
- meridian/analysis/review/reviewer.py +8 -23
- meridian/analysis/summarizer.py +6 -1
- meridian/analysis/test_utils.py +2898 -2538
- meridian/analysis/visualizer.py +28 -9
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +1 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +25 -41
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +134 -0
- meridian/model/eda/constants.py +334 -4
- meridian/model/eda/eda_engine.py +724 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/model.py +159 -110
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/linkingapi/constants.py +1 -1
- scenarioplanner/mmm_ui_proto_generator.py +1 -0
- schema/processors/marketing_processor.py +11 -10
- schema/processors/model_processor.py +4 -1
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +12 -3
- schema/utils/__init__.py +1 -0
- schema/utils/proto_enum_converter.py +127 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
meridian/model/prior_sampler.py
CHANGED
|
@@ -15,12 +15,16 @@
|
|
|
15
15
|
"""Module for sampling prior distributions in a Meridian model."""
|
|
16
16
|
|
|
17
17
|
from collections.abc import Mapping
|
|
18
|
-
|
|
18
|
+
import functools
|
|
19
|
+
from typing import Optional, TYPE_CHECKING
|
|
20
|
+
import warnings
|
|
19
21
|
|
|
20
|
-
import arviz as az
|
|
21
22
|
from meridian import backend
|
|
22
23
|
from meridian import constants
|
|
24
|
+
from meridian.model import context
|
|
25
|
+
from meridian.model import equations
|
|
23
26
|
|
|
27
|
+
# TODO: Break this circular dependency.
|
|
24
28
|
if TYPE_CHECKING:
|
|
25
29
|
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
|
|
26
30
|
|
|
@@ -63,8 +67,34 @@ def _get_tau_g(
|
|
|
63
67
|
class PriorDistributionSampler:
|
|
64
68
|
"""A callable that samples from a model spec's prior distributions."""
|
|
65
69
|
|
|
66
|
-
|
|
67
|
-
|
|
70
|
+
# TODO: Deprecate direct injection of `model.Meridian`.
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
meridian: Optional["model.Meridian"] = None,
|
|
74
|
+
*,
|
|
75
|
+
model_context: context.ModelContext | None = None,
|
|
76
|
+
):
|
|
77
|
+
if meridian is not None:
|
|
78
|
+
warnings.warn(
|
|
79
|
+
"Initializing PriorDistributionSampler with a Meridian object is"
|
|
80
|
+
" deprecated and will be removed in a future version. Please use"
|
|
81
|
+
" `model_context` instead.",
|
|
82
|
+
DeprecationWarning,
|
|
83
|
+
stacklevel=2,
|
|
84
|
+
)
|
|
85
|
+
self._meridian = meridian
|
|
86
|
+
self._model_context = meridian.model_context
|
|
87
|
+
elif model_context is not None:
|
|
88
|
+
self._meridian = None
|
|
89
|
+
self._model_context = model_context
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"Either `meridian` or `model_context` must be provided."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@functools.cached_property
|
|
96
|
+
def _model_equations(self) -> equations.ModelEquations:
|
|
97
|
+
return equations.ModelEquations(self._model_context)
|
|
68
98
|
|
|
69
99
|
def _sample_media_priors(
|
|
70
100
|
self,
|
|
@@ -82,9 +112,9 @@ class PriorDistributionSampler:
|
|
|
82
112
|
n_media_channels]` or `[n_draws, n_media_channels]` containing the
|
|
83
113
|
samples.
|
|
84
114
|
"""
|
|
85
|
-
|
|
115
|
+
ctx = self._model_context
|
|
86
116
|
|
|
87
|
-
prior =
|
|
117
|
+
prior = ctx.prior_broadcast
|
|
88
118
|
sample_shape = [1, n_draws]
|
|
89
119
|
|
|
90
120
|
media_vars = {
|
|
@@ -103,11 +133,11 @@ class PriorDistributionSampler:
|
|
|
103
133
|
}
|
|
104
134
|
beta_gm_dev = backend.tfd.Sample(
|
|
105
135
|
backend.tfd.Normal(0, 1),
|
|
106
|
-
[
|
|
136
|
+
[ctx.n_geos, ctx.n_media_channels],
|
|
107
137
|
name=constants.BETA_GM_DEV,
|
|
108
138
|
).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
|
|
109
139
|
|
|
110
|
-
prior_type =
|
|
140
|
+
prior_type = ctx.model_spec.effective_media_prior_type
|
|
111
141
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
112
142
|
media_vars[constants.BETA_M] = prior.beta_m.sample(
|
|
113
143
|
sample_shape=sample_shape, seed=rng_handler.get_next_seed()
|
|
@@ -131,24 +161,22 @@ class PriorDistributionSampler:
|
|
|
131
161
|
else:
|
|
132
162
|
raise ValueError(f"Unsupported prior type: {prior_type}")
|
|
133
163
|
incremental_outcome_m = (
|
|
134
|
-
treatment_parameter_m *
|
|
164
|
+
treatment_parameter_m * ctx.media_tensors.prior_denominator
|
|
135
165
|
)
|
|
136
|
-
media_transformed =
|
|
137
|
-
media=
|
|
166
|
+
media_transformed = self._model_equations.adstock_hill_media(
|
|
167
|
+
media=ctx.media_tensors.media_scaled,
|
|
138
168
|
alpha=media_vars[constants.ALPHA_M],
|
|
139
169
|
ec=media_vars[constants.EC_M],
|
|
140
170
|
slope=media_vars[constants.SLOPE_M],
|
|
141
|
-
decay_functions=
|
|
171
|
+
decay_functions=ctx.adstock_decay_spec.media,
|
|
142
172
|
)
|
|
143
|
-
linear_predictor_counterfactual_difference = (
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
slope_m=media_vars[constants.SLOPE_M],
|
|
149
|
-
)
|
|
173
|
+
linear_predictor_counterfactual_difference = self._model_equations.linear_predictor_counterfactual_difference_media(
|
|
174
|
+
media_transformed=media_transformed,
|
|
175
|
+
alpha_m=media_vars[constants.ALPHA_M],
|
|
176
|
+
ec_m=media_vars[constants.EC_M],
|
|
177
|
+
slope_m=media_vars[constants.SLOPE_M],
|
|
150
178
|
)
|
|
151
|
-
beta_m_value =
|
|
179
|
+
beta_m_value = self._model_equations.calculate_beta_x(
|
|
152
180
|
is_non_media=False,
|
|
153
181
|
incremental_outcome_x=incremental_outcome_m,
|
|
154
182
|
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
@@ -165,7 +193,7 @@ class PriorDistributionSampler:
|
|
|
165
193
|
)
|
|
166
194
|
beta_gm_value = (
|
|
167
195
|
beta_eta_combined
|
|
168
|
-
if
|
|
196
|
+
if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
169
197
|
else backend.exp(beta_eta_combined)
|
|
170
198
|
)
|
|
171
199
|
media_vars[constants.BETA_GM] = backend.tfd.Deterministic(
|
|
@@ -190,9 +218,9 @@ class PriorDistributionSampler:
|
|
|
190
218
|
`[n_draws, n_geos, n_rf_channels]` or `[n_draws, n_rf_channels]`
|
|
191
219
|
containing the samples.
|
|
192
220
|
"""
|
|
193
|
-
|
|
221
|
+
ctx = self._model_context
|
|
194
222
|
|
|
195
|
-
prior =
|
|
223
|
+
prior = ctx.prior_broadcast
|
|
196
224
|
sample_shape = [1, n_draws]
|
|
197
225
|
|
|
198
226
|
rf_vars = {
|
|
@@ -211,11 +239,11 @@ class PriorDistributionSampler:
|
|
|
211
239
|
}
|
|
212
240
|
beta_grf_dev = backend.tfd.Sample(
|
|
213
241
|
backend.tfd.Normal(0, 1),
|
|
214
|
-
[
|
|
242
|
+
[ctx.n_geos, ctx.n_rf_channels],
|
|
215
243
|
name=constants.BETA_GRF_DEV,
|
|
216
244
|
).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
|
|
217
245
|
|
|
218
|
-
prior_type =
|
|
246
|
+
prior_type = ctx.model_spec.effective_rf_prior_type
|
|
219
247
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
220
248
|
rf_vars[constants.BETA_RF] = prior.beta_rf.sample(
|
|
221
249
|
sample_shape=sample_shape, seed=rng_handler.get_next_seed()
|
|
@@ -239,25 +267,25 @@ class PriorDistributionSampler:
|
|
|
239
267
|
else:
|
|
240
268
|
raise ValueError(f"Unsupported prior type: {prior_type}")
|
|
241
269
|
incremental_outcome_rf = (
|
|
242
|
-
treatment_parameter_rf *
|
|
270
|
+
treatment_parameter_rf * ctx.rf_tensors.prior_denominator
|
|
243
271
|
)
|
|
244
|
-
rf_transformed =
|
|
245
|
-
reach=
|
|
246
|
-
frequency=
|
|
272
|
+
rf_transformed = self._model_equations.adstock_hill_rf(
|
|
273
|
+
reach=ctx.rf_tensors.reach_scaled,
|
|
274
|
+
frequency=ctx.rf_tensors.frequency,
|
|
247
275
|
alpha=rf_vars[constants.ALPHA_RF],
|
|
248
276
|
ec=rf_vars[constants.EC_RF],
|
|
249
277
|
slope=rf_vars[constants.SLOPE_RF],
|
|
250
|
-
decay_functions=
|
|
278
|
+
decay_functions=ctx.adstock_decay_spec.rf,
|
|
251
279
|
)
|
|
252
280
|
linear_predictor_counterfactual_difference = (
|
|
253
|
-
|
|
281
|
+
self._model_equations.linear_predictor_counterfactual_difference_rf(
|
|
254
282
|
rf_transformed=rf_transformed,
|
|
255
283
|
alpha_rf=rf_vars[constants.ALPHA_RF],
|
|
256
284
|
ec_rf=rf_vars[constants.EC_RF],
|
|
257
285
|
slope_rf=rf_vars[constants.SLOPE_RF],
|
|
258
286
|
)
|
|
259
287
|
)
|
|
260
|
-
beta_rf_value =
|
|
288
|
+
beta_rf_value = self._model_equations.calculate_beta_x(
|
|
261
289
|
is_non_media=False,
|
|
262
290
|
incremental_outcome_x=incremental_outcome_rf,
|
|
263
291
|
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
@@ -275,7 +303,7 @@ class PriorDistributionSampler:
|
|
|
275
303
|
)
|
|
276
304
|
beta_grf_value = (
|
|
277
305
|
beta_eta_combined
|
|
278
|
-
if
|
|
306
|
+
if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
279
307
|
else backend.exp(beta_eta_combined)
|
|
280
308
|
)
|
|
281
309
|
rf_vars[constants.BETA_GRF] = backend.tfd.Deterministic(
|
|
@@ -300,9 +328,9 @@ class PriorDistributionSampler:
|
|
|
300
328
|
`[n_draws, n_geos, n_organic_media_channels]` or
|
|
301
329
|
`[n_draws, n_organic_media_channels]` containing the samples.
|
|
302
330
|
"""
|
|
303
|
-
|
|
331
|
+
ctx = self._model_context
|
|
304
332
|
|
|
305
|
-
prior =
|
|
333
|
+
prior = ctx.prior_broadcast
|
|
306
334
|
sample_shape = [1, n_draws]
|
|
307
335
|
|
|
308
336
|
organic_media_vars = {
|
|
@@ -321,11 +349,11 @@ class PriorDistributionSampler:
|
|
|
321
349
|
}
|
|
322
350
|
beta_gom_dev = backend.tfd.Sample(
|
|
323
351
|
backend.tfd.Normal(0, 1),
|
|
324
|
-
[
|
|
352
|
+
[ctx.n_geos, ctx.n_organic_media_channels],
|
|
325
353
|
name=constants.BETA_GOM_DEV,
|
|
326
354
|
).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
|
|
327
355
|
|
|
328
|
-
prior_type =
|
|
356
|
+
prior_type = ctx.model_spec.organic_media_prior_type
|
|
329
357
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
330
358
|
organic_media_vars[constants.BETA_OM] = prior.beta_om.sample(
|
|
331
359
|
sample_shape=sample_shape, seed=rng_handler.get_next_seed()
|
|
@@ -337,16 +365,16 @@ class PriorDistributionSampler:
|
|
|
337
365
|
)
|
|
338
366
|
)
|
|
339
367
|
incremental_outcome_om = (
|
|
340
|
-
organic_media_vars[constants.CONTRIBUTION_OM] *
|
|
368
|
+
organic_media_vars[constants.CONTRIBUTION_OM] * ctx.total_outcome
|
|
341
369
|
)
|
|
342
|
-
organic_media_transformed =
|
|
343
|
-
media=
|
|
370
|
+
organic_media_transformed = self._model_equations.adstock_hill_media(
|
|
371
|
+
media=ctx.organic_media_tensors.organic_media_scaled,
|
|
344
372
|
alpha=organic_media_vars[constants.ALPHA_OM],
|
|
345
373
|
ec=organic_media_vars[constants.EC_OM],
|
|
346
374
|
slope=organic_media_vars[constants.SLOPE_OM],
|
|
347
|
-
decay_functions=
|
|
375
|
+
decay_functions=ctx.adstock_decay_spec.organic_media,
|
|
348
376
|
)
|
|
349
|
-
beta_om_value =
|
|
377
|
+
beta_om_value = self._model_equations.calculate_beta_x(
|
|
350
378
|
is_non_media=False,
|
|
351
379
|
incremental_outcome_x=incremental_outcome_om,
|
|
352
380
|
linear_predictor_counterfactual_difference=organic_media_transformed,
|
|
@@ -367,7 +395,7 @@ class PriorDistributionSampler:
|
|
|
367
395
|
)
|
|
368
396
|
beta_gom_value = (
|
|
369
397
|
beta_eta_combined
|
|
370
|
-
if
|
|
398
|
+
if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
371
399
|
else backend.exp(beta_eta_combined)
|
|
372
400
|
)
|
|
373
401
|
organic_media_vars[constants.BETA_GOM] = backend.tfd.Deterministic(
|
|
@@ -392,9 +420,9 @@ class PriorDistributionSampler:
|
|
|
392
420
|
`[n_draws, n_geos, n_organic_rf_channels]` or
|
|
393
421
|
`[n_draws, n_organic_rf_channels]` containing the samples.
|
|
394
422
|
"""
|
|
395
|
-
|
|
423
|
+
ctx = self._model_context
|
|
396
424
|
|
|
397
|
-
prior =
|
|
425
|
+
prior = ctx.prior_broadcast
|
|
398
426
|
sample_shape = [1, n_draws]
|
|
399
427
|
|
|
400
428
|
organic_rf_vars = {
|
|
@@ -413,11 +441,11 @@ class PriorDistributionSampler:
|
|
|
413
441
|
}
|
|
414
442
|
beta_gorf_dev = backend.tfd.Sample(
|
|
415
443
|
backend.tfd.Normal(0, 1),
|
|
416
|
-
[
|
|
444
|
+
[ctx.n_geos, ctx.n_organic_rf_channels],
|
|
417
445
|
name=constants.BETA_GORF_DEV,
|
|
418
446
|
).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
|
|
419
447
|
|
|
420
|
-
prior_type =
|
|
448
|
+
prior_type = ctx.model_spec.organic_media_prior_type
|
|
421
449
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
422
450
|
organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(
|
|
423
451
|
sample_shape=sample_shape, seed=rng_handler.get_next_seed()
|
|
@@ -429,17 +457,17 @@ class PriorDistributionSampler:
|
|
|
429
457
|
)
|
|
430
458
|
)
|
|
431
459
|
incremental_outcome_orf = (
|
|
432
|
-
organic_rf_vars[constants.CONTRIBUTION_ORF] *
|
|
460
|
+
organic_rf_vars[constants.CONTRIBUTION_ORF] * ctx.total_outcome
|
|
433
461
|
)
|
|
434
|
-
organic_rf_transformed =
|
|
435
|
-
reach=
|
|
436
|
-
frequency=
|
|
462
|
+
organic_rf_transformed = self._model_equations.adstock_hill_rf(
|
|
463
|
+
reach=ctx.organic_rf_tensors.organic_reach_scaled,
|
|
464
|
+
frequency=ctx.organic_rf_tensors.organic_frequency,
|
|
437
465
|
alpha=organic_rf_vars[constants.ALPHA_ORF],
|
|
438
466
|
ec=organic_rf_vars[constants.EC_ORF],
|
|
439
467
|
slope=organic_rf_vars[constants.SLOPE_ORF],
|
|
440
|
-
decay_functions=
|
|
468
|
+
decay_functions=ctx.adstock_decay_spec.organic_rf,
|
|
441
469
|
)
|
|
442
|
-
beta_orf_value =
|
|
470
|
+
beta_orf_value = self._model_equations.calculate_beta_x(
|
|
443
471
|
is_non_media=False,
|
|
444
472
|
incremental_outcome_x=incremental_outcome_orf,
|
|
445
473
|
linear_predictor_counterfactual_difference=organic_rf_transformed,
|
|
@@ -460,7 +488,7 @@ class PriorDistributionSampler:
|
|
|
460
488
|
)
|
|
461
489
|
beta_gorf_value = (
|
|
462
490
|
beta_eta_combined
|
|
463
|
-
if
|
|
491
|
+
if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
464
492
|
else backend.exp(beta_eta_combined)
|
|
465
493
|
)
|
|
466
494
|
organic_rf_vars[constants.BETA_GORF] = backend.tfd.Deterministic(
|
|
@@ -485,9 +513,9 @@ class PriorDistributionSampler:
|
|
|
485
513
|
`[n_draws, n_geos, n_non_media_channels]` or
|
|
486
514
|
`[n_draws, n_non_media_channels]` containing the samples.
|
|
487
515
|
"""
|
|
488
|
-
|
|
516
|
+
ctx = self._model_context
|
|
489
517
|
|
|
490
|
-
prior =
|
|
518
|
+
prior = ctx.prior_broadcast
|
|
491
519
|
sample_shape = [1, n_draws]
|
|
492
520
|
|
|
493
521
|
non_media_treatments_vars = {
|
|
@@ -497,10 +525,10 @@ class PriorDistributionSampler:
|
|
|
497
525
|
}
|
|
498
526
|
gamma_gn_dev = backend.tfd.Sample(
|
|
499
527
|
backend.tfd.Normal(0, 1),
|
|
500
|
-
[
|
|
528
|
+
[ctx.n_geos, ctx.n_non_media_channels],
|
|
501
529
|
name=constants.GAMMA_GN_DEV,
|
|
502
530
|
).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
|
|
503
|
-
prior_type =
|
|
531
|
+
prior_type = ctx.model_spec.non_media_treatments_prior_type
|
|
504
532
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
505
533
|
non_media_treatments_vars[constants.GAMMA_N] = prior.gamma_n.sample(
|
|
506
534
|
sample_shape=sample_shape, seed=rng_handler.get_next_seed()
|
|
@@ -513,15 +541,15 @@ class PriorDistributionSampler:
|
|
|
513
541
|
)
|
|
514
542
|
incremental_outcome_n = (
|
|
515
543
|
non_media_treatments_vars[constants.CONTRIBUTION_N]
|
|
516
|
-
*
|
|
544
|
+
* ctx.total_outcome
|
|
517
545
|
)
|
|
518
|
-
baseline_scaled =
|
|
519
|
-
|
|
546
|
+
baseline_scaled = ctx.non_media_transformer.forward( # pytype: disable=attribute-error
|
|
547
|
+
self._model_equations.compute_non_media_treatments_baseline()
|
|
520
548
|
)
|
|
521
549
|
linear_predictor_counterfactual_difference = (
|
|
522
|
-
|
|
550
|
+
ctx.non_media_treatments_normalized - baseline_scaled
|
|
523
551
|
)
|
|
524
|
-
gamma_n_value =
|
|
552
|
+
gamma_n_value = self._model_equations.calculate_beta_x(
|
|
525
553
|
is_non_media=True,
|
|
526
554
|
incremental_outcome_x=incremental_outcome_n,
|
|
527
555
|
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
@@ -541,13 +569,23 @@ class PriorDistributionSampler:
|
|
|
541
569
|
).sample(seed=rng_handler.get_next_seed())
|
|
542
570
|
return non_media_treatments_vars
|
|
543
571
|
|
|
544
|
-
def
|
|
572
|
+
def __call__(
|
|
545
573
|
self,
|
|
546
574
|
n_draws: int,
|
|
547
575
|
seed: int | None = None,
|
|
548
576
|
) -> Mapping[str, backend.Tensor]:
|
|
549
|
-
"""
|
|
550
|
-
|
|
577
|
+
"""Draws samples from prior distributions.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
n_draws: Number of samples drawn from the prior distribution.
|
|
581
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
582
|
+
see [PRNGS and seeds]
|
|
583
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
A mapping of prior parameter names to tensors of the samples.
|
|
587
|
+
"""
|
|
588
|
+
ctx = self._model_context
|
|
551
589
|
|
|
552
590
|
# For stateful sampling, the random seed must be set to ensure that any
|
|
553
591
|
# random numbers that are generated are deterministic.
|
|
@@ -556,7 +594,7 @@ class PriorDistributionSampler:
|
|
|
556
594
|
|
|
557
595
|
rng_handler = backend.RNGHandler(seed)
|
|
558
596
|
|
|
559
|
-
prior =
|
|
597
|
+
prior = ctx.prior_broadcast
|
|
560
598
|
# `sample_shape` is prepended to the shape of each BatchBroadcast in `prior`
|
|
561
599
|
# when it is sampled.
|
|
562
600
|
sample_shape = [1, n_draws]
|
|
@@ -574,7 +612,7 @@ class PriorDistributionSampler:
|
|
|
574
612
|
constants.TAU_G: (
|
|
575
613
|
_get_tau_g(
|
|
576
614
|
tau_g_excl_baseline=tau_g_excl_baseline,
|
|
577
|
-
baseline_geo_idx=
|
|
615
|
+
baseline_geo_idx=ctx.baseline_geo_idx,
|
|
578
616
|
).sample(seed=rng_handler.get_next_seed())
|
|
579
617
|
),
|
|
580
618
|
}
|
|
@@ -583,14 +621,14 @@ class PriorDistributionSampler:
|
|
|
583
621
|
backend.einsum(
|
|
584
622
|
"...k,kt->...t",
|
|
585
623
|
base_vars[constants.KNOT_VALUES],
|
|
586
|
-
backend.to_tensor(
|
|
624
|
+
backend.to_tensor(ctx.knot_info.weights),
|
|
587
625
|
),
|
|
588
626
|
name=constants.MU_T,
|
|
589
627
|
).sample(seed=rng_handler.get_next_seed())
|
|
590
628
|
|
|
591
629
|
# Omit gamma_c, xi_c, and gamma_gc parameters from sampled distributions if
|
|
592
630
|
# there are no control variables in the model.
|
|
593
|
-
if
|
|
631
|
+
if ctx.n_controls:
|
|
594
632
|
base_vars |= {
|
|
595
633
|
constants.GAMMA_C: prior.gamma_c.sample(
|
|
596
634
|
sample_shape=sample_shape, seed=rng_handler.get_next_seed()
|
|
@@ -602,7 +640,7 @@ class PriorDistributionSampler:
|
|
|
602
640
|
|
|
603
641
|
gamma_gc_dev = backend.tfd.Sample(
|
|
604
642
|
backend.tfd.Normal(0, 1),
|
|
605
|
-
[
|
|
643
|
+
[ctx.n_geos, ctx.n_controls],
|
|
606
644
|
name=constants.GAMMA_GC_DEV,
|
|
607
645
|
).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
|
|
608
646
|
base_vars[constants.GAMMA_GC] = backend.tfd.Deterministic(
|
|
@@ -613,27 +651,27 @@ class PriorDistributionSampler:
|
|
|
613
651
|
|
|
614
652
|
media_vars = (
|
|
615
653
|
self._sample_media_priors(n_draws, rng_handler)
|
|
616
|
-
if
|
|
654
|
+
if ctx.media_tensors.media is not None
|
|
617
655
|
else {}
|
|
618
656
|
)
|
|
619
657
|
rf_vars = (
|
|
620
658
|
self._sample_rf_priors(n_draws, rng_handler)
|
|
621
|
-
if
|
|
659
|
+
if ctx.rf_tensors.reach is not None
|
|
622
660
|
else {}
|
|
623
661
|
)
|
|
624
662
|
organic_media_vars = (
|
|
625
663
|
self._sample_organic_media_priors(n_draws, rng_handler)
|
|
626
|
-
if
|
|
664
|
+
if ctx.organic_media_tensors.organic_media is not None
|
|
627
665
|
else {}
|
|
628
666
|
)
|
|
629
667
|
organic_rf_vars = (
|
|
630
668
|
self._sample_organic_rf_priors(n_draws, rng_handler)
|
|
631
|
-
if
|
|
669
|
+
if ctx.organic_rf_tensors.organic_reach is not None
|
|
632
670
|
else {}
|
|
633
671
|
)
|
|
634
672
|
non_media_treatments_vars = (
|
|
635
673
|
self._sample_non_media_treatments_priors(n_draws, rng_handler)
|
|
636
|
-
if
|
|
674
|
+
if ctx.non_media_treatments_normalized is not None
|
|
637
675
|
else {}
|
|
638
676
|
)
|
|
639
677
|
|
|
@@ -645,21 +683,3 @@ class PriorDistributionSampler:
|
|
|
645
683
|
| organic_rf_vars
|
|
646
684
|
| non_media_treatments_vars
|
|
647
685
|
)
|
|
648
|
-
|
|
649
|
-
def __call__(self, n_draws: int, seed: int | None = None) -> None:
|
|
650
|
-
"""Draws samples from prior distributions.
|
|
651
|
-
|
|
652
|
-
Args:
|
|
653
|
-
n_draws: Number of samples drawn from the prior distribution.
|
|
654
|
-
seed: Used to set the seed for reproducible results. For more information,
|
|
655
|
-
see [PRNGS and seeds]
|
|
656
|
-
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
657
|
-
"""
|
|
658
|
-
prior_draws = self._sample_prior(n_draws=n_draws, seed=seed)
|
|
659
|
-
# Create Arviz InferenceData for prior draws.
|
|
660
|
-
prior_coords = self._meridian.create_inference_data_coords(1, n_draws)
|
|
661
|
-
prior_dims = self._meridian.create_inference_data_dims()
|
|
662
|
-
prior_inference_data = az.convert_to_inference_data(
|
|
663
|
-
prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR
|
|
664
|
-
)
|
|
665
|
-
self._meridian.inference_data.extend(prior_inference_data, join="right")
|
meridian/model/spec.py
CHANGED
|
@@ -14,11 +14,10 @@
|
|
|
14
14
|
|
|
15
15
|
"""Defines model specification parameters for Meridian."""
|
|
16
16
|
|
|
17
|
-
from collections.abc import Mapping
|
|
17
|
+
from collections.abc import Collection, Mapping
|
|
18
18
|
import dataclasses
|
|
19
19
|
from typing import Sequence
|
|
20
20
|
import warnings
|
|
21
|
-
|
|
22
21
|
from meridian import constants
|
|
23
22
|
from meridian.model import prior_distribution
|
|
24
23
|
import numpy as np
|
|
@@ -166,17 +165,17 @@ class ModelSpec:
|
|
|
166
165
|
given non_media treatments channel). If `None`, the minimum value is used
|
|
167
166
|
as baseline for each non-media treatments channel. This attribute is used
|
|
168
167
|
as the default value for the corresponding argument to `Analyzer` methods.
|
|
169
|
-
knots: An optional integer or
|
|
170
|
-
estimate time effects. When `knots` is a
|
|
171
|
-
locations are provided by that list. Zero corresponds to a knot
|
|
172
|
-
first time period, one corresponds to a knot at the second time
|
|
173
|
-
..., and `(n_times - 1)` corresponds to a knot at the last time
|
|
174
|
-
Typically, we recommend including knots at `0` and `(n_times -
|
|
175
|
-
this is not required. When `knots` is an integer, then there are
|
|
176
|
-
with locations equally spaced across the time periods, (including
|
|
177
|
-
zero and `(n_times - 1)`. When `knots` is` 1`, there is a single
|
|
178
|
-
regression coefficient used for all time periods. If `knots` is set
|
|
179
|
-
`None`, then the numbers of knots used is equal to the number of time
|
|
168
|
+
knots: An optional integer or collection of integers indicating the knots
|
|
169
|
+
used to estimate time effects. When `knots` is a collection of integers,
|
|
170
|
+
the knot locations are provided by that list. Zero corresponds to a knot
|
|
171
|
+
at the first time period, one corresponds to a knot at the second time
|
|
172
|
+
period, ..., and `(n_times - 1)` corresponds to a knot at the last time
|
|
173
|
+
period). Typically, we recommend including knots at `0` and `(n_times -
|
|
174
|
+
1)`, but this is not required. When `knots` is an integer, then there are
|
|
175
|
+
knots with locations equally spaced across the time periods, (including
|
|
176
|
+
knots at zero and `(n_times - 1)`. When `knots` is` 1`, there is a single
|
|
177
|
+
common regression coefficient used for all time periods. If `knots` is set
|
|
178
|
+
to `None`, then the numbers of knots used is equal to the number of time
|
|
180
179
|
periods in the case of a geo model. This is equivalent to each time period
|
|
181
180
|
having its own regression coefficient. If `knots` is set to `None` in the
|
|
182
181
|
case of a national model, then the number of knots used is `1`. Default:
|
|
@@ -235,7 +234,7 @@ class ModelSpec:
|
|
|
235
234
|
constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION
|
|
236
235
|
)
|
|
237
236
|
non_media_baseline_values: Sequence[float | str] | None = None
|
|
238
|
-
knots: int |
|
|
237
|
+
knots: int | Collection[int] | None = None
|
|
239
238
|
baseline_geo: int | str | None = None
|
|
240
239
|
holdout_id: np.ndarray | None = None
|
|
241
240
|
control_population_scaling_id: np.ndarray | None = None
|
|
@@ -321,6 +320,12 @@ class ModelSpec:
|
|
|
321
320
|
prior_type_name="rf_prior_type",
|
|
322
321
|
)
|
|
323
322
|
|
|
323
|
+
if isinstance(self.knots, Collection):
|
|
324
|
+
knots_list = list(self.knots)
|
|
325
|
+
if not all(isinstance(x, (int, np.integer)) for x in knots_list):
|
|
326
|
+
raise ValueError("`knots` must be a sequence of integers.")
|
|
327
|
+
object.__setattr__(self, "knots", [int(x) for x in knots_list])
|
|
328
|
+
|
|
324
329
|
# Validate knots.
|
|
325
330
|
if isinstance(self.knots, list) and not self.knots:
|
|
326
331
|
raise ValueError("The `knots` parameter cannot be an empty list.")
|
|
@@ -330,6 +335,10 @@ class ModelSpec:
|
|
|
330
335
|
raise ValueError(
|
|
331
336
|
"The `knots` parameter cannot be set when `enable_aks` is True."
|
|
332
337
|
)
|
|
338
|
+
if not (self.knots is None or isinstance(self.knots, (int, list))):
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"Unsupported type for `knots` parameter: {type(self.knots)}."
|
|
341
|
+
)
|
|
333
342
|
|
|
334
343
|
@property
|
|
335
344
|
def effective_media_prior_type(self) -> str:
|
|
@@ -18,13 +18,15 @@ limitations under the License.
|
|
|
18
18
|
<card-title>
|
|
19
19
|
{{ title }}
|
|
20
20
|
</card-title>
|
|
21
|
-
|
|
22
|
-
<card-insights
|
|
23
|
-
<
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
21
|
+
{% if insights %}
|
|
22
|
+
<card-insights>
|
|
23
|
+
<card-insights-icon>
|
|
24
|
+
<img src="https://www.gstatic.com/images/icons/material/system/svg/insights_24px.svg"
|
|
25
|
+
alt="insights icon" />
|
|
26
|
+
</card-insights-icon>
|
|
27
|
+
{{ insights }} {# `insights` is a pre-rendered HTML snippet. #}
|
|
28
|
+
</card-insights>
|
|
29
|
+
{% endif %}
|
|
28
30
|
{% if stats %}
|
|
29
31
|
<stats-section>
|
|
30
32
|
{% for item in stats %} {# Each `item` is a pre-rendered HTML snippet. #}
|
|
@@ -15,6 +15,7 @@ limitations under the License.
|
|
|
15
15
|
#}
|
|
16
16
|
|
|
17
17
|
<chart>
|
|
18
|
+
{% include "findings.html.jinja" %}
|
|
18
19
|
<chart-embed id="{{ id }}"></chart-embed>
|
|
19
20
|
{% if description %}
|
|
20
21
|
<chart-description>
|
|
@@ -25,13 +26,7 @@ limitations under the License.
|
|
|
25
26
|
|
|
26
27
|
<script type="text/javascript">
|
|
27
28
|
(() => {
|
|
28
|
-
const opt = {
|
|
29
|
-
mode: 'vega-lite',
|
|
30
|
-
width: 'container',
|
|
31
|
-
autosize: { type: 'fit', contains: 'padding' }
|
|
32
|
-
};
|
|
33
29
|
const spec = JSON.parse({{ chart_json|tojson }});
|
|
34
|
-
const chartDiv = document.getElementById('{{ id }}');
|
|
35
30
|
vegaEmbed('#{{ id }}', spec).catch(console.error);
|
|
36
31
|
})();
|
|
37
32
|
</script>
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
{#
|
|
2
|
+
Copyright 2026 Google LLC
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
#}
|
|
16
|
+
|
|
17
|
+
<finding class="{{ finding_class }}">
|
|
18
|
+
{{ text }}
|
|
19
|
+
</finding>
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
{#
|
|
2
|
+
Copyright 2025 Google LLC
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
#}
|
|
16
|
+
|
|
17
|
+
{% macro render_findings(finding_type, findings, icon_alt, icon_url) %}
|
|
18
|
+
{% if findings %}
|
|
19
|
+
{% for finding in findings %}
|
|
20
|
+
<{{ finding_type }}s>
|
|
21
|
+
<{{ finding_type }}s-icon>
|
|
22
|
+
<img src="{{ icon_url }}"
|
|
23
|
+
alt="{{ icon_alt }}" />
|
|
24
|
+
</{{ finding_type }}s-icon>
|
|
25
|
+
<p class="{{ finding_type }}-text">{{ finding }}</p>
|
|
26
|
+
</{{ finding_type }}s>
|
|
27
|
+
{% endfor %}
|
|
28
|
+
{% endif %}
|
|
29
|
+
{% endmacro %}
|
|
30
|
+
|
|
31
|
+
{{ render_findings('error', errors, 'error icon', 'https://www.gstatic.com/images/icons/material/system/svg/error_outline_24px.svg') }}
|
|
32
|
+
{{ render_findings('warning', warnings, 'warning icon', 'https://www.gstatic.com/images/icons/material/system/svg/warning_amber_24px.svg') }}
|
|
33
|
+
{{ render_findings('info', infos, 'info icon', 'https://www.gstatic.com/images/icons/material/system/svg/info_outline_24px.svg') }}
|