google-meridian 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,633 @@
1
+ # Copyright 2024 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for sampling prior distributions in a Meridian model."""
16
+
17
+ from collections.abc import Mapping
18
+ from typing import TYPE_CHECKING
19
+
20
+ import arviz as az
21
+ from meridian import constants
22
+ import tensorflow as tf
23
+ import tensorflow_probability as tfp
24
+
25
+ if TYPE_CHECKING:
26
+ from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
27
+
28
+
29
+ __all__ = [
30
+ "PriorDistributionSampler",
31
+ ]
32
+
33
+
34
+ def _get_tau_g(
35
+ tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int
36
+ ) -> tfp.distributions.Distribution:
37
+ """Computes `tau_g` from `tau_g_excl_baseline`.
38
+
39
+ This function computes `tau_g` by inserting a column of zeros at the
40
+ `baseline_geo` position in `tau_g_excl_baseline`.
41
+
42
+ Args:
43
+ tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the
44
+ user-defined dimensions of the `tau_g` parameter distribution.
45
+ baseline_geo_idx: The index of the baseline geo to be set to zero.
46
+
47
+ Returns:
48
+ A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g`
49
+ parameter with zero at position `baseline_geo_idx` and matching
50
+ `tau_g_excl_baseline` elsewhere.
51
+ """
52
+ rank = len(tau_g_excl_baseline.shape)
53
+ shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
54
+ tau_g = tf.concat(
55
+ [
56
+ tau_g_excl_baseline[..., :baseline_geo_idx],
57
+ tf.zeros(shape, dtype=tau_g_excl_baseline.dtype),
58
+ tau_g_excl_baseline[..., baseline_geo_idx:],
59
+ ],
60
+ axis=rank - 1,
61
+ )
62
+ return tfp.distributions.Deterministic(tau_g, name="tau_g")
63
+
64
+
65
+ class PriorDistributionSampler:
66
+ """A callable that samples from a model spec's prior distributions."""
67
+
68
+ def __init__(self, meridian: "model.Meridian"):
69
+ self._meridian = meridian
70
+
71
+ def get_roi_prior_beta_m_value(
72
+ self,
73
+ alpha_m: tf.Tensor,
74
+ beta_gm_dev: tf.Tensor,
75
+ ec_m: tf.Tensor,
76
+ eta_m: tf.Tensor,
77
+ roi_or_mroi_m: tf.Tensor,
78
+ slope_m: tf.Tensor,
79
+ media_transformed: tf.Tensor,
80
+ ) -> tf.Tensor:
81
+ """Returns a tensor to be used in `beta_m`."""
82
+ mmm = self._meridian
83
+
84
+ # The `roi_or_mroi_m` parameter represents either ROI or mROI. For reach &
85
+ # frequency channels, marginal ROI priors are defined as "mROI by reach",
86
+ # which is equivalent to ROI.
87
+ media_spend = mmm.media_tensors.media_spend
88
+ media_spend_counterfactual = mmm.media_tensors.media_spend_counterfactual
89
+ media_counterfactual_scaled = mmm.media_tensors.media_counterfactual_scaled
90
+ # If we got here, then we should already have media tensors derived from
91
+ # non-None InputData.media data.
92
+ assert media_spend is not None
93
+ assert media_spend_counterfactual is not None
94
+ assert media_counterfactual_scaled is not None
95
+
96
+ # Use absolute value here because this difference will be negative for
97
+ # marginal ROI priors.
98
+ inc_revenue_m = roi_or_mroi_m * tf.reduce_sum(
99
+ tf.abs(media_spend - media_spend_counterfactual),
100
+ range(media_spend.ndim - 1),
101
+ )
102
+
103
+ if (
104
+ mmm.model_spec.roi_calibration_period is None
105
+ and mmm.model_spec.paid_media_prior_type
106
+ == constants.PAID_MEDIA_PRIOR_TYPE_ROI
107
+ ):
108
+ # We can skip the adstock/hill computation step in this case.
109
+ media_counterfactual_transformed = tf.zeros_like(media_transformed)
110
+ else:
111
+ media_counterfactual_transformed = mmm.adstock_hill_media(
112
+ media=media_counterfactual_scaled,
113
+ alpha=alpha_m,
114
+ ec=ec_m,
115
+ slope=slope_m,
116
+ )
117
+
118
+ revenue_per_kpi = mmm.revenue_per_kpi
119
+ if mmm.input_data.revenue_per_kpi is None:
120
+ revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32)
121
+ # Note: use absolute value here because this difference will be negative for
122
+ # marginal ROI priors.
123
+ media_contrib_gm = tf.einsum(
124
+ "...gtm,g,,gt->...gm",
125
+ tf.abs(media_transformed - media_counterfactual_transformed),
126
+ mmm.population,
127
+ mmm.kpi_transformer.population_scaled_stdev,
128
+ revenue_per_kpi,
129
+ )
130
+
131
+ if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
132
+ media_contrib_m = tf.einsum("...gm->...m", media_contrib_gm)
133
+ random_effect_m = tf.einsum(
134
+ "...m,...gm,...gm->...m", eta_m, beta_gm_dev, media_contrib_gm
135
+ )
136
+ return (inc_revenue_m - random_effect_m) / media_contrib_m
137
+ else:
138
+ # For log_normal, beta_m and eta_m are not mean & std.
139
+ # The parameterization is beta_gm ~ exp(beta_m + eta_m * N(0, 1)).
140
+ random_effect_m = tf.einsum(
141
+ "...gm,...gm->...m",
142
+ tf.math.exp(beta_gm_dev * eta_m[..., tf.newaxis, :]),
143
+ media_contrib_gm,
144
+ )
145
+ return tf.math.log(inc_revenue_m) - tf.math.log(random_effect_m)
146
+
147
+ def get_roi_prior_beta_rf_value(
148
+ self,
149
+ alpha_rf: tf.Tensor,
150
+ beta_grf_dev: tf.Tensor,
151
+ ec_rf: tf.Tensor,
152
+ eta_rf: tf.Tensor,
153
+ roi_or_mroi_rf: tf.Tensor,
154
+ slope_rf: tf.Tensor,
155
+ rf_transformed: tf.Tensor,
156
+ ) -> tf.Tensor:
157
+ """Returns a tensor to be used in `beta_rf`."""
158
+ mmm = self._meridian
159
+
160
+ rf_spend = mmm.rf_tensors.rf_spend
161
+ rf_spend_counterfactual = mmm.rf_tensors.rf_spend_counterfactual
162
+ reach_counterfactual_scaled = mmm.rf_tensors.reach_counterfactual_scaled
163
+ frequency = mmm.rf_tensors.frequency
164
+ # If we got here, then we should already have RF media tensors derived from
165
+ # non-None InputData.reach data.
166
+ assert rf_spend is not None
167
+ assert rf_spend_counterfactual is not None
168
+ assert reach_counterfactual_scaled is not None
169
+ assert frequency is not None
170
+
171
+ inc_revenue_rf = roi_or_mroi_rf * tf.reduce_sum(
172
+ rf_spend - rf_spend_counterfactual,
173
+ range(rf_spend.ndim - 1),
174
+ )
175
+ if mmm.model_spec.rf_roi_calibration_period is not None:
176
+ rf_counterfactual_transformed = mmm.adstock_hill_rf(
177
+ reach=reach_counterfactual_scaled,
178
+ frequency=frequency,
179
+ alpha=alpha_rf,
180
+ ec=ec_rf,
181
+ slope=slope_rf,
182
+ )
183
+ else:
184
+ rf_counterfactual_transformed = tf.zeros_like(rf_transformed)
185
+ revenue_per_kpi = mmm.revenue_per_kpi
186
+ if mmm.input_data.revenue_per_kpi is None:
187
+ revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32)
188
+
189
+ media_contrib_grf = tf.einsum(
190
+ "...gtm,g,,gt->...gm",
191
+ rf_transformed - rf_counterfactual_transformed,
192
+ mmm.population,
193
+ mmm.kpi_transformer.population_scaled_stdev,
194
+ revenue_per_kpi,
195
+ )
196
+ if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
197
+ media_contrib_rf = tf.einsum("...gm->...m", media_contrib_grf)
198
+ random_effect_rf = tf.einsum(
199
+ "...m,...gm,...gm->...m", eta_rf, beta_grf_dev, media_contrib_grf
200
+ )
201
+ return (inc_revenue_rf - random_effect_rf) / media_contrib_rf
202
+ else:
203
+ # For log_normal, beta_rf and eta_rf are not mean & std.
204
+ # The parameterization is beta_grf ~ exp(beta_rf + eta_rf * N(0, 1)).
205
+ random_effect_rf = tf.einsum(
206
+ "...gm,...gm->...m",
207
+ tf.math.exp(beta_grf_dev * eta_rf[..., tf.newaxis, :]),
208
+ media_contrib_grf,
209
+ )
210
+ return tf.math.log(inc_revenue_rf) - tf.math.log(random_effect_rf)
211
+
212
+ def _sample_media_priors(
213
+ self,
214
+ n_draws: int,
215
+ seed: int | None = None,
216
+ ) -> Mapping[str, tf.Tensor]:
217
+ """Draws samples from the prior distributions of the media variables.
218
+
219
+ Args:
220
+ n_draws: Number of samples drawn from the prior distribution.
221
+ seed: Used to set the seed for reproducible results. For more information,
222
+ see [PRNGS and seeds]
223
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
224
+
225
+ Returns:
226
+ A mapping of media parameter names to a tensor of shape `[n_draws, n_geos,
227
+ n_media_channels]` or `[n_draws, n_media_channels]` containing the
228
+ samples.
229
+ """
230
+ mmm = self._meridian
231
+
232
+ prior = mmm.prior_broadcast
233
+ sample_shape = [1, n_draws]
234
+ sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
235
+ media_vars = {
236
+ constants.ALPHA_M: prior.alpha_m.sample(**sample_kwargs),
237
+ constants.EC_M: prior.ec_m.sample(**sample_kwargs),
238
+ constants.ETA_M: prior.eta_m.sample(**sample_kwargs),
239
+ constants.SLOPE_M: prior.slope_m.sample(**sample_kwargs),
240
+ }
241
+ beta_gm_dev = tfp.distributions.Sample(
242
+ tfp.distributions.Normal(0, 1),
243
+ [mmm.n_geos, mmm.n_media_channels],
244
+ name=constants.BETA_GM_DEV,
245
+ ).sample(**sample_kwargs)
246
+ media_transformed = mmm.adstock_hill_media(
247
+ media=mmm.media_tensors.media_scaled,
248
+ alpha=media_vars[constants.ALPHA_M],
249
+ ec=media_vars[constants.EC_M],
250
+ slope=media_vars[constants.SLOPE_M],
251
+ )
252
+
253
+ prior_type = mmm.model_spec.paid_media_prior_type
254
+ if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
255
+ roi_m = prior.roi_m.sample(**sample_kwargs)
256
+ beta_m_value = self.get_roi_prior_beta_m_value(
257
+ beta_gm_dev=beta_gm_dev,
258
+ media_transformed=media_transformed,
259
+ roi_or_mroi_m=roi_m,
260
+ **media_vars,
261
+ )
262
+ media_vars[constants.ROI_M] = roi_m
263
+ media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
264
+ beta_m_value, name=constants.BETA_M
265
+ ).sample()
266
+ elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
267
+ mroi_m = prior.mroi_m.sample(**sample_kwargs)
268
+ beta_m_value = self.get_roi_prior_beta_m_value(
269
+ beta_gm_dev=beta_gm_dev,
270
+ media_transformed=media_transformed,
271
+ roi_or_mroi_m=mroi_m,
272
+ **media_vars,
273
+ )
274
+ media_vars[constants.MROI_M] = mroi_m
275
+ media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
276
+ beta_m_value, name=constants.BETA_M
277
+ ).sample()
278
+ else:
279
+ media_vars[constants.BETA_M] = prior.beta_m.sample(**sample_kwargs)
280
+
281
+ beta_eta_combined = (
282
+ media_vars[constants.BETA_M][..., tf.newaxis, :]
283
+ + media_vars[constants.ETA_M][..., tf.newaxis, :] * beta_gm_dev
284
+ )
285
+ beta_gm_value = (
286
+ beta_eta_combined
287
+ if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
288
+ else tf.math.exp(beta_eta_combined)
289
+ )
290
+ media_vars[constants.BETA_GM] = tfp.distributions.Deterministic(
291
+ beta_gm_value, name=constants.BETA_GM
292
+ ).sample()
293
+
294
+ return media_vars
295
+
296
+ def _sample_rf_priors(
297
+ self,
298
+ n_draws: int,
299
+ seed: int | None = None,
300
+ ) -> Mapping[str, tf.Tensor]:
301
+ """Draws samples from the prior distributions of the RF variables.
302
+
303
+ Args:
304
+ n_draws: Number of samples drawn from the prior distribution.
305
+ seed: Used to set the seed for reproducible results. For more information,
306
+ see [PRNGS and seeds]
307
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
308
+
309
+ Returns:
310
+ A mapping of RF parameter names to a tensor of shape `[n_draws, n_geos,
311
+ n_rf_channels]` or `[n_draws, n_rf_channels]` containing the samples.
312
+ """
313
+ mmm = self._meridian
314
+
315
+ prior = mmm.prior_broadcast
316
+ sample_shape = [1, n_draws]
317
+ sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
318
+ rf_vars = {
319
+ constants.ALPHA_RF: prior.alpha_rf.sample(**sample_kwargs),
320
+ constants.EC_RF: prior.ec_rf.sample(**sample_kwargs),
321
+ constants.ETA_RF: prior.eta_rf.sample(**sample_kwargs),
322
+ constants.SLOPE_RF: prior.slope_rf.sample(**sample_kwargs),
323
+ }
324
+ beta_grf_dev = tfp.distributions.Sample(
325
+ tfp.distributions.Normal(0, 1),
326
+ [mmm.n_geos, mmm.n_rf_channels],
327
+ name=constants.BETA_GRF_DEV,
328
+ ).sample(**sample_kwargs)
329
+ rf_transformed = mmm.adstock_hill_rf(
330
+ reach=mmm.rf_tensors.reach_scaled,
331
+ frequency=mmm.rf_tensors.frequency,
332
+ alpha=rf_vars[constants.ALPHA_RF],
333
+ ec=rf_vars[constants.EC_RF],
334
+ slope=rf_vars[constants.SLOPE_RF],
335
+ )
336
+
337
+ prior_type = mmm.model_spec.paid_media_prior_type
338
+ if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
339
+ roi_rf = prior.roi_rf.sample(**sample_kwargs)
340
+ beta_rf_value = self.get_roi_prior_beta_rf_value(
341
+ beta_grf_dev=beta_grf_dev,
342
+ rf_transformed=rf_transformed,
343
+ roi_or_mroi_rf=roi_rf,
344
+ **rf_vars,
345
+ )
346
+ rf_vars[constants.ROI_RF] = roi_rf
347
+ rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
348
+ beta_rf_value,
349
+ name=constants.BETA_RF,
350
+ ).sample()
351
+ elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
352
+ mroi_rf = prior.mroi_rf.sample(**sample_kwargs)
353
+ beta_rf_value = self.get_roi_prior_beta_rf_value(
354
+ beta_grf_dev=beta_grf_dev,
355
+ rf_transformed=rf_transformed,
356
+ roi_or_mroi_rf=mroi_rf,
357
+ **rf_vars,
358
+ )
359
+ rf_vars[constants.MROI_RF] = mroi_rf
360
+ rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
361
+ beta_rf_value,
362
+ name=constants.BETA_RF,
363
+ ).sample()
364
+ else:
365
+ rf_vars[constants.BETA_RF] = prior.beta_rf.sample(**sample_kwargs)
366
+
367
+ beta_eta_combined = (
368
+ rf_vars[constants.BETA_RF][..., tf.newaxis, :]
369
+ + rf_vars[constants.ETA_RF][..., tf.newaxis, :] * beta_grf_dev
370
+ )
371
+ beta_grf_value = (
372
+ beta_eta_combined
373
+ if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
374
+ else tf.math.exp(beta_eta_combined)
375
+ )
376
+ rf_vars[constants.BETA_GRF] = tfp.distributions.Deterministic(
377
+ beta_grf_value, name=constants.BETA_GRF
378
+ ).sample()
379
+
380
+ return rf_vars
381
+
382
+ def _sample_organic_media_priors(
383
+ self,
384
+ n_draws: int,
385
+ seed: int | None = None,
386
+ ) -> Mapping[str, tf.Tensor]:
387
+ """Draws samples from the prior distributions of organic media variables.
388
+
389
+ Args:
390
+ n_draws: Number of samples drawn from the prior distribution.
391
+ seed: Used to set the seed for reproducible results. For more information,
392
+ see [PRNGS and seeds]
393
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
394
+
395
+ Returns:
396
+ A mapping of organic media parameter names to a tensor of shape [n_draws,
397
+ n_geos, n_organic_media_channels] or [n_draws, n_organic_media_channels]
398
+ containing the samples.
399
+ """
400
+ mmm = self._meridian
401
+
402
+ prior = mmm.prior_broadcast
403
+ sample_shape = [1, n_draws]
404
+ sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
405
+ organic_media_vars = {
406
+ constants.ALPHA_OM: prior.alpha_om.sample(**sample_kwargs),
407
+ constants.EC_OM: prior.ec_om.sample(**sample_kwargs),
408
+ constants.ETA_OM: prior.eta_om.sample(**sample_kwargs),
409
+ constants.SLOPE_OM: prior.slope_om.sample(**sample_kwargs),
410
+ }
411
+ beta_gom_dev = tfp.distributions.Sample(
412
+ tfp.distributions.Normal(0, 1),
413
+ [mmm.n_geos, mmm.n_organic_media_channels],
414
+ name=constants.BETA_GOM_DEV,
415
+ ).sample(**sample_kwargs)
416
+
417
+ organic_media_vars[constants.BETA_OM] = prior.beta_om.sample(
418
+ **sample_kwargs
419
+ )
420
+
421
+ beta_eta_combined = (
422
+ organic_media_vars[constants.BETA_OM][..., tf.newaxis, :]
423
+ + organic_media_vars[constants.ETA_OM][..., tf.newaxis, :]
424
+ * beta_gom_dev
425
+ )
426
+ beta_gom_value = (
427
+ beta_eta_combined
428
+ if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
429
+ else tf.math.exp(beta_eta_combined)
430
+ )
431
+ organic_media_vars[constants.BETA_GOM] = tfp.distributions.Deterministic(
432
+ beta_gom_value, name=constants.BETA_GOM
433
+ ).sample()
434
+
435
+ return organic_media_vars
436
+
437
+ def _sample_organic_rf_priors(
438
+ self,
439
+ n_draws: int,
440
+ seed: int | None = None,
441
+ ) -> Mapping[str, tf.Tensor]:
442
+ """Draws samples from the prior distributions of the organic RF variables.
443
+
444
+ Args:
445
+ n_draws: Number of samples drawn from the prior distribution.
446
+ seed: Used to set the seed for reproducible results. For more information,
447
+ see [PRNGS and seeds]
448
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
449
+
450
+ Returns:
451
+ A mapping of organic RF parameter names to a tensor of shape [n_draws,
452
+ n_geos, n_organic_rf_channels] or [n_draws, n_organic_rf_channels]
453
+ containing the samples.
454
+ """
455
+ mmm = self._meridian
456
+
457
+ prior = mmm.prior_broadcast
458
+ sample_shape = [1, n_draws]
459
+ sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
460
+ organic_rf_vars = {
461
+ constants.ALPHA_ORF: prior.alpha_orf.sample(**sample_kwargs),
462
+ constants.EC_ORF: prior.ec_orf.sample(**sample_kwargs),
463
+ constants.ETA_ORF: prior.eta_orf.sample(**sample_kwargs),
464
+ constants.SLOPE_ORF: prior.slope_orf.sample(**sample_kwargs),
465
+ }
466
+ beta_gorf_dev = tfp.distributions.Sample(
467
+ tfp.distributions.Normal(0, 1),
468
+ [mmm.n_geos, mmm.n_organic_rf_channels],
469
+ name=constants.BETA_GORF_DEV,
470
+ ).sample(**sample_kwargs)
471
+
472
+ organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(**sample_kwargs)
473
+
474
+ beta_eta_combined = (
475
+ organic_rf_vars[constants.BETA_ORF][..., tf.newaxis, :]
476
+ + organic_rf_vars[constants.ETA_ORF][..., tf.newaxis, :] * beta_gorf_dev
477
+ )
478
+ beta_gorf_value = (
479
+ beta_eta_combined
480
+ if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
481
+ else tf.math.exp(beta_eta_combined)
482
+ )
483
+ organic_rf_vars[constants.BETA_GORF] = tfp.distributions.Deterministic(
484
+ beta_gorf_value, name=constants.BETA_GORF
485
+ ).sample()
486
+
487
+ return organic_rf_vars
488
+
489
+ def _sample_non_media_treatments_priors(
490
+ self,
491
+ n_draws: int,
492
+ seed: int | None = None,
493
+ ) -> Mapping[str, tf.Tensor]:
494
+ """Draws from the prior distributions of the non-media treatment variables.
495
+
496
+ Args:
497
+ n_draws: Number of samples drawn from the prior distribution.
498
+ seed: Used to set the seed for reproducible results. For more information,
499
+ see [PRNGS and seeds]
500
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
501
+
502
+ Returns:
503
+ A mapping of non-media treatment parameter names to a tensor of shape
504
+ [n_draws,
505
+ n_geos, n_non_media_channels] or [n_draws, n_non_media_channels]
506
+ containing the samples.
507
+ """
508
+ mmm = self._meridian
509
+
510
+ prior = mmm.prior_broadcast
511
+ sample_shape = [1, n_draws]
512
+ sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
513
+ non_media_treatments_vars = {
514
+ constants.GAMMA_N: prior.gamma_n.sample(**sample_kwargs),
515
+ constants.XI_N: prior.xi_n.sample(**sample_kwargs),
516
+ }
517
+ gamma_gn_dev = tfp.distributions.Sample(
518
+ tfp.distributions.Normal(0, 1),
519
+ [mmm.n_geos, mmm.n_non_media_channels],
520
+ name=constants.GAMMA_GN_DEV,
521
+ ).sample(**sample_kwargs)
522
+ non_media_treatments_vars[constants.GAMMA_GN] = (
523
+ tfp.distributions.Deterministic(
524
+ non_media_treatments_vars[constants.GAMMA_N][..., tf.newaxis, :]
525
+ + non_media_treatments_vars[constants.XI_N][..., tf.newaxis, :]
526
+ * gamma_gn_dev,
527
+ name=constants.GAMMA_GN,
528
+ ).sample()
529
+ )
530
+ return non_media_treatments_vars
531
+
532
+ def _sample_prior(
533
+ self,
534
+ n_draws: int,
535
+ seed: int | None = None,
536
+ ) -> Mapping[str, tf.Tensor]:
537
+ """Returns a mapping of prior parameters to tensors of the samples."""
538
+ mmm = self._meridian
539
+
540
+ # For stateful sampling, the random seed must be set to ensure that any
541
+ # random numbers that are generated are deterministic.
542
+ if seed is not None:
543
+ tf.keras.utils.set_random_seed(1)
544
+
545
+ prior = mmm.prior_broadcast
546
+ sample_shape = [1, n_draws]
547
+ sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
548
+
549
+ tau_g_excl_baseline = prior.tau_g_excl_baseline.sample(**sample_kwargs)
550
+ base_vars = {
551
+ constants.KNOT_VALUES: prior.knot_values.sample(**sample_kwargs),
552
+ constants.GAMMA_C: prior.gamma_c.sample(**sample_kwargs),
553
+ constants.XI_C: prior.xi_c.sample(**sample_kwargs),
554
+ constants.SIGMA: prior.sigma.sample(**sample_kwargs),
555
+ constants.TAU_G: _get_tau_g(
556
+ tau_g_excl_baseline=tau_g_excl_baseline,
557
+ baseline_geo_idx=mmm.baseline_geo_idx,
558
+ ).sample(),
559
+ }
560
+ base_vars[constants.MU_T] = tfp.distributions.Deterministic(
561
+ tf.einsum(
562
+ "...k,kt->...t",
563
+ base_vars[constants.KNOT_VALUES],
564
+ tf.convert_to_tensor(mmm.knot_info.weights),
565
+ ),
566
+ name=constants.MU_T,
567
+ ).sample()
568
+
569
+ gamma_gc_dev = tfp.distributions.Sample(
570
+ tfp.distributions.Normal(0, 1),
571
+ [mmm.n_geos, mmm.n_controls],
572
+ name=constants.GAMMA_GC_DEV,
573
+ ).sample(**sample_kwargs)
574
+ base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic(
575
+ base_vars[constants.GAMMA_C][..., tf.newaxis, :]
576
+ + base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev,
577
+ name=constants.GAMMA_GC,
578
+ ).sample()
579
+
580
+ media_vars = (
581
+ self._sample_media_priors(n_draws, seed)
582
+ if mmm.media_tensors.media is not None
583
+ else {}
584
+ )
585
+ rf_vars = (
586
+ self._sample_rf_priors(n_draws, seed)
587
+ if mmm.rf_tensors.reach is not None
588
+ else {}
589
+ )
590
+ organic_media_vars = (
591
+ self._sample_organic_media_priors(n_draws, seed)
592
+ if mmm.organic_media_tensors.organic_media is not None
593
+ else {}
594
+ )
595
+ organic_rf_vars = (
596
+ self._sample_organic_rf_priors(n_draws, seed)
597
+ if mmm.organic_rf_tensors.organic_reach is not None
598
+ else {}
599
+ )
600
+ non_media_treatments_vars = (
601
+ self._sample_non_media_treatments_priors(n_draws, seed)
602
+ if mmm.non_media_treatments_scaled is not None
603
+ else {}
604
+ )
605
+
606
+ return (
607
+ base_vars
608
+ | media_vars
609
+ | rf_vars
610
+ | organic_media_vars
611
+ | organic_rf_vars
612
+ | non_media_treatments_vars
613
+ )
614
+
615
+ def __call__(self, n_draws: int, seed: int | None = None) -> az.InferenceData:
616
+ """Draws samples from prior distributions.
617
+
618
+ Returns:
619
+ An Arviz `InferenceData` object containing prior samples only.
620
+
621
+ Args:
622
+ n_draws: Number of samples drawn from the prior distribution.
623
+ seed: Used to set the seed for reproducible results. For more information,
624
+ see [PRNGS and seeds]
625
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
626
+ """
627
+ prior_draws = self._sample_prior(n_draws, seed=seed)
628
+ # Create Arviz InferenceData for prior draws.
629
+ prior_coords = self._meridian.create_inference_data_coords(1, n_draws)
630
+ prior_dims = self._meridian.create_inference_data_dims()
631
+ return az.convert_to_inference_data(
632
+ prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR
633
+ )