google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,10 +18,9 @@ from collections.abc import Mapping, Sequence
18
18
  from typing import TYPE_CHECKING
19
19
 
20
20
  import arviz as az
21
+ from meridian import backend
21
22
  from meridian import constants
22
23
  import numpy as np
23
- import tensorflow as tf
24
- import tensorflow_probability as tfp
25
24
 
26
25
  if TYPE_CHECKING:
27
26
  from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
@@ -43,8 +42,8 @@ class MCMCOOMError(Exception):
43
42
 
44
43
 
45
44
  def _get_tau_g(
46
- tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int
47
- ) -> tfp.distributions.Distribution:
45
+ tau_g_excl_baseline: backend.Tensor, baseline_geo_idx: int
46
+ ) -> backend.tfd.Distribution:
48
47
  """Computes `tau_g` from `tau_g_excl_baseline`.
49
48
 
50
49
  This function computes `tau_g` by inserting a column of zeros at the
@@ -62,21 +61,21 @@ def _get_tau_g(
62
61
  """
63
62
  rank = len(tau_g_excl_baseline.shape)
64
63
  shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
65
- tau_g = tf.concat(
64
+ tau_g = backend.concatenate(
66
65
  [
67
66
  tau_g_excl_baseline[..., :baseline_geo_idx],
68
- tf.zeros(shape, dtype=tau_g_excl_baseline.dtype),
67
+ backend.zeros(shape, dtype=tau_g_excl_baseline.dtype),
69
68
  tau_g_excl_baseline[..., baseline_geo_idx:],
70
69
  ],
71
70
  axis=rank - 1,
72
71
  )
73
- return tfp.distributions.Deterministic(tau_g, name="tau_g")
72
+ return backend.tfd.Deterministic(tau_g, name="tau_g")
74
73
 
75
74
 
76
- @tf.function(autograph=False, jit_compile=True)
75
+ @backend.function(autograph=False, jit_compile=True)
77
76
  def _xla_windowed_adaptive_nuts(**kwargs):
78
77
  """XLA wrapper for windowed_adaptive_nuts."""
79
- return tfp.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
78
+ return backend.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
80
79
 
81
80
 
82
81
  class PosteriorMCMCSampler:
@@ -89,7 +88,7 @@ class PosteriorMCMCSampler:
89
88
  def model(self) -> "model.Meridian":
90
89
  return self._meridian
91
90
 
92
- def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution:
91
+ def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
93
92
  """Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
94
93
  mmm = self.model
95
94
  mmm.populate_cached_properties()
@@ -120,13 +119,13 @@ class PosteriorMCMCSampler:
120
119
  adstock_hill_rf_fn = mmm.adstock_hill_rf
121
120
  total_outcome = mmm.total_outcome
122
121
 
123
- @tfp.distributions.JointDistributionCoroutineAutoBatched
122
+ @backend.tfd.JointDistributionCoroutineAutoBatched
124
123
  def joint_dist_unpinned():
125
124
  # Sample directly from prior.
126
125
  knot_values = yield prior_broadcast.knot_values
127
126
  sigma = yield prior_broadcast.sigma
128
127
 
129
- tau_g_excl_baseline = yield tfp.distributions.Sample(
128
+ tau_g_excl_baseline = yield backend.tfd.Sample(
130
129
  prior_broadcast.tau_g_excl_baseline,
131
130
  name=constants.TAU_G_EXCL_BASELINE,
132
131
  )
@@ -134,27 +133,27 @@ class PosteriorMCMCSampler:
134
133
  tau_g_excl_baseline=tau_g_excl_baseline,
135
134
  baseline_geo_idx=baseline_geo_idx,
136
135
  )
137
- mu_t = yield tfp.distributions.Deterministic(
138
- tf.einsum(
136
+ mu_t = yield backend.tfd.Deterministic(
137
+ backend.einsum(
139
138
  "k,kt->t",
140
139
  knot_values,
141
- tf.convert_to_tensor(knot_info.weights),
140
+ backend.to_tensor(knot_info.weights),
142
141
  ),
143
142
  name=constants.MU_T,
144
143
  )
145
144
 
146
- tau_gt = tau_g[:, tf.newaxis] + mu_t
147
- combined_media_transformed = tf.zeros(
148
- shape=(n_geos, n_times, 0), dtype=tf.float32
145
+ tau_gt = tau_g[:, backend.newaxis] + mu_t
146
+ combined_media_transformed = backend.zeros(
147
+ shape=(n_geos, n_times, 0), dtype=backend.float32
149
148
  )
150
- combined_beta = tf.zeros(shape=(n_geos, 0), dtype=tf.float32)
149
+ combined_beta = backend.zeros(shape=(n_geos, 0), dtype=backend.float32)
151
150
  if media_tensors.media is not None:
152
151
  alpha_m = yield prior_broadcast.alpha_m
153
152
  ec_m = yield prior_broadcast.ec_m
154
153
  eta_m = yield prior_broadcast.eta_m
155
154
  slope_m = yield prior_broadcast.slope_m
156
- beta_gm_dev = yield tfp.distributions.Sample(
157
- tfp.distributions.Normal(0, 1),
155
+ beta_gm_dev = yield backend.tfd.Sample(
156
+ backend.tfd.Normal(0, 1),
158
157
  [n_geos, n_media_channels],
159
158
  name=constants.BETA_GM_DEV,
160
159
  )
@@ -163,6 +162,7 @@ class PosteriorMCMCSampler:
163
162
  alpha=alpha_m,
164
163
  ec=ec_m,
165
164
  slope=slope_m,
165
+ decay_functions=mmm.adstock_decay_spec.media,
166
166
  )
167
167
  prior_type = mmm.model_spec.effective_media_prior_type
168
168
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
@@ -194,7 +194,7 @@ class PosteriorMCMCSampler:
194
194
  eta_x=eta_m,
195
195
  beta_gx_dev=beta_gm_dev,
196
196
  )
197
- beta_m = yield tfp.distributions.Deterministic(
197
+ beta_m = yield backend.tfd.Deterministic(
198
198
  beta_m_value, name=constants.BETA_M
199
199
  )
200
200
 
@@ -202,23 +202,23 @@ class PosteriorMCMCSampler:
202
202
  beta_gm_value = (
203
203
  beta_eta_combined
204
204
  if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
205
- else tf.math.exp(beta_eta_combined)
205
+ else backend.exp(beta_eta_combined)
206
206
  )
207
- beta_gm = yield tfp.distributions.Deterministic(
207
+ beta_gm = yield backend.tfd.Deterministic(
208
208
  beta_gm_value, name=constants.BETA_GM
209
209
  )
210
- combined_media_transformed = tf.concat(
210
+ combined_media_transformed = backend.concatenate(
211
211
  [combined_media_transformed, media_transformed], axis=-1
212
212
  )
213
- combined_beta = tf.concat([combined_beta, beta_gm], axis=-1)
213
+ combined_beta = backend.concatenate([combined_beta, beta_gm], axis=-1)
214
214
 
215
215
  if rf_tensors.reach is not None:
216
216
  alpha_rf = yield prior_broadcast.alpha_rf
217
217
  ec_rf = yield prior_broadcast.ec_rf
218
218
  eta_rf = yield prior_broadcast.eta_rf
219
219
  slope_rf = yield prior_broadcast.slope_rf
220
- beta_grf_dev = yield tfp.distributions.Sample(
221
- tfp.distributions.Normal(0, 1),
220
+ beta_grf_dev = yield backend.tfd.Sample(
221
+ backend.tfd.Normal(0, 1),
222
222
  [n_geos, n_rf_channels],
223
223
  name=constants.BETA_GRF_DEV,
224
224
  )
@@ -228,6 +228,7 @@ class PosteriorMCMCSampler:
228
228
  alpha=alpha_rf,
229
229
  ec=ec_rf,
230
230
  slope=slope_rf,
231
+ decay_functions=mmm.adstock_decay_spec.rf,
231
232
  )
232
233
 
233
234
  prior_type = mmm.model_spec.effective_rf_prior_type
@@ -260,7 +261,7 @@ class PosteriorMCMCSampler:
260
261
  eta_x=eta_rf,
261
262
  beta_gx_dev=beta_grf_dev,
262
263
  )
263
- beta_rf = yield tfp.distributions.Deterministic(
264
+ beta_rf = yield backend.tfd.Deterministic(
264
265
  beta_rf_value, name=constants.BETA_RF
265
266
  )
266
267
 
@@ -268,23 +269,23 @@ class PosteriorMCMCSampler:
268
269
  beta_grf_value = (
269
270
  beta_eta_combined
270
271
  if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
271
- else tf.math.exp(beta_eta_combined)
272
+ else backend.exp(beta_eta_combined)
272
273
  )
273
- beta_grf = yield tfp.distributions.Deterministic(
274
+ beta_grf = yield backend.tfd.Deterministic(
274
275
  beta_grf_value, name=constants.BETA_GRF
275
276
  )
276
- combined_media_transformed = tf.concat(
277
+ combined_media_transformed = backend.concatenate(
277
278
  [combined_media_transformed, rf_transformed], axis=-1
278
279
  )
279
- combined_beta = tf.concat([combined_beta, beta_grf], axis=-1)
280
+ combined_beta = backend.concatenate([combined_beta, beta_grf], axis=-1)
280
281
 
281
282
  if organic_media_tensors.organic_media is not None:
282
283
  alpha_om = yield prior_broadcast.alpha_om
283
284
  ec_om = yield prior_broadcast.ec_om
284
285
  eta_om = yield prior_broadcast.eta_om
285
286
  slope_om = yield prior_broadcast.slope_om
286
- beta_gom_dev = yield tfp.distributions.Sample(
287
- tfp.distributions.Normal(0, 1),
287
+ beta_gom_dev = yield backend.tfd.Sample(
288
+ backend.tfd.Normal(0, 1),
288
289
  [n_geos, n_organic_media_channels],
289
290
  name=constants.BETA_GOM_DEV,
290
291
  )
@@ -293,6 +294,7 @@ class PosteriorMCMCSampler:
293
294
  alpha=alpha_om,
294
295
  ec=ec_om,
295
296
  slope=slope_om,
297
+ decay_functions=mmm.adstock_decay_spec.organic_media,
296
298
  )
297
299
  prior_type = mmm.model_spec.organic_media_prior_type
298
300
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
@@ -307,7 +309,7 @@ class PosteriorMCMCSampler:
307
309
  eta_x=eta_om,
308
310
  beta_gx_dev=beta_gom_dev,
309
311
  )
310
- beta_om = yield tfp.distributions.Deterministic(
312
+ beta_om = yield backend.tfd.Deterministic(
311
313
  beta_om_value, name=constants.BETA_OM
312
314
  )
313
315
  else:
@@ -317,23 +319,23 @@ class PosteriorMCMCSampler:
317
319
  beta_gom_value = (
318
320
  beta_eta_combined
319
321
  if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
320
- else tf.math.exp(beta_eta_combined)
322
+ else backend.exp(beta_eta_combined)
321
323
  )
322
- beta_gom = yield tfp.distributions.Deterministic(
324
+ beta_gom = yield backend.tfd.Deterministic(
323
325
  beta_gom_value, name=constants.BETA_GOM
324
326
  )
325
- combined_media_transformed = tf.concat(
327
+ combined_media_transformed = backend.concatenate(
326
328
  [combined_media_transformed, organic_media_transformed], axis=-1
327
329
  )
328
- combined_beta = tf.concat([combined_beta, beta_gom], axis=-1)
330
+ combined_beta = backend.concatenate([combined_beta, beta_gom], axis=-1)
329
331
 
330
332
  if organic_rf_tensors.organic_reach is not None:
331
333
  alpha_orf = yield prior_broadcast.alpha_orf
332
334
  ec_orf = yield prior_broadcast.ec_orf
333
335
  eta_orf = yield prior_broadcast.eta_orf
334
336
  slope_orf = yield prior_broadcast.slope_orf
335
- beta_gorf_dev = yield tfp.distributions.Sample(
336
- tfp.distributions.Normal(0, 1),
337
+ beta_gorf_dev = yield backend.tfd.Sample(
338
+ backend.tfd.Normal(0, 1),
337
339
  [n_geos, n_organic_rf_channels],
338
340
  name=constants.BETA_GORF_DEV,
339
341
  )
@@ -343,6 +345,7 @@ class PosteriorMCMCSampler:
343
345
  alpha=alpha_orf,
344
346
  ec=ec_orf,
345
347
  slope=slope_orf,
348
+ decay_functions=mmm.adstock_decay_spec.organic_rf,
346
349
  )
347
350
 
348
351
  prior_type = mmm.model_spec.organic_rf_prior_type
@@ -358,7 +361,7 @@ class PosteriorMCMCSampler:
358
361
  eta_x=eta_orf,
359
362
  beta_gx_dev=beta_gorf_dev,
360
363
  )
361
- beta_orf = yield tfp.distributions.Deterministic(
364
+ beta_orf = yield backend.tfd.Deterministic(
362
365
  beta_orf_value, name=constants.BETA_ORF
363
366
  )
364
367
  else:
@@ -368,18 +371,20 @@ class PosteriorMCMCSampler:
368
371
  beta_gorf_value = (
369
372
  beta_eta_combined
370
373
  if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
371
- else tf.math.exp(beta_eta_combined)
374
+ else backend.exp(beta_eta_combined)
372
375
  )
373
- beta_gorf = yield tfp.distributions.Deterministic(
376
+ beta_gorf = yield backend.tfd.Deterministic(
374
377
  beta_gorf_value, name=constants.BETA_GORF
375
378
  )
376
- combined_media_transformed = tf.concat(
379
+ combined_media_transformed = backend.concatenate(
377
380
  [combined_media_transformed, organic_rf_transformed], axis=-1
378
381
  )
379
- combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1)
382
+ combined_beta = backend.concatenate([combined_beta, beta_gorf], axis=-1)
380
383
 
381
- sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos]))
382
- y_pred_combined_media = tau_gt + tf.einsum(
384
+ sigma_gt = backend.transpose(
385
+ backend.broadcast_to(sigma, [n_times, n_geos])
386
+ )
387
+ y_pred_combined_media = tau_gt + backend.einsum(
383
388
  "gtm,gm->gt", combined_media_transformed, combined_beta
384
389
  )
385
390
  # Omit gamma_c, xi_c, and gamma_gc from joint distribution output if
@@ -387,22 +392,22 @@ class PosteriorMCMCSampler:
387
392
  if n_controls:
388
393
  gamma_c = yield prior_broadcast.gamma_c
389
394
  xi_c = yield prior_broadcast.xi_c
390
- gamma_gc_dev = yield tfp.distributions.Sample(
391
- tfp.distributions.Normal(0, 1),
395
+ gamma_gc_dev = yield backend.tfd.Sample(
396
+ backend.tfd.Normal(0, 1),
392
397
  [n_geos, n_controls],
393
398
  name=constants.GAMMA_GC_DEV,
394
399
  )
395
- gamma_gc = yield tfp.distributions.Deterministic(
400
+ gamma_gc = yield backend.tfd.Deterministic(
396
401
  gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
397
402
  )
398
- y_pred_combined_media += tf.einsum(
403
+ y_pred_combined_media += backend.einsum(
399
404
  "gtc,gc->gt", controls_scaled, gamma_gc
400
405
  )
401
406
 
402
407
  if mmm.non_media_treatments is not None:
403
408
  xi_n = yield prior_broadcast.xi_n
404
- gamma_gn_dev = yield tfp.distributions.Sample(
405
- tfp.distributions.Normal(0, 1),
409
+ gamma_gn_dev = yield backend.tfd.Sample(
410
+ backend.tfd.Normal(0, 1),
406
411
  [n_geos, n_non_media_channels],
407
412
  name=constants.GAMMA_GN_DEV,
408
413
  )
@@ -425,15 +430,15 @@ class PosteriorMCMCSampler:
425
430
  eta_x=xi_n,
426
431
  beta_gx_dev=gamma_gn_dev,
427
432
  )
428
- gamma_n = yield tfp.distributions.Deterministic(
433
+ gamma_n = yield backend.tfd.Deterministic(
429
434
  gamma_n_value, name=constants.GAMMA_N
430
435
  )
431
436
  else:
432
437
  raise ValueError(f"Unsupported prior type: {prior_type}")
433
- gamma_gn = yield tfp.distributions.Deterministic(
438
+ gamma_gn = yield backend.tfd.Deterministic(
434
439
  gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
435
440
  )
436
- y_pred = y_pred_combined_media + tf.einsum(
441
+ y_pred = y_pred_combined_media + backend.einsum(
437
442
  "gtn,gn->gt", non_media_treatments_normalized, gamma_gn
438
443
  )
439
444
  else:
@@ -445,21 +450,19 @@ class PosteriorMCMCSampler:
445
450
  # deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the
446
451
  # sampled posterior parameter values.
447
452
  if holdout_id is not None:
448
- y_pred_holdout = tf.where(holdout_id, 0.0, y_pred)
449
- test_sd = tf.cast(1.0 / np.sqrt(2.0 * np.pi), tf.float32)
450
- sigma_gt_holdout = tf.where(holdout_id, test_sd, sigma_gt)
451
- yield tfp.distributions.Normal(
452
- y_pred_holdout, sigma_gt_holdout, name="y"
453
- )
453
+ y_pred_holdout = backend.where(holdout_id, 0.0, y_pred)
454
+ test_sd = backend.cast(1.0 / np.sqrt(2.0 * np.pi), backend.float32)
455
+ sigma_gt_holdout = backend.where(holdout_id, test_sd, sigma_gt)
456
+ yield backend.tfd.Normal(y_pred_holdout, sigma_gt_holdout, name="y")
454
457
  else:
455
- yield tfp.distributions.Normal(y_pred, sigma_gt, name="y")
458
+ yield backend.tfd.Normal(y_pred, sigma_gt, name="y")
456
459
 
457
460
  return joint_dist_unpinned
458
461
 
459
- def _get_joint_dist(self) -> tfp.distributions.Distribution:
462
+ def _get_joint_dist(self) -> backend.tfd.Distribution:
460
463
  mmm = self.model
461
464
  y = (
462
- tf.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
465
+ backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
463
466
  if mmm.holdout_id is not None
464
467
  else mmm.kpi_scaled
465
468
  )
@@ -471,7 +474,7 @@ class PosteriorMCMCSampler:
471
474
  n_adapt: int,
472
475
  n_burnin: int,
473
476
  n_keep: int,
474
- current_state: Mapping[str, tf.Tensor] | None = None,
477
+ current_state: Mapping[str, backend.Tensor] | None = None,
475
478
  init_step_size: int | None = None,
476
479
  dual_averaging_kwargs: Mapping[str, int] | None = None,
477
480
  max_tree_depth: int = 10,
@@ -528,7 +531,7 @@ class PosteriorMCMCSampler:
528
531
  be a positive integer. For more information, see `tf.while_loop`.
529
532
  seed: An `int32[2]` Tensor or a Python list or tuple of 2 `int`s, which
530
533
  will be treated as stateless seeds; or a Python `int` or `None`, which
531
- will be treated as stateful seeds. See [tfp.random.sanitize_seed]
534
+ will be converted into a stateless seed. See [tfp.random.sanitize_seed]
532
535
  (https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed).
533
536
  **pins: These are used to condition the provided joint distribution, and
534
537
  are passed directly to `joint_dist.experimental_pin(**pins)`.
@@ -547,7 +550,9 @@ class PosteriorMCMCSampler:
547
550
  " [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
548
551
  " for details."
549
552
  )
550
- seed = tfp.random.sanitize_seed(seed) if seed is not None else None
553
+ if seed is not None and isinstance(seed, int):
554
+ seed = (seed, seed)
555
+ seed = backend.random.sanitize_seed(seed) if seed is not None else None
551
556
  n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
552
557
  total_chains = np.sum(n_chains_list)
553
558
 
@@ -570,7 +575,7 @@ class PosteriorMCMCSampler:
570
575
  seed=seed,
571
576
  **pins,
572
577
  )
573
- except tf.errors.ResourceExhaustedError as error:
578
+ except backend.errors.ResourceExhaustedError as error:
574
579
  raise MCMCOOMError(
575
580
  "ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
576
581
  " integers as `n_chains` to sample chains serially (see"
@@ -582,9 +587,11 @@ class PosteriorMCMCSampler:
582
587
  traces.append(mcmc.trace)
583
588
 
584
589
  mcmc_states = {
585
- k: tf.einsum(
590
+ k: backend.einsum(
586
591
  "ij...->ji...",
587
- tf.concat([state[k] for state in states], axis=1)[n_burnin:, ...],
592
+ backend.concatenate([state[k] for state in states], axis=1)[
593
+ n_burnin:, ...
594
+ ],
588
595
  )
589
596
  for k in states[0].keys()
590
597
  if k not in constants.UNSAVED_PARAMETERS
@@ -602,10 +609,10 @@ class PosteriorMCMCSampler:
602
609
  mcmc_trace = {}
603
610
  for k in traces[0].keys():
604
611
  if k not in constants.IGNORED_TRACE_METRICS:
605
- mcmc_trace[k] = tf.concat(
612
+ mcmc_trace[k] = backend.concatenate(
606
613
  [
607
- tf.broadcast_to(
608
- tf.transpose(trace[k][n_burnin:, ...]),
614
+ backend.broadcast_to(
615
+ backend.transpose(trace[k][n_burnin:, ...]),
609
616
  [n_chains_list[i], n_keep],
610
617
  )
611
618
  for i, trace in enumerate(traces)
@@ -645,7 +652,7 @@ class PosteriorMCMCSampler:
645
652
  # Tensorflow does not include a "draw" dimension on step size metric if same
646
653
  # step size is used for all chains. Step size must be broadcast to the
647
654
  # correct shape.
648
- sample_stats[constants.STEP_SIZE] = tf.broadcast_to(
655
+ sample_stats[constants.STEP_SIZE] = backend.broadcast_to(
649
656
  sample_stats[constants.STEP_SIZE], [total_chains, n_keep]
650
657
  )
651
658
  sample_stats_dims[constants.STEP_SIZE] = [constants.CHAIN, constants.DRAW]