google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,18 +19,22 @@ used by the Meridian model object.
19
19
  """
20
20
 
21
21
  from __future__ import annotations
22
- from collections.abc import MutableMapping
22
+
23
+ from collections.abc import MutableMapping, Sequence
23
24
  import dataclasses
24
25
  from typing import Any
25
26
  import warnings
27
+
28
+ from meridian import backend
26
29
  from meridian import constants
27
30
  import numpy as np
28
- import tensorflow as tf
29
- import tensorflow_probability as tfp
30
31
 
31
32
 
32
33
  __all__ = [
34
+ 'IndependentMultivariateDistribution',
33
35
  'PriorDistribution',
36
+ 'distributions_are_equal',
37
+ 'lognormal_dist_from_mean_std',
34
38
  ]
35
39
 
36
40
 
@@ -172,14 +176,14 @@ class PriorDistribution:
172
176
  xi_n: Prior distribution on the hierarchical standard deviation of
173
177
  `gamma_gn` which is the coefficient on non-media channel `n` for geo `g`.
174
178
  Hierarchy is defined over geos. Default distribution is `HalfNormal(5.0)`.
175
- alpha_m: Prior distribution on the `geometric decay` Adstock parameter for
179
+ alpha_m: Prior distribution on the Adstock decay parameter for media input.
180
+ Default distribution is `Uniform(0.0, 1.0)`.
181
+ alpha_rf: Prior distribution on the Adstock decay parameter for RF input.
182
+ Default distribution is `Uniform(0.0, 1.0)`.
183
+ alpha_om: Prior distribution on the Adstock decay parameter for organic
176
184
  media input. Default distribution is `Uniform(0.0, 1.0)`.
177
- alpha_rf: Prior distribution on the `geometric decay` Adstock parameter for
178
- RF input. Default distribution is `Uniform(0.0, 1.0)`.
179
- alpha_om: Prior distribution on the `geometric decay` Adstock parameter for
180
- organic media input. Default distribution is `Uniform(0.0, 1.0)`.
181
- alpha_orf: Prior distribution on the `geometric decay` Adstock parameter for
182
- organic RF input. Default distribution is `Uniform(0.0, 1.0)`.
185
+ alpha_orf: Prior distribution on the Adstock decay parameter for organic RF
186
+ input. Default distribution is `Uniform(0.0, 1.0)`.
183
187
  ec_m: Prior distribution on the `half-saturation` Hill parameter for media
184
188
  input. Default distribution is `TruncatedNormal(0.8, 0.8, 0.1, 10)`.
185
189
  ec_rf: Prior distribution on the `half-saturation` Hill parameter for RF
@@ -203,9 +207,9 @@ class PriorDistribution:
203
207
  distribution is `HalfNormal(5.0)`.
204
208
  roi_m: Prior distribution on the ROI of each media channel. This parameter
205
209
  is only used when `paid_media_prior_type` is `'roi'`, in which case
206
- `beta_m` is calculated as a deterministic function of `roi_rf`,
207
- `alpha_rf`, `ec_rf`, `slope_rf`, and the spend associated with each media
208
- channel. Default distribution is `LogNormal(0.2, 0.9)`. When `kpi_type` is
210
+ `beta_m` is calculated as a deterministic function of `roi_m`, `alpha_m`,
211
+ `ec_m`, `slope_m`, and the spend associated with each media channel.
212
+ Default distribution is `LogNormal(0.2, 0.9)`. When `kpi_type` is
209
213
  `'non_revenue'` and `revenue_per_kpi` is not provided, ROI is interpreted
210
214
  as incremental KPI units per monetary unit spent. In this case, the
211
215
  default value for `roi_m` and `roi_rf` will be ignored and a common ROI
@@ -270,196 +274,202 @@ class PriorDistribution:
270
274
  `TruncatedNormal(0.0, 0.1, -1.0, 1.0)`.
271
275
  """
272
276
 
273
- knot_values: tfp.distributions.Distribution = dataclasses.field(
274
- default_factory=lambda: tfp.distributions.Normal(
277
+ knot_values: backend.tfd.Distribution = dataclasses.field(
278
+ default_factory=lambda: backend.tfd.Normal(
275
279
  0.0, 5.0, name=constants.KNOT_VALUES
276
280
  ),
277
281
  )
278
- tau_g_excl_baseline: tfp.distributions.Distribution = dataclasses.field(
279
- default_factory=lambda: tfp.distributions.Normal(
282
+ tau_g_excl_baseline: backend.tfd.Distribution = dataclasses.field(
283
+ default_factory=lambda: backend.tfd.Normal(
280
284
  0.0, 5.0, name=constants.TAU_G_EXCL_BASELINE
281
285
  ),
282
286
  )
283
- beta_m: tfp.distributions.Distribution = dataclasses.field(
284
- default_factory=lambda: tfp.distributions.HalfNormal(
287
+ beta_m: backend.tfd.Distribution = dataclasses.field(
288
+ default_factory=lambda: backend.tfd.HalfNormal(
285
289
  5.0, name=constants.BETA_M
286
290
  ),
287
291
  )
288
- beta_rf: tfp.distributions.Distribution = dataclasses.field(
289
- default_factory=lambda: tfp.distributions.HalfNormal(
292
+ beta_rf: backend.tfd.Distribution = dataclasses.field(
293
+ default_factory=lambda: backend.tfd.HalfNormal(
290
294
  5.0, name=constants.BETA_RF
291
295
  ),
292
296
  )
293
- beta_om: tfp.distributions.Distribution = dataclasses.field(
294
- default_factory=lambda: tfp.distributions.HalfNormal(
297
+ beta_om: backend.tfd.Distribution = dataclasses.field(
298
+ default_factory=lambda: backend.tfd.HalfNormal(
295
299
  5.0, name=constants.BETA_OM
296
300
  ),
297
301
  )
298
- beta_orf: tfp.distributions.Distribution = dataclasses.field(
299
- default_factory=lambda: tfp.distributions.HalfNormal(
302
+ beta_orf: backend.tfd.Distribution = dataclasses.field(
303
+ default_factory=lambda: backend.tfd.HalfNormal(
300
304
  5.0, name=constants.BETA_ORF
301
305
  ),
302
306
  )
303
- eta_m: tfp.distributions.Distribution = dataclasses.field(
304
- default_factory=lambda: tfp.distributions.HalfNormal(
305
- 1.0, name=constants.ETA_M
306
- ),
307
+ eta_m: backend.tfd.Distribution = dataclasses.field(
308
+ default_factory=lambda: backend.tfd.HalfNormal(1.0, name=constants.ETA_M),
307
309
  )
308
- eta_rf: tfp.distributions.Distribution = dataclasses.field(
309
- default_factory=lambda: tfp.distributions.HalfNormal(
310
+ eta_rf: backend.tfd.Distribution = dataclasses.field(
311
+ default_factory=lambda: backend.tfd.HalfNormal(
310
312
  1.0, name=constants.ETA_RF
311
313
  ),
312
314
  )
313
- eta_om: tfp.distributions.Distribution = dataclasses.field(
314
- default_factory=lambda: tfp.distributions.HalfNormal(
315
+ eta_om: backend.tfd.Distribution = dataclasses.field(
316
+ default_factory=lambda: backend.tfd.HalfNormal(
315
317
  1.0, name=constants.ETA_OM
316
318
  ),
317
319
  )
318
- eta_orf: tfp.distributions.Distribution = dataclasses.field(
319
- default_factory=lambda: tfp.distributions.HalfNormal(
320
+ eta_orf: backend.tfd.Distribution = dataclasses.field(
321
+ default_factory=lambda: backend.tfd.HalfNormal(
320
322
  1.0, name=constants.ETA_ORF
321
323
  ),
322
324
  )
323
- gamma_c: tfp.distributions.Distribution = dataclasses.field(
324
- default_factory=lambda: tfp.distributions.Normal(
325
+ gamma_c: backend.tfd.Distribution = dataclasses.field(
326
+ default_factory=lambda: backend.tfd.Normal(
325
327
  0.0, 5.0, name=constants.GAMMA_C
326
328
  ),
327
329
  )
328
- gamma_n: tfp.distributions.Distribution = dataclasses.field(
329
- default_factory=lambda: tfp.distributions.Normal(
330
+ gamma_n: backend.tfd.Distribution = dataclasses.field(
331
+ default_factory=lambda: backend.tfd.Normal(
330
332
  0.0, 5.0, name=constants.GAMMA_N
331
333
  ),
332
334
  )
333
- xi_c: tfp.distributions.Distribution = dataclasses.field(
334
- default_factory=lambda: tfp.distributions.HalfNormal(
335
- 5.0, name=constants.XI_C
336
- ),
335
+ xi_c: backend.tfd.Distribution = dataclasses.field(
336
+ default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_C),
337
337
  )
338
- xi_n: tfp.distributions.Distribution = dataclasses.field(
339
- default_factory=lambda: tfp.distributions.HalfNormal(
340
- 5.0, name=constants.XI_N
341
- ),
338
+ xi_n: backend.tfd.Distribution = dataclasses.field(
339
+ default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_N),
342
340
  )
343
- alpha_m: tfp.distributions.Distribution = dataclasses.field(
344
- default_factory=lambda: tfp.distributions.Uniform(
341
+ alpha_m: backend.tfd.Distribution = dataclasses.field(
342
+ default_factory=lambda: backend.tfd.Uniform(
345
343
  0.0, 1.0, name=constants.ALPHA_M
346
344
  ),
347
345
  )
348
- alpha_rf: tfp.distributions.Distribution = dataclasses.field(
349
- default_factory=lambda: tfp.distributions.Uniform(
346
+ alpha_rf: backend.tfd.Distribution = dataclasses.field(
347
+ default_factory=lambda: backend.tfd.Uniform(
350
348
  0.0, 1.0, name=constants.ALPHA_RF
351
349
  ),
352
350
  )
353
- alpha_om: tfp.distributions.Distribution = dataclasses.field(
354
- default_factory=lambda: tfp.distributions.Uniform(
351
+ alpha_om: backend.tfd.Distribution = dataclasses.field(
352
+ default_factory=lambda: backend.tfd.Uniform(
355
353
  0.0, 1.0, name=constants.ALPHA_OM
356
354
  ),
357
355
  )
358
- alpha_orf: tfp.distributions.Distribution = dataclasses.field(
359
- default_factory=lambda: tfp.distributions.Uniform(
356
+ alpha_orf: backend.tfd.Distribution = dataclasses.field(
357
+ default_factory=lambda: backend.tfd.Uniform(
360
358
  0.0, 1.0, name=constants.ALPHA_ORF
361
359
  ),
362
360
  )
363
- ec_m: tfp.distributions.Distribution = dataclasses.field(
364
- default_factory=lambda: tfp.distributions.TruncatedNormal(
361
+ ec_m: backend.tfd.Distribution = dataclasses.field(
362
+ default_factory=lambda: backend.tfd.TruncatedNormal(
365
363
  0.8, 0.8, 0.1, 10, name=constants.EC_M
366
364
  ),
367
365
  )
368
- ec_rf: tfp.distributions.Distribution = dataclasses.field(
369
- default_factory=lambda: tfp.distributions.TransformedDistribution(
370
- tfp.distributions.LogNormal(0.7, 0.4),
371
- tfp.bijectors.Shift(0.1),
366
+ ec_rf: backend.tfd.Distribution = dataclasses.field(
367
+ default_factory=lambda: backend.tfd.TransformedDistribution(
368
+ backend.tfd.LogNormal(0.7, 0.4),
369
+ backend.bijectors.Shift(0.1),
372
370
  name=constants.EC_RF,
373
371
  ),
374
372
  )
375
- ec_om: tfp.distributions.Distribution = dataclasses.field(
376
- default_factory=lambda: tfp.distributions.TruncatedNormal(
373
+ ec_om: backend.tfd.Distribution = dataclasses.field(
374
+ default_factory=lambda: backend.tfd.TruncatedNormal(
377
375
  0.8, 0.8, 0.1, 10, name=constants.EC_OM
378
376
  ),
379
377
  )
380
- ec_orf: tfp.distributions.Distribution = dataclasses.field(
381
- default_factory=lambda: tfp.distributions.TransformedDistribution(
382
- tfp.distributions.LogNormal(0.7, 0.4),
383
- tfp.bijectors.Shift(0.1),
378
+ ec_orf: backend.tfd.Distribution = dataclasses.field(
379
+ default_factory=lambda: backend.tfd.TransformedDistribution(
380
+ backend.tfd.LogNormal(0.7, 0.4),
381
+ backend.bijectors.Shift(0.1),
384
382
  name=constants.EC_ORF,
385
383
  ),
386
384
  )
387
- slope_m: tfp.distributions.Distribution = dataclasses.field(
388
- default_factory=lambda: tfp.distributions.Deterministic(
385
+ slope_m: backend.tfd.Distribution = dataclasses.field(
386
+ default_factory=lambda: backend.tfd.Deterministic(
389
387
  1.0, name=constants.SLOPE_M
390
388
  ),
391
389
  )
392
- slope_rf: tfp.distributions.Distribution = dataclasses.field(
393
- default_factory=lambda: tfp.distributions.LogNormal(
390
+ slope_rf: backend.tfd.Distribution = dataclasses.field(
391
+ default_factory=lambda: backend.tfd.LogNormal(
394
392
  0.7, 0.4, name=constants.SLOPE_RF
395
393
  ),
396
394
  )
397
- slope_om: tfp.distributions.Distribution = dataclasses.field(
398
- default_factory=lambda: tfp.distributions.Deterministic(
395
+ slope_om: backend.tfd.Distribution = dataclasses.field(
396
+ default_factory=lambda: backend.tfd.Deterministic(
399
397
  1.0, name=constants.SLOPE_OM
400
398
  ),
401
399
  )
402
- slope_orf: tfp.distributions.Distribution = dataclasses.field(
403
- default_factory=lambda: tfp.distributions.LogNormal(
400
+ slope_orf: backend.tfd.Distribution = dataclasses.field(
401
+ default_factory=lambda: backend.tfd.LogNormal(
404
402
  0.7, 0.4, name=constants.SLOPE_ORF
405
403
  ),
406
404
  )
407
- sigma: tfp.distributions.Distribution = dataclasses.field(
408
- default_factory=lambda: tfp.distributions.HalfNormal(
409
- 5.0, name=constants.SIGMA
410
- ),
405
+ sigma: backend.tfd.Distribution = dataclasses.field(
406
+ default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.SIGMA),
411
407
  )
412
- roi_m: tfp.distributions.Distribution = dataclasses.field(
413
- default_factory=lambda: tfp.distributions.LogNormal(
408
+ roi_m: backend.tfd.Distribution = dataclasses.field(
409
+ default_factory=lambda: backend.tfd.LogNormal(
414
410
  0.2, 0.9, name=constants.ROI_M
415
411
  ),
416
412
  )
417
- roi_rf: tfp.distributions.Distribution = dataclasses.field(
418
- default_factory=lambda: tfp.distributions.LogNormal(
413
+ roi_rf: backend.tfd.Distribution = dataclasses.field(
414
+ default_factory=lambda: backend.tfd.LogNormal(
419
415
  0.2, 0.9, name=constants.ROI_RF
420
416
  ),
421
417
  )
422
- mroi_m: tfp.distributions.Distribution = dataclasses.field(
423
- default_factory=lambda: tfp.distributions.LogNormal(
418
+ mroi_m: backend.tfd.Distribution = dataclasses.field(
419
+ default_factory=lambda: backend.tfd.LogNormal(
424
420
  0.0, 0.5, name=constants.MROI_M
425
421
  ),
426
422
  )
427
- mroi_rf: tfp.distributions.Distribution = dataclasses.field(
428
- default_factory=lambda: tfp.distributions.LogNormal(
423
+ mroi_rf: backend.tfd.Distribution = dataclasses.field(
424
+ default_factory=lambda: backend.tfd.LogNormal(
429
425
  0.0, 0.5, name=constants.MROI_RF
430
426
  ),
431
427
  )
432
- contribution_m: tfp.distributions.Distribution = dataclasses.field(
433
- default_factory=lambda: tfp.distributions.Beta(
428
+ contribution_m: backend.tfd.Distribution = dataclasses.field(
429
+ default_factory=lambda: backend.tfd.Beta(
434
430
  1.0, 99.0, name=constants.CONTRIBUTION_M
435
431
  ),
436
432
  )
437
- contribution_rf: tfp.distributions.Distribution = dataclasses.field(
438
- default_factory=lambda: tfp.distributions.Beta(
433
+ contribution_rf: backend.tfd.Distribution = dataclasses.field(
434
+ default_factory=lambda: backend.tfd.Beta(
439
435
  1.0, 99.0, name=constants.CONTRIBUTION_RF
440
436
  ),
441
437
  )
442
- contribution_om: tfp.distributions.Distribution = dataclasses.field(
443
- default_factory=lambda: tfp.distributions.Beta(
438
+ contribution_om: backend.tfd.Distribution = dataclasses.field(
439
+ default_factory=lambda: backend.tfd.Beta(
444
440
  1.0, 99.0, name=constants.CONTRIBUTION_OM
445
441
  ),
446
442
  )
447
- contribution_orf: tfp.distributions.Distribution = dataclasses.field(
448
- default_factory=lambda: tfp.distributions.Beta(
443
+ contribution_orf: backend.tfd.Distribution = dataclasses.field(
444
+ default_factory=lambda: backend.tfd.Beta(
449
445
  1.0, 99.0, name=constants.CONTRIBUTION_ORF
450
446
  ),
451
447
  )
452
- contribution_n: tfp.distributions.Distribution = dataclasses.field(
453
- default_factory=lambda: tfp.distributions.TruncatedNormal(
448
+ contribution_n: backend.tfd.Distribution = dataclasses.field(
449
+ default_factory=lambda: backend.tfd.TruncatedNormal(
454
450
  loc=0.0, scale=0.1, low=-1.0, high=1.0, name=constants.CONTRIBUTION_N
455
451
  ),
456
452
  )
457
453
 
454
+ def __post_init__(self):
455
+ for param, bounds in _parameter_space_bounds.items():
456
+ prevent_deterministic_prior_at_bounds = (
457
+ _prevent_deterministic_prior_at_bounds[param]
458
+ if param in _prevent_deterministic_prior_at_bounds.keys()
459
+ else (False, False)
460
+ )
461
+ _validate_support(
462
+ param,
463
+ getattr(self, param),
464
+ bounds,
465
+ prevent_deterministic_prior_at_bounds,
466
+ )
467
+
458
468
  def __setstate__(self, state):
459
469
  # Override to support pickling.
460
470
  def _unpack_distribution_params(
461
471
  params: MutableMapping[str, Any],
462
- ) -> tfp.distributions.Distribution:
472
+ ) -> backend.tfd.Distribution:
463
473
  if constants.DISTRIBUTION in params:
464
474
  params[constants.DISTRIBUTION] = _unpack_distribution_params(
465
475
  params[constants.DISTRIBUTION]
@@ -478,7 +488,7 @@ class PriorDistribution:
478
488
  state = self.__dict__.copy()
479
489
 
480
490
  def _pack_distribution_params(
481
- dist: tfp.distributions.Distribution,
491
+ dist: backend.tfd.Distribution,
482
492
  ) -> MutableMapping[str, Any]:
483
493
  params = dist.parameters
484
494
  params[constants.DISTRIBUTION_TYPE] = type(dist)
@@ -493,11 +503,9 @@ class PriorDistribution:
493
503
 
494
504
  return state
495
505
 
496
- def has_deterministic_param(
497
- self, param: tfp.distributions.Distribution
498
- ) -> bool:
506
+ def has_deterministic_param(self, param: backend.tfd.Distribution) -> bool:
499
507
  return hasattr(self, param) and isinstance(
500
- getattr(self, param).distribution, tfp.distributions.Deterministic
508
+ getattr(self, param).distribution, backend.tfd.Deterministic
501
509
  )
502
510
 
503
511
  def broadcast(
@@ -550,7 +558,7 @@ class PriorDistribution:
550
558
  """
551
559
 
552
560
  def _validate_media_custom_priors(
553
- param: tfp.distributions.Distribution,
561
+ param: backend.tfd.Distribution,
554
562
  ) -> None:
555
563
  if (
556
564
  param.batch_shape.as_list()
@@ -573,7 +581,7 @@ class PriorDistribution:
573
581
  _validate_media_custom_priors(self.beta_m)
574
582
 
575
583
  def _validate_organic_media_custom_priors(
576
- param: tfp.distributions.Distribution,
584
+ param: backend.tfd.Distribution,
577
585
  ) -> None:
578
586
  if (
579
587
  param.batch_shape.as_list()
@@ -595,7 +603,7 @@ class PriorDistribution:
595
603
  _validate_organic_media_custom_priors(self.beta_om)
596
604
 
597
605
  def _validate_organic_rf_custom_priors(
598
- param: tfp.distributions.Distribution,
606
+ param: backend.tfd.Distribution,
599
607
  ) -> None:
600
608
  if (
601
609
  param.batch_shape.as_list()
@@ -617,7 +625,7 @@ class PriorDistribution:
617
625
  _validate_organic_rf_custom_priors(self.beta_orf)
618
626
 
619
627
  def _validate_rf_custom_priors(
620
- param: tfp.distributions.Distribution,
628
+ param: backend.tfd.Distribution,
621
629
  ) -> None:
622
630
  if param.batch_shape.as_list() and n_rf_channels != param.batch_shape[0]:
623
631
  raise ValueError(
@@ -637,7 +645,7 @@ class PriorDistribution:
637
645
  _validate_rf_custom_priors(self.beta_rf)
638
646
 
639
647
  def _validate_control_custom_priors(
640
- param: tfp.distributions.Distribution,
648
+ param: backend.tfd.Distribution,
641
649
  ) -> None:
642
650
  if param.batch_shape.as_list() and n_controls != param.batch_shape[0]:
643
651
  raise ValueError(
@@ -651,7 +659,7 @@ class PriorDistribution:
651
659
  _validate_control_custom_priors(self.xi_c)
652
660
 
653
661
  def _validate_non_media_custom_priors(
654
- param: tfp.distributions.Distribution,
662
+ param: backend.tfd.Distribution,
655
663
  ) -> None:
656
664
  if (
657
665
  param.batch_shape.as_list()
@@ -669,7 +677,7 @@ class PriorDistribution:
669
677
  _validate_non_media_custom_priors(self.gamma_n)
670
678
  _validate_non_media_custom_priors(self.xi_n)
671
679
 
672
- knot_values = tfp.distributions.BatchBroadcast(
680
+ knot_values = backend.tfd.BatchBroadcast(
673
681
  self.knot_values,
674
682
  n_knots,
675
683
  name=constants.KNOT_VALUES,
@@ -680,19 +688,19 @@ class PriorDistribution:
680
688
  )
681
689
  else:
682
690
  tau_g_converted = self.tau_g_excl_baseline
683
- tau_g_excl_baseline = tfp.distributions.BatchBroadcast(
691
+ tau_g_excl_baseline = backend.tfd.BatchBroadcast(
684
692
  tau_g_converted, n_geos - 1, name=constants.TAU_G_EXCL_BASELINE
685
693
  )
686
- beta_m = tfp.distributions.BatchBroadcast(
694
+ beta_m = backend.tfd.BatchBroadcast(
687
695
  self.beta_m, n_media_channels, name=constants.BETA_M
688
696
  )
689
- beta_rf = tfp.distributions.BatchBroadcast(
697
+ beta_rf = backend.tfd.BatchBroadcast(
690
698
  self.beta_rf, n_rf_channels, name=constants.BETA_RF
691
699
  )
692
- beta_om = tfp.distributions.BatchBroadcast(
700
+ beta_om = backend.tfd.BatchBroadcast(
693
701
  self.beta_om, n_organic_media_channels, name=constants.BETA_OM
694
702
  )
695
- beta_orf = tfp.distributions.BatchBroadcast(
703
+ beta_orf = backend.tfd.BatchBroadcast(
696
704
  self.beta_orf, n_organic_rf_channels, name=constants.BETA_ORF
697
705
  )
698
706
  if is_national:
@@ -705,67 +713,67 @@ class PriorDistribution:
705
713
  eta_rf_converted = self.eta_rf
706
714
  eta_om_converted = self.eta_om
707
715
  eta_orf_converted = self.eta_orf
708
- eta_m = tfp.distributions.BatchBroadcast(
716
+ eta_m = backend.tfd.BatchBroadcast(
709
717
  eta_m_converted, n_media_channels, name=constants.ETA_M
710
718
  )
711
- eta_rf = tfp.distributions.BatchBroadcast(
719
+ eta_rf = backend.tfd.BatchBroadcast(
712
720
  eta_rf_converted, n_rf_channels, name=constants.ETA_RF
713
721
  )
714
- eta_om = tfp.distributions.BatchBroadcast(
722
+ eta_om = backend.tfd.BatchBroadcast(
715
723
  eta_om_converted,
716
724
  n_organic_media_channels,
717
725
  name=constants.ETA_OM,
718
726
  )
719
- eta_orf = tfp.distributions.BatchBroadcast(
727
+ eta_orf = backend.tfd.BatchBroadcast(
720
728
  eta_orf_converted, n_organic_rf_channels, name=constants.ETA_ORF
721
729
  )
722
- gamma_c = tfp.distributions.BatchBroadcast(
730
+ gamma_c = backend.tfd.BatchBroadcast(
723
731
  self.gamma_c, n_controls, name=constants.GAMMA_C
724
732
  )
725
733
  if is_national:
726
734
  xi_c_converted = _convert_to_deterministic_0_distribution(self.xi_c)
727
735
  else:
728
736
  xi_c_converted = self.xi_c
729
- xi_c = tfp.distributions.BatchBroadcast(
737
+ xi_c = backend.tfd.BatchBroadcast(
730
738
  xi_c_converted, n_controls, name=constants.XI_C
731
739
  )
732
- gamma_n = tfp.distributions.BatchBroadcast(
740
+ gamma_n = backend.tfd.BatchBroadcast(
733
741
  self.gamma_n, n_non_media_channels, name=constants.GAMMA_N
734
742
  )
735
743
  if is_national:
736
744
  xi_n_converted = _convert_to_deterministic_0_distribution(self.xi_n)
737
745
  else:
738
746
  xi_n_converted = self.xi_n
739
- xi_n = tfp.distributions.BatchBroadcast(
747
+ xi_n = backend.tfd.BatchBroadcast(
740
748
  xi_n_converted, n_non_media_channels, name=constants.XI_N
741
749
  )
742
- alpha_m = tfp.distributions.BatchBroadcast(
750
+ alpha_m = backend.tfd.BatchBroadcast(
743
751
  self.alpha_m, n_media_channels, name=constants.ALPHA_M
744
752
  )
745
- alpha_rf = tfp.distributions.BatchBroadcast(
753
+ alpha_rf = backend.tfd.BatchBroadcast(
746
754
  self.alpha_rf, n_rf_channels, name=constants.ALPHA_RF
747
755
  )
748
- alpha_om = tfp.distributions.BatchBroadcast(
756
+ alpha_om = backend.tfd.BatchBroadcast(
749
757
  self.alpha_om, n_organic_media_channels, name=constants.ALPHA_OM
750
758
  )
751
- alpha_orf = tfp.distributions.BatchBroadcast(
759
+ alpha_orf = backend.tfd.BatchBroadcast(
752
760
  self.alpha_orf, n_organic_rf_channels, name=constants.ALPHA_ORF
753
761
  )
754
- ec_m = tfp.distributions.BatchBroadcast(
762
+ ec_m = backend.tfd.BatchBroadcast(
755
763
  self.ec_m, n_media_channels, name=constants.EC_M
756
764
  )
757
- ec_rf = tfp.distributions.BatchBroadcast(
765
+ ec_rf = backend.tfd.BatchBroadcast(
758
766
  self.ec_rf, n_rf_channels, name=constants.EC_RF
759
767
  )
760
- ec_om = tfp.distributions.BatchBroadcast(
768
+ ec_om = backend.tfd.BatchBroadcast(
761
769
  self.ec_om, n_organic_media_channels, name=constants.EC_OM
762
770
  )
763
- ec_orf = tfp.distributions.BatchBroadcast(
771
+ ec_orf = backend.tfd.BatchBroadcast(
764
772
  self.ec_orf, n_organic_rf_channels, name=constants.EC_ORF
765
773
  )
766
774
  if (
767
- not isinstance(self.slope_m, tfp.distributions.Deterministic)
768
- or (np.isscalar(self.slope_m.loc.numpy()) and self.slope_m.loc != 1.0)
775
+ not isinstance(self.slope_m, backend.tfd.Deterministic)
776
+ or (backend.rank(self.slope_m.loc) == 0 and self.slope_m.loc != 1.0)
769
777
  or (
770
778
  self.slope_m.batch_shape.as_list()
771
779
  and any(x != 1.0 for x in self.slope_m.loc)
@@ -776,15 +784,15 @@ class PriorDistribution:
776
784
  ' This may lead to poor MCMC convergence and budget optimization'
777
785
  ' may no longer produce a global optimum.'
778
786
  )
779
- slope_m = tfp.distributions.BatchBroadcast(
787
+ slope_m = backend.tfd.BatchBroadcast(
780
788
  self.slope_m, n_media_channels, name=constants.SLOPE_M
781
789
  )
782
- slope_rf = tfp.distributions.BatchBroadcast(
790
+ slope_rf = backend.tfd.BatchBroadcast(
783
791
  self.slope_rf, n_rf_channels, name=constants.SLOPE_RF
784
792
  )
785
793
  if (
786
- not isinstance(self.slope_om, tfp.distributions.Deterministic)
787
- or (np.isscalar(self.slope_om.loc.numpy()) and self.slope_om.loc != 1.0)
794
+ not isinstance(self.slope_om, backend.tfd.Deterministic)
795
+ or (backend.rank(self.slope_om.loc) == 0 and self.slope_om.loc != 1.0)
788
796
  or (
789
797
  self.slope_om.batch_shape.as_list()
790
798
  and any(x != 1.0 for x in self.slope_om.loc)
@@ -795,16 +803,16 @@ class PriorDistribution:
795
803
  ' This may lead to poor MCMC convergence and budget optimization'
796
804
  ' may no longer produce a global optimum.'
797
805
  )
798
- slope_om = tfp.distributions.BatchBroadcast(
806
+ slope_om = backend.tfd.BatchBroadcast(
799
807
  self.slope_om, n_organic_media_channels, name=constants.SLOPE_OM
800
808
  )
801
- slope_orf = tfp.distributions.BatchBroadcast(
809
+ slope_orf = backend.tfd.BatchBroadcast(
802
810
  self.slope_orf, n_organic_rf_channels, name=constants.SLOPE_ORF
803
811
  )
804
812
 
805
813
  # If `unique_sigma_for_each_geo == False`, then make a scalar batch.
806
814
  sigma_shape = n_geos if (n_geos > 1 and unique_sigma_for_each_geo) else []
807
- sigma = tfp.distributions.BatchBroadcast(
815
+ sigma = backend.tfd.BatchBroadcast(
808
816
  self.sigma, sigma_shape, name=constants.SIGMA
809
817
  )
810
818
 
@@ -818,37 +826,37 @@ class PriorDistribution:
818
826
  else:
819
827
  roi_m_converted = self.roi_m
820
828
  roi_rf_converted = self.roi_rf
821
- roi_m = tfp.distributions.BatchBroadcast(
829
+ roi_m = backend.tfd.BatchBroadcast(
822
830
  roi_m_converted, n_media_channels, name=constants.ROI_M
823
831
  )
824
- roi_rf = tfp.distributions.BatchBroadcast(
832
+ roi_rf = backend.tfd.BatchBroadcast(
825
833
  roi_rf_converted, n_rf_channels, name=constants.ROI_RF
826
834
  )
827
835
 
828
- mroi_m = tfp.distributions.BatchBroadcast(
836
+ mroi_m = backend.tfd.BatchBroadcast(
829
837
  self.mroi_m, n_media_channels, name=constants.MROI_M
830
838
  )
831
- mroi_rf = tfp.distributions.BatchBroadcast(
839
+ mroi_rf = backend.tfd.BatchBroadcast(
832
840
  self.mroi_rf, n_rf_channels, name=constants.MROI_RF
833
841
  )
834
842
 
835
- contribution_m = tfp.distributions.BatchBroadcast(
843
+ contribution_m = backend.tfd.BatchBroadcast(
836
844
  self.contribution_m, n_media_channels, name=constants.CONTRIBUTION_M
837
845
  )
838
- contribution_rf = tfp.distributions.BatchBroadcast(
846
+ contribution_rf = backend.tfd.BatchBroadcast(
839
847
  self.contribution_rf, n_rf_channels, name=constants.CONTRIBUTION_RF
840
848
  )
841
- contribution_om = tfp.distributions.BatchBroadcast(
849
+ contribution_om = backend.tfd.BatchBroadcast(
842
850
  self.contribution_om,
843
851
  n_organic_media_channels,
844
852
  name=constants.CONTRIBUTION_OM,
845
853
  )
846
- contribution_orf = tfp.distributions.BatchBroadcast(
854
+ contribution_orf = backend.tfd.BatchBroadcast(
847
855
  self.contribution_orf,
848
856
  n_organic_rf_channels,
849
857
  name=constants.CONTRIBUTION_ORF,
850
858
  )
851
- contribution_n = tfp.distributions.BatchBroadcast(
859
+ contribution_n = backend.tfd.BatchBroadcast(
852
860
  self.contribution_n, n_non_media_channels, name=constants.CONTRIBUTION_N
853
861
  )
854
862
 
@@ -892,9 +900,350 @@ class PriorDistribution:
892
900
  )
893
901
 
894
902
 
903
+ class IndependentMultivariateDistribution(backend.tfd.Distribution):
904
+ """Container for a joint distribution created from independent distributions.
905
+
906
+ This class is useful when one wants to define a joint distribution for a
907
+ Meridian prior, where the elements are not necessarily from the same
908
+ distribution family. For example, to define a distribution where
909
+ one element is Uniform and the second is triangular:
910
+
911
+ ```python
912
+ distributions = [
913
+ tfp.distributions.Uniform(0.0, 1.0),
914
+ tfp.distributions.Triangular(0.0, 1.0, 0.5)
915
+ ]
916
+ distribution = IndependentMultivariateDistribution(distributions)
917
+ ```
918
+
919
+ It is also possible to define a distribution where multiple elements come
920
+ from the same distribution family. For example, to define a distribution where
921
+ the three elements are LogNormal(0.2, 0.9), LogNormal(0, 0.5) and
922
+ Gamma(2, 2):
923
+
924
+ ```python
925
+ distributions = [
926
+ tfp.distributions.LogNormal([0.2, 0.0], [0.9, 0.5]),
927
+ tfp.distributions.Gamma(2.0, 2.0)
928
+ ]
929
+ distribution = IndependentMultivariateDistribution(distributions)
930
+ ```
931
+
932
+ This class cannot contain instances of `tfd.Deterministic`.
933
+ """
934
+
935
+ def __init__(
936
+ self,
937
+ distributions: Sequence[backend.tfd.Distribution],
938
+ validate_args: bool = False,
939
+ allow_nan_stats: bool = True,
940
+ name: str | None = None,
941
+ ):
942
+ """Initializes a batch of independent distributions from different families.
943
+
944
+ Args:
945
+ distributions: List of `tfd.Distribution` from which to construct a
946
+ multivariate distribution. The distributions must have scalar or one
947
+ dimensional batch shapes; the resulting batch shape will be the sum of
948
+ the underlying batch shapes.
949
+ validate_args: Python `bool`. When `True` distribution parameters are
950
+ checked for validity despite possibly degrading runtime performance.
951
+ When `False` invalid inputs may silently render incorrect outputs.
952
+ Default value is `False`.
953
+ allow_nan_stats: Python `bool`. When `True`, statistics (e.g., mean, mode,
954
+ variance) use the value "`NaN`" to indicate the result is undefined.
955
+ When `False`, an exception is raised if one or more of the statistic's
956
+ batch members are undefined. Default value is `True`.
957
+ name: Python `str` name prefixed to Ops created by this class. Default
958
+ value is 'IndependentMultivariate' followed by the names of the
959
+ underlying distributions.
960
+
961
+ Raises:
962
+ ValueError: If one or more distributions are instances of
963
+ `tfd.Deterministic` or dtypes differ between the
964
+ distributions.
965
+ """
966
+ parameters = dict(locals())
967
+
968
+ self._verify_distributions(distributions)
969
+
970
+ self._distributions = [
971
+ dist
972
+ if not dist.is_scalar_batch()
973
+ else backend.tfd.BatchBroadcast(dist, (1,))
974
+ for dist in distributions
975
+ ]
976
+
977
+ self._distribution_batch_shapes = self._get_distribution_batch_shapes()
978
+ self._distribution_batch_shape_tensors = backend.concatenate(
979
+ [dist.batch_shape_tensor() for dist in self._distributions],
980
+ axis=0,
981
+ )
982
+
983
+ dtype = self._verify_dtypes()
984
+
985
+ name = name or '-'.join(
986
+ [constants.INDEPENDENT_MULTIVARIATE] + [d.name for d in distributions]
987
+ )
988
+
989
+ super().__init__(
990
+ dtype=dtype,
991
+ reparameterization_type=backend.tfd.NOT_REPARAMETERIZED,
992
+ validate_args=validate_args,
993
+ allow_nan_stats=allow_nan_stats,
994
+ parameters=parameters,
995
+ name=name,
996
+ )
997
+
998
+ def _verify_distributions(
999
+ self, distributions: Sequence[backend.tfd.Distribution]
1000
+ ):
1001
+ """Check for deterministic distributions and raise an error if found."""
1002
+
1003
+ if any(
1004
+ isinstance(dist, backend.tfd.Deterministic) for dist in distributions
1005
+ ):
1006
+ raise ValueError(
1007
+ f'{self.__class__.__name__} cannot contain `Deterministic` '
1008
+ 'distributions. To implement a nearly deterministic element of this '
1009
+ 'distribution, we recommend using `backend.tfd.Uniform` with a '
1010
+ 'small range. For example to define a distribution that is nearly '
1011
+ '`Deterministic(1.0)`, use '
1012
+ '`tfp.distribution.Uniform(1.0 - 1e-9, 1.0 + 1e-9)`'
1013
+ )
1014
+
1015
+ def _verify_dtypes(self) -> str:
1016
+ dtypes = [dist.dtype for dist in self._distributions]
1017
+ if len(set(dtypes)) != 1:
1018
+ raise ValueError(
1019
+ f'All distributions must have the same dtype. Found: {dtypes}.'
1020
+ )
1021
+
1022
+ return backend.result_type(*dtypes)
1023
+
1024
+ def _event_shape(self):
1025
+ return backend.TensorShape([])
1026
+
1027
+ def _batch_shape_tensor(self):
1028
+ distribution_batch_shape_tensors = backend.concatenate(
1029
+ [dist.batch_shape_tensor() for dist in self._distributions],
1030
+ axis=0,
1031
+ )
1032
+ return backend.reduce_sum(distribution_batch_shape_tensors, keepdims=True)
1033
+
1034
+ def _batch_shape(self):
1035
+ return backend.TensorShape(sum(self._distribution_batch_shapes))
1036
+
1037
+ def _sample_n(self, n, seed=None):
1038
+ return backend.concatenate(
1039
+ [dist.sample(n, seed) for dist in self._distributions], axis=-1
1040
+ )
1041
+
1042
+ def _quantile(self, value):
1043
+ value = self._broadcast_value(value)
1044
+ split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
1045
+ quantiles = [
1046
+ dist.quantile(sv) for dist, sv in zip(self._distributions, split_value)
1047
+ ]
1048
+
1049
+ return backend.concatenate(quantiles, axis=-1)
1050
+
1051
+ def _log_prob(self, value):
1052
+ value = self._broadcast_value(value)
1053
+ split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
1054
+ log_probs = [
1055
+ dist.log_prob(sv) for dist, sv in zip(self._distributions, split_value)
1056
+ ]
1057
+
1058
+ return backend.concatenate(log_probs, axis=-1)
1059
+
1060
+ def _log_cdf(self, value):
1061
+ value = self._broadcast_value(value)
1062
+ split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
1063
+
1064
+ log_cdfs = [
1065
+ dist.log_cdf(sv) for dist, sv in zip(self._distributions, split_value)
1066
+ ]
1067
+
1068
+ return backend.concatenate(log_cdfs, axis=-1)
1069
+
1070
+ def _mean(self):
1071
+ return backend.concatenate(
1072
+ [dist.mean() for dist in self._distributions], axis=0
1073
+ )
1074
+
1075
+ def _variance(self):
1076
+ return backend.concatenate(
1077
+ [dist.variance() for dist in self._distributions], axis=0
1078
+ )
1079
+
1080
+ def _default_event_space_bijector(self):
1081
+ """Mapping from R^n to the event space of the wrapped distributions.
1082
+
1083
+ This is the blockwise concatenation of the underlying bijectors.
1084
+
1085
+ Returns:
1086
+ A `tfp.bijectors.Blockwise` object that concatenates the underlying
1087
+ bijectors.
1088
+ """
1089
+ bijectors = [
1090
+ d.experimental_default_event_space_bijector()
1091
+ for d in self._distributions
1092
+ ]
1093
+
1094
+ return backend.bijectors.Blockwise(
1095
+ bijectors,
1096
+ block_sizes=self._distribution_batch_shapes,
1097
+ )
1098
+
1099
+ def _broadcast_value(self, value: backend.Tensor) -> backend.Tensor:
1100
+ value = backend.to_tensor(value)
1101
+ broadcast_shape = backend.broadcast_dynamic_shape(
1102
+ value.shape, self.batch_shape_tensor()
1103
+ )
1104
+ return backend.broadcast_to(value, broadcast_shape)
1105
+
1106
+ def _get_distribution_batch_shapes(self) -> Sequence[int]:
1107
+ """Sequence of batch shapes of underlying distributions."""
1108
+
1109
+ batch_shapes = []
1110
+
1111
+ for dist in self._distributions:
1112
+ try:
1113
+ (dist_batch_shape,) = dist.batch_shape
1114
+ except ValueError as exc:
1115
+ raise ValueError(
1116
+ 'All distributions must be 0- or 1-dimensional.'
1117
+ f' Found {len(dist.batch_shape)}-dimensional distribution:'
1118
+ f' {dist.batch_shape}.'
1119
+ ) from exc
1120
+ else:
1121
+ batch_shapes.append(dist_batch_shape)
1122
+
1123
+ return batch_shapes
1124
+
1125
+
1126
+ def distributions_are_equal(
1127
+ a: backend.tfd.Distribution, b: backend.tfd.Distribution
1128
+ ) -> bool:
1129
+ """Determine if two distributions are equal."""
1130
+ if type(a) != type(b): # pylint: disable=unidiomatic-typecheck
1131
+ return False
1132
+
1133
+ a_params = a.parameters.copy()
1134
+ b_params = b.parameters.copy()
1135
+
1136
+ if constants.DISTRIBUTION in a_params and constants.DISTRIBUTION in b_params:
1137
+ if not distributions_are_equal(
1138
+ a_params[constants.DISTRIBUTION], b_params[constants.DISTRIBUTION]
1139
+ ):
1140
+ return False
1141
+ del a_params[constants.DISTRIBUTION]
1142
+ del b_params[constants.DISTRIBUTION]
1143
+
1144
+ if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
1145
+ return False
1146
+
1147
+ if a_params.keys() != b_params.keys():
1148
+ return False
1149
+
1150
+ for key in a_params.keys():
1151
+ if isinstance(
1152
+ a_params[key], (backend.Tensor, np.ndarray, float, int)
1153
+ ) and isinstance(b_params[key], (backend.Tensor, np.ndarray, float, int)):
1154
+ if not backend.allclose(a_params[key], b_params[key]):
1155
+ return False
1156
+ else:
1157
+ if a_params[key] != b_params[key]:
1158
+ return False
1159
+
1160
+ return True
1161
+
1162
+
1163
+ def lognormal_dist_from_mean_std(
1164
+ mean: float | Sequence[float], std: float | Sequence[float]
1165
+ ) -> backend.tfd.LogNormal:
1166
+ """Define a lognormal distribution from its mean and standard deviation.
1167
+
1168
+ This function parameterizes lognormal distributions by their mean and
1169
+ standard deviation.
1170
+
1171
+ Args:
1172
+ mean: A float or array-like object defining the distribution mean. Must be
1173
+ positive.
1174
+ std: A float or array-like object defining the distribution standard
1175
+ deviation. Must be non-negative.
1176
+
1177
+ Returns:
1178
+ A `backend.tfd.LogNormal` object with the input mean and standard deviation.
1179
+ """
1180
+
1181
+ mean = np.asarray(mean)
1182
+ std = np.asarray(std)
1183
+
1184
+ mu = np.log(mean) - 0.5 * np.log((std / mean) ** 2 + 1)
1185
+ sigma = np.sqrt(np.log((std / mean) ** 2 + 1))
1186
+
1187
+ return backend.tfd.LogNormal(mu, sigma)
1188
+
1189
+
1190
+ def lognormal_dist_from_range(
1191
+ low: float | Sequence[float],
1192
+ high: float | Sequence[float],
1193
+ mass_percent: float | Sequence[float] = 0.95,
1194
+ ) -> backend.tfd.LogNormal:
1195
+ """Define a LogNormal distribution from a specified range.
1196
+
1197
+ This function parameterizes lognormal distributions by the bounds of a range,
1198
+ so that the specificed probability mass falls within the bounds defined by
1199
+ `low` and `high`. The probability mass is symmetric about the median. For
1200
+ example, to define a lognormal distribution with a 95% probability mass of
1201
+ (1, 10), use:
1202
+
1203
+ ```python
1204
+ lognormal = lognormal_dist_from_range(1.0, 10.0, mass_percent=0.95)
1205
+ ```
1206
+
1207
+ Args:
1208
+ low: Float or array-like denoting the lower bound of the range. Values must
1209
+ be non-negative.
1210
+ high: Float or array-like denoting the upper bound of range. Values must be
1211
+ non-negative.
1212
+ mass_percent: Float or array-like denoting the probability mass. Values must
1213
+ be between 0 and 1 (exlusive). Default: 0.95.
1214
+
1215
+ Returns:
1216
+ A `backend.tfd.LogNormal` object with the input percentage mass falling
1217
+ within the given range.
1218
+ """
1219
+ low = np.asarray(low)
1220
+ high = np.asarray(high)
1221
+ mass_percent = np.asarray(mass_percent)
1222
+
1223
+ if not ((0.0 < low).all() and (low < high).all()): # pytype: disable=attribute-error
1224
+ raise ValueError("'low' and 'high' values must be non-negative and satisfy "
1225
+ "high > low.")
1226
+
1227
+ if not ((0.0 < mass_percent).all() and (mass_percent < 1.0).all()): # pytype: disable=attribute-error
1228
+ raise ValueError(
1229
+ "'mass_percent' values must be between 0 and 1, exclusive."
1230
+ )
1231
+
1232
+ normal = backend.tfd.Normal(0, 1)
1233
+ mass_lower = 0.5 - (mass_percent / 2)
1234
+ mass_upper = 0.5 + (mass_percent / 2)
1235
+
1236
+ sigma = np.log(high / low) / (
1237
+ normal.quantile(mass_upper) - normal.quantile(mass_lower)
1238
+ )
1239
+ mu = np.log(high) - normal.quantile(mass_upper) * sigma
1240
+
1241
+ return backend.tfd.LogNormal(mu, sigma)
1242
+
1243
+
895
1244
  def _convert_to_deterministic_0_distribution(
896
- distribution: tfp.distributions.Distribution,
897
- ) -> tfp.distributions.Distribution:
1245
+ distribution: backend.tfd.Distribution,
1246
+ ) -> backend.tfd.Distribution:
898
1247
  """Converts the given distribution to a `Deterministic(0)` one.
899
1248
 
900
1249
  Args:
@@ -909,7 +1258,7 @@ def _convert_to_deterministic_0_distribution(
909
1258
  distribution.
910
1259
  """
911
1260
  if (
912
- not isinstance(distribution, tfp.distributions.Deterministic)
1261
+ not isinstance(distribution, backend.tfd.Deterministic)
913
1262
  or distribution.loc != 0
914
1263
  ):
915
1264
  warnings.warn(
@@ -917,7 +1266,7 @@ def _convert_to_deterministic_0_distribution(
917
1266
  f' for national models. {distribution.name} has been automatically set'
918
1267
  ' to Deterministic(0).'
919
1268
  )
920
- return tfp.distributions.Deterministic(loc=0, name=distribution.name)
1269
+ return backend.tfd.Deterministic(loc=0, name=distribution.name)
921
1270
  else:
922
1271
  return distribution
923
1272
 
@@ -928,7 +1277,7 @@ def _get_total_media_contribution_prior(
928
1277
  name: str,
929
1278
  p_mean: float = constants.P_MEAN,
930
1279
  p_sd: float = constants.P_SD,
931
- ) -> tfp.distributions.Distribution:
1280
+ ) -> backend.tfd.Distribution:
932
1281
  """Determines ROI priors based on total media contribution.
933
1282
 
934
1283
  Args:
@@ -945,34 +1294,120 @@ def _get_total_media_contribution_prior(
945
1294
  """
946
1295
  roi_mean = p_mean * kpi / np.sum(total_spend)
947
1296
  roi_sd = p_sd * kpi / np.sqrt(np.sum(np.power(total_spend, 2)))
948
- lognormal_sigma = tf.cast(
949
- np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=tf.float32
1297
+ lognormal_sigma = backend.cast(
1298
+ np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=backend.float32
950
1299
  )
951
- lognormal_mu = tf.cast(
952
- np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)), dtype=tf.float32
1300
+ lognormal_mu = backend.cast(
1301
+ np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)),
1302
+ dtype=backend.float32,
953
1303
  )
954
- return tfp.distributions.LogNormal(lognormal_mu, lognormal_sigma, name=name)
1304
+ return backend.tfd.LogNormal(lognormal_mu, lognormal_sigma, name=name)
955
1305
 
956
1306
 
957
- def distributions_are_equal(
958
- a: tfp.distributions.Distribution, b: tfp.distributions.Distribution
959
- ) -> bool:
960
- """Determine if two distributions are equal."""
961
- if type(a) != type(b): # pylint: disable=unidiomatic-typecheck
962
- return False
1307
+ def _validate_support(
1308
+ parameter_name: str,
1309
+ tfp_dist: backend.tfp.distributions.Distribution,
1310
+ bounds: tuple[float, float],
1311
+ prevent_deterministic_prior_at_bounds: tuple[bool, bool],
1312
+ ) -> None:
1313
+ """Validates that distribution support is within the parameter bounds.
963
1314
 
964
- a_params = a.parameters.copy()
965
- b_params = b.parameters.copy()
1315
+ Args:
1316
+ parameter_name: Name of the parameter.
1317
+ tfp_dist: The TFP distribution to validate.
1318
+ bounds: Tuple containing the min and max values of the parameteter space.
1319
+ prevent_deterministic_prior_at_bounds: Tuple of two booleans indicating
1320
+ whether a deterministic prior is allowed at the lower and upper bounds,
1321
+ respectively.
966
1322
 
967
- if constants.DISTRIBUTION in a_params and constants.DISTRIBUTION in b_params:
968
- if not distributions_are_equal(
969
- a_params[constants.DISTRIBUTION], b_params[constants.DISTRIBUTION]
970
- ):
971
- return False
972
- del a_params[constants.DISTRIBUTION]
973
- del b_params[constants.DISTRIBUTION]
1323
+ Raises:
1324
+ ValueError: If the distribution support is not within the parameter bounds.
1325
+ """
1326
+ # Note that `tfp.distributions.BatchBroadcast` objects have a `distribution`
1327
+ # attribute that points to a `tfp.distributions.Distribution` object.
1328
+ if isinstance(tfp_dist, backend.tfd.BatchBroadcast):
1329
+ tfp_dist = tfp_dist.distribution
1330
+ # Note that `tfp.distributions.Deterministic` does not have a `quantile`
1331
+ # method implemented, so the min and max values must be extracted from the
1332
+ # `loc` attribute instead.
1333
+ if isinstance(tfp_dist, backend.tfd.Deterministic):
1334
+ support_min_vals = tfp_dist.loc
1335
+ support_max_vals = tfp_dist.loc
1336
+ for i in (0, 1):
1337
+ if prevent_deterministic_prior_at_bounds[i] and np.any(
1338
+ tfp_dist.loc == bounds[i]
1339
+ ):
1340
+ raise ValueError(
1341
+ f'{parameter_name} was assigned a point mass (deterministic) prior'
1342
+ f' at {bounds[i]}, which is not allowed.'
1343
+ )
1344
+ else:
1345
+ try:
1346
+ support_min_vals = tfp_dist.quantile(0)
1347
+ support_max_vals = tfp_dist.quantile(1)
1348
+ except (AttributeError, NotImplementedError):
1349
+ warnings.warn(
1350
+ f'The prior distribution for {parameter_name} does not have a'
1351
+ ' `quantile` method implemented, so the support range validation'
1352
+ f' was skipped. Confirm that your prior for {parameter_name} is'
1353
+ ' appropriate.'
1354
+ )
1355
+ return
1356
+ if np.any(support_min_vals < bounds[0]):
1357
+ raise ValueError(
1358
+ f'{parameter_name} was assigned a prior distribution that allows values'
1359
+ f' less than the parameter minimum {bounds[0]}.'
1360
+ )
1361
+ if np.any(support_max_vals > bounds[1]):
1362
+ raise ValueError(
1363
+ f'{parameter_name} was assigned a prior distribution that allows values'
1364
+ f' greater than the parameter maximum {bounds[1]}.'
1365
+ )
974
1366
 
975
- if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
976
- return False
977
1367
 
978
- return a_params == b_params
1368
+ # Dictionary of parameters that have a limited parameters space. The tuple
1369
+ # contains the lower and upper bounds, respectively.
1370
+ _parameter_space_bounds = {
1371
+ 'eta_m': (0, np.inf),
1372
+ 'eta_rf': (0, np.inf),
1373
+ 'eta_om': (0, np.inf),
1374
+ 'eta_orf': (0, np.inf),
1375
+ 'xi_c': (0, np.inf),
1376
+ 'xi_n': (0, np.inf),
1377
+ 'alpha_m': (0, 1),
1378
+ 'alpha_rf': (0, 1),
1379
+ 'alpha_om': (0, 1),
1380
+ 'alpha_orf': (0, 1),
1381
+ 'ec_m': (0, np.inf),
1382
+ 'ec_rf': (0, np.inf),
1383
+ 'ec_om': (0, np.inf),
1384
+ 'ec_orf': (0, np.inf),
1385
+ 'slope_m': (0, np.inf),
1386
+ 'slope_rf': (0, np.inf),
1387
+ 'slope_om': (0, np.inf),
1388
+ 'slope_orf': (0, np.inf),
1389
+ 'sigma': (0, np.inf),
1390
+ }
1391
+
1392
+ # Dictionary of parameters that do not allow a deterministic prior at one or
1393
+ # more of the parameter space bounds. The boolean tuple indicates whether a
1394
+ # deterministic prior is allowed at the lower bound or upper bound,
1395
+ # respectively, where `True` means "not allowed". This check is specifically for
1396
+ # point mass at finite paramteter space bounds, since point mass at infinity is
1397
+ # generally problematic for all parameters. Note that `sigma` should generally
1398
+ # not have point mass at zero, but this is not checked here because unit tests
1399
+ # require the ability to simulate data with `sigma` set to zero.
1400
+ _prevent_deterministic_prior_at_bounds = {
1401
+ 'alpha_m': (False, True),
1402
+ 'alpha_rf': (False, True),
1403
+ 'alpha_om': (False, True),
1404
+ 'alpha_orf': (False, True),
1405
+ 'ec_m': (True, False),
1406
+ 'ec_rf': (True, False),
1407
+ 'ec_om': (True, False),
1408
+ 'ec_orf': (True, False),
1409
+ 'slope_m': (True, False),
1410
+ 'slope_rf': (True, False),
1411
+ 'slope_om': (True, False),
1412
+ 'slope_orf': (True, False),
1413
+ }