google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,18 +19,21 @@ used by the Meridian model object.
19
19
  """
20
20
 
21
21
  from __future__ import annotations
22
- from collections.abc import MutableMapping
22
+ from collections.abc import MutableMapping, Sequence
23
23
  import dataclasses
24
24
  from typing import Any
25
25
  import warnings
26
+
27
+ from meridian import backend
26
28
  from meridian import constants
29
+
27
30
  import numpy as np
28
- import tensorflow as tf
29
- import tensorflow_probability as tfp
30
31
 
31
32
 
32
33
  __all__ = [
34
+ 'IndependentMultivariateDistribution',
33
35
  'PriorDistribution',
36
+ 'distributions_are_equal',
34
37
  ]
35
38
 
36
39
 
@@ -203,9 +206,9 @@ class PriorDistribution:
203
206
  distribution is `HalfNormal(5.0)`.
204
207
  roi_m: Prior distribution on the ROI of each media channel. This parameter
205
208
  is only used when `paid_media_prior_type` is `'roi'`, in which case
206
- `beta_m` is calculated as a deterministic function of `roi_rf`,
207
- `alpha_rf`, `ec_rf`, `slope_rf`, and the spend associated with each media
208
- channel. Default distribution is `LogNormal(0.2, 0.9)`. When `kpi_type` is
209
+ `beta_m` is calculated as a deterministic function of `roi_m`, `alpha_m`,
210
+ `ec_m`, `slope_m`, and the spend associated with each media channel.
211
+ Default distribution is `LogNormal(0.2, 0.9)`. When `kpi_type` is
209
212
  `'non_revenue'` and `revenue_per_kpi` is not provided, ROI is interpreted
210
213
  as incremental KPI units per monetary unit spent. In this case, the
211
214
  default value for `roi_m` and `roi_rf` will be ignored and a common ROI
@@ -270,196 +273,202 @@ class PriorDistribution:
270
273
  `TruncatedNormal(0.0, 0.1, -1.0, 1.0)`.
271
274
  """
272
275
 
273
- knot_values: tfp.distributions.Distribution = dataclasses.field(
274
- default_factory=lambda: tfp.distributions.Normal(
276
+ knot_values: backend.tfd.Distribution = dataclasses.field(
277
+ default_factory=lambda: backend.tfd.Normal(
275
278
  0.0, 5.0, name=constants.KNOT_VALUES
276
279
  ),
277
280
  )
278
- tau_g_excl_baseline: tfp.distributions.Distribution = dataclasses.field(
279
- default_factory=lambda: tfp.distributions.Normal(
281
+ tau_g_excl_baseline: backend.tfd.Distribution = dataclasses.field(
282
+ default_factory=lambda: backend.tfd.Normal(
280
283
  0.0, 5.0, name=constants.TAU_G_EXCL_BASELINE
281
284
  ),
282
285
  )
283
- beta_m: tfp.distributions.Distribution = dataclasses.field(
284
- default_factory=lambda: tfp.distributions.HalfNormal(
286
+ beta_m: backend.tfd.Distribution = dataclasses.field(
287
+ default_factory=lambda: backend.tfd.HalfNormal(
285
288
  5.0, name=constants.BETA_M
286
289
  ),
287
290
  )
288
- beta_rf: tfp.distributions.Distribution = dataclasses.field(
289
- default_factory=lambda: tfp.distributions.HalfNormal(
291
+ beta_rf: backend.tfd.Distribution = dataclasses.field(
292
+ default_factory=lambda: backend.tfd.HalfNormal(
290
293
  5.0, name=constants.BETA_RF
291
294
  ),
292
295
  )
293
- beta_om: tfp.distributions.Distribution = dataclasses.field(
294
- default_factory=lambda: tfp.distributions.HalfNormal(
296
+ beta_om: backend.tfd.Distribution = dataclasses.field(
297
+ default_factory=lambda: backend.tfd.HalfNormal(
295
298
  5.0, name=constants.BETA_OM
296
299
  ),
297
300
  )
298
- beta_orf: tfp.distributions.Distribution = dataclasses.field(
299
- default_factory=lambda: tfp.distributions.HalfNormal(
301
+ beta_orf: backend.tfd.Distribution = dataclasses.field(
302
+ default_factory=lambda: backend.tfd.HalfNormal(
300
303
  5.0, name=constants.BETA_ORF
301
304
  ),
302
305
  )
303
- eta_m: tfp.distributions.Distribution = dataclasses.field(
304
- default_factory=lambda: tfp.distributions.HalfNormal(
305
- 1.0, name=constants.ETA_M
306
- ),
306
+ eta_m: backend.tfd.Distribution = dataclasses.field(
307
+ default_factory=lambda: backend.tfd.HalfNormal(1.0, name=constants.ETA_M),
307
308
  )
308
- eta_rf: tfp.distributions.Distribution = dataclasses.field(
309
- default_factory=lambda: tfp.distributions.HalfNormal(
309
+ eta_rf: backend.tfd.Distribution = dataclasses.field(
310
+ default_factory=lambda: backend.tfd.HalfNormal(
310
311
  1.0, name=constants.ETA_RF
311
312
  ),
312
313
  )
313
- eta_om: tfp.distributions.Distribution = dataclasses.field(
314
- default_factory=lambda: tfp.distributions.HalfNormal(
314
+ eta_om: backend.tfd.Distribution = dataclasses.field(
315
+ default_factory=lambda: backend.tfd.HalfNormal(
315
316
  1.0, name=constants.ETA_OM
316
317
  ),
317
318
  )
318
- eta_orf: tfp.distributions.Distribution = dataclasses.field(
319
- default_factory=lambda: tfp.distributions.HalfNormal(
319
+ eta_orf: backend.tfd.Distribution = dataclasses.field(
320
+ default_factory=lambda: backend.tfd.HalfNormal(
320
321
  1.0, name=constants.ETA_ORF
321
322
  ),
322
323
  )
323
- gamma_c: tfp.distributions.Distribution = dataclasses.field(
324
- default_factory=lambda: tfp.distributions.Normal(
324
+ gamma_c: backend.tfd.Distribution = dataclasses.field(
325
+ default_factory=lambda: backend.tfd.Normal(
325
326
  0.0, 5.0, name=constants.GAMMA_C
326
327
  ),
327
328
  )
328
- gamma_n: tfp.distributions.Distribution = dataclasses.field(
329
- default_factory=lambda: tfp.distributions.Normal(
329
+ gamma_n: backend.tfd.Distribution = dataclasses.field(
330
+ default_factory=lambda: backend.tfd.Normal(
330
331
  0.0, 5.0, name=constants.GAMMA_N
331
332
  ),
332
333
  )
333
- xi_c: tfp.distributions.Distribution = dataclasses.field(
334
- default_factory=lambda: tfp.distributions.HalfNormal(
335
- 5.0, name=constants.XI_C
336
- ),
334
+ xi_c: backend.tfd.Distribution = dataclasses.field(
335
+ default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_C),
337
336
  )
338
- xi_n: tfp.distributions.Distribution = dataclasses.field(
339
- default_factory=lambda: tfp.distributions.HalfNormal(
340
- 5.0, name=constants.XI_N
341
- ),
337
+ xi_n: backend.tfd.Distribution = dataclasses.field(
338
+ default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.XI_N),
342
339
  )
343
- alpha_m: tfp.distributions.Distribution = dataclasses.field(
344
- default_factory=lambda: tfp.distributions.Uniform(
340
+ alpha_m: backend.tfd.Distribution = dataclasses.field(
341
+ default_factory=lambda: backend.tfd.Uniform(
345
342
  0.0, 1.0, name=constants.ALPHA_M
346
343
  ),
347
344
  )
348
- alpha_rf: tfp.distributions.Distribution = dataclasses.field(
349
- default_factory=lambda: tfp.distributions.Uniform(
345
+ alpha_rf: backend.tfd.Distribution = dataclasses.field(
346
+ default_factory=lambda: backend.tfd.Uniform(
350
347
  0.0, 1.0, name=constants.ALPHA_RF
351
348
  ),
352
349
  )
353
- alpha_om: tfp.distributions.Distribution = dataclasses.field(
354
- default_factory=lambda: tfp.distributions.Uniform(
350
+ alpha_om: backend.tfd.Distribution = dataclasses.field(
351
+ default_factory=lambda: backend.tfd.Uniform(
355
352
  0.0, 1.0, name=constants.ALPHA_OM
356
353
  ),
357
354
  )
358
- alpha_orf: tfp.distributions.Distribution = dataclasses.field(
359
- default_factory=lambda: tfp.distributions.Uniform(
355
+ alpha_orf: backend.tfd.Distribution = dataclasses.field(
356
+ default_factory=lambda: backend.tfd.Uniform(
360
357
  0.0, 1.0, name=constants.ALPHA_ORF
361
358
  ),
362
359
  )
363
- ec_m: tfp.distributions.Distribution = dataclasses.field(
364
- default_factory=lambda: tfp.distributions.TruncatedNormal(
360
+ ec_m: backend.tfd.Distribution = dataclasses.field(
361
+ default_factory=lambda: backend.tfd.TruncatedNormal(
365
362
  0.8, 0.8, 0.1, 10, name=constants.EC_M
366
363
  ),
367
364
  )
368
- ec_rf: tfp.distributions.Distribution = dataclasses.field(
369
- default_factory=lambda: tfp.distributions.TransformedDistribution(
370
- tfp.distributions.LogNormal(0.7, 0.4),
371
- tfp.bijectors.Shift(0.1),
365
+ ec_rf: backend.tfd.Distribution = dataclasses.field(
366
+ default_factory=lambda: backend.tfd.TransformedDistribution(
367
+ backend.tfd.LogNormal(0.7, 0.4),
368
+ backend.bijectors.Shift(0.1),
372
369
  name=constants.EC_RF,
373
370
  ),
374
371
  )
375
- ec_om: tfp.distributions.Distribution = dataclasses.field(
376
- default_factory=lambda: tfp.distributions.TruncatedNormal(
372
+ ec_om: backend.tfd.Distribution = dataclasses.field(
373
+ default_factory=lambda: backend.tfd.TruncatedNormal(
377
374
  0.8, 0.8, 0.1, 10, name=constants.EC_OM
378
375
  ),
379
376
  )
380
- ec_orf: tfp.distributions.Distribution = dataclasses.field(
381
- default_factory=lambda: tfp.distributions.TransformedDistribution(
382
- tfp.distributions.LogNormal(0.7, 0.4),
383
- tfp.bijectors.Shift(0.1),
377
+ ec_orf: backend.tfd.Distribution = dataclasses.field(
378
+ default_factory=lambda: backend.tfd.TransformedDistribution(
379
+ backend.tfd.LogNormal(0.7, 0.4),
380
+ backend.bijectors.Shift(0.1),
384
381
  name=constants.EC_ORF,
385
382
  ),
386
383
  )
387
- slope_m: tfp.distributions.Distribution = dataclasses.field(
388
- default_factory=lambda: tfp.distributions.Deterministic(
384
+ slope_m: backend.tfd.Distribution = dataclasses.field(
385
+ default_factory=lambda: backend.tfd.Deterministic(
389
386
  1.0, name=constants.SLOPE_M
390
387
  ),
391
388
  )
392
- slope_rf: tfp.distributions.Distribution = dataclasses.field(
393
- default_factory=lambda: tfp.distributions.LogNormal(
389
+ slope_rf: backend.tfd.Distribution = dataclasses.field(
390
+ default_factory=lambda: backend.tfd.LogNormal(
394
391
  0.7, 0.4, name=constants.SLOPE_RF
395
392
  ),
396
393
  )
397
- slope_om: tfp.distributions.Distribution = dataclasses.field(
398
- default_factory=lambda: tfp.distributions.Deterministic(
394
+ slope_om: backend.tfd.Distribution = dataclasses.field(
395
+ default_factory=lambda: backend.tfd.Deterministic(
399
396
  1.0, name=constants.SLOPE_OM
400
397
  ),
401
398
  )
402
- slope_orf: tfp.distributions.Distribution = dataclasses.field(
403
- default_factory=lambda: tfp.distributions.LogNormal(
399
+ slope_orf: backend.tfd.Distribution = dataclasses.field(
400
+ default_factory=lambda: backend.tfd.LogNormal(
404
401
  0.7, 0.4, name=constants.SLOPE_ORF
405
402
  ),
406
403
  )
407
- sigma: tfp.distributions.Distribution = dataclasses.field(
408
- default_factory=lambda: tfp.distributions.HalfNormal(
409
- 5.0, name=constants.SIGMA
410
- ),
404
+ sigma: backend.tfd.Distribution = dataclasses.field(
405
+ default_factory=lambda: backend.tfd.HalfNormal(5.0, name=constants.SIGMA),
411
406
  )
412
- roi_m: tfp.distributions.Distribution = dataclasses.field(
413
- default_factory=lambda: tfp.distributions.LogNormal(
407
+ roi_m: backend.tfd.Distribution = dataclasses.field(
408
+ default_factory=lambda: backend.tfd.LogNormal(
414
409
  0.2, 0.9, name=constants.ROI_M
415
410
  ),
416
411
  )
417
- roi_rf: tfp.distributions.Distribution = dataclasses.field(
418
- default_factory=lambda: tfp.distributions.LogNormal(
412
+ roi_rf: backend.tfd.Distribution = dataclasses.field(
413
+ default_factory=lambda: backend.tfd.LogNormal(
419
414
  0.2, 0.9, name=constants.ROI_RF
420
415
  ),
421
416
  )
422
- mroi_m: tfp.distributions.Distribution = dataclasses.field(
423
- default_factory=lambda: tfp.distributions.LogNormal(
417
+ mroi_m: backend.tfd.Distribution = dataclasses.field(
418
+ default_factory=lambda: backend.tfd.LogNormal(
424
419
  0.0, 0.5, name=constants.MROI_M
425
420
  ),
426
421
  )
427
- mroi_rf: tfp.distributions.Distribution = dataclasses.field(
428
- default_factory=lambda: tfp.distributions.LogNormal(
422
+ mroi_rf: backend.tfd.Distribution = dataclasses.field(
423
+ default_factory=lambda: backend.tfd.LogNormal(
429
424
  0.0, 0.5, name=constants.MROI_RF
430
425
  ),
431
426
  )
432
- contribution_m: tfp.distributions.Distribution = dataclasses.field(
433
- default_factory=lambda: tfp.distributions.Beta(
427
+ contribution_m: backend.tfd.Distribution = dataclasses.field(
428
+ default_factory=lambda: backend.tfd.Beta(
434
429
  1.0, 99.0, name=constants.CONTRIBUTION_M
435
430
  ),
436
431
  )
437
- contribution_rf: tfp.distributions.Distribution = dataclasses.field(
438
- default_factory=lambda: tfp.distributions.Beta(
432
+ contribution_rf: backend.tfd.Distribution = dataclasses.field(
433
+ default_factory=lambda: backend.tfd.Beta(
439
434
  1.0, 99.0, name=constants.CONTRIBUTION_RF
440
435
  ),
441
436
  )
442
- contribution_om: tfp.distributions.Distribution = dataclasses.field(
443
- default_factory=lambda: tfp.distributions.Beta(
437
+ contribution_om: backend.tfd.Distribution = dataclasses.field(
438
+ default_factory=lambda: backend.tfd.Beta(
444
439
  1.0, 99.0, name=constants.CONTRIBUTION_OM
445
440
  ),
446
441
  )
447
- contribution_orf: tfp.distributions.Distribution = dataclasses.field(
448
- default_factory=lambda: tfp.distributions.Beta(
442
+ contribution_orf: backend.tfd.Distribution = dataclasses.field(
443
+ default_factory=lambda: backend.tfd.Beta(
449
444
  1.0, 99.0, name=constants.CONTRIBUTION_ORF
450
445
  ),
451
446
  )
452
- contribution_n: tfp.distributions.Distribution = dataclasses.field(
453
- default_factory=lambda: tfp.distributions.TruncatedNormal(
447
+ contribution_n: backend.tfd.Distribution = dataclasses.field(
448
+ default_factory=lambda: backend.tfd.TruncatedNormal(
454
449
  loc=0.0, scale=0.1, low=-1.0, high=1.0, name=constants.CONTRIBUTION_N
455
450
  ),
456
451
  )
457
452
 
453
+ def __post_init__(self):
454
+ for param, bounds in _parameter_space_bounds.items():
455
+ prevent_deterministic_prior_at_bounds = (
456
+ _prevent_deterministic_prior_at_bounds[param]
457
+ if param in _prevent_deterministic_prior_at_bounds.keys()
458
+ else (False, False)
459
+ )
460
+ _validate_support(
461
+ param,
462
+ getattr(self, param),
463
+ bounds,
464
+ prevent_deterministic_prior_at_bounds,
465
+ )
466
+
458
467
  def __setstate__(self, state):
459
468
  # Override to support pickling.
460
469
  def _unpack_distribution_params(
461
470
  params: MutableMapping[str, Any],
462
- ) -> tfp.distributions.Distribution:
471
+ ) -> backend.tfd.Distribution:
463
472
  if constants.DISTRIBUTION in params:
464
473
  params[constants.DISTRIBUTION] = _unpack_distribution_params(
465
474
  params[constants.DISTRIBUTION]
@@ -478,7 +487,7 @@ class PriorDistribution:
478
487
  state = self.__dict__.copy()
479
488
 
480
489
  def _pack_distribution_params(
481
- dist: tfp.distributions.Distribution,
490
+ dist: backend.tfd.Distribution,
482
491
  ) -> MutableMapping[str, Any]:
483
492
  params = dist.parameters
484
493
  params[constants.DISTRIBUTION_TYPE] = type(dist)
@@ -493,11 +502,9 @@ class PriorDistribution:
493
502
 
494
503
  return state
495
504
 
496
- def has_deterministic_param(
497
- self, param: tfp.distributions.Distribution
498
- ) -> bool:
505
+ def has_deterministic_param(self, param: backend.tfd.Distribution) -> bool:
499
506
  return hasattr(self, param) and isinstance(
500
- getattr(self, param).distribution, tfp.distributions.Deterministic
507
+ getattr(self, param).distribution, backend.tfd.Deterministic
501
508
  )
502
509
 
503
510
  def broadcast(
@@ -550,7 +557,7 @@ class PriorDistribution:
550
557
  """
551
558
 
552
559
  def _validate_media_custom_priors(
553
- param: tfp.distributions.Distribution,
560
+ param: backend.tfd.Distribution,
554
561
  ) -> None:
555
562
  if (
556
563
  param.batch_shape.as_list()
@@ -573,7 +580,7 @@ class PriorDistribution:
573
580
  _validate_media_custom_priors(self.beta_m)
574
581
 
575
582
  def _validate_organic_media_custom_priors(
576
- param: tfp.distributions.Distribution,
583
+ param: backend.tfd.Distribution,
577
584
  ) -> None:
578
585
  if (
579
586
  param.batch_shape.as_list()
@@ -595,7 +602,7 @@ class PriorDistribution:
595
602
  _validate_organic_media_custom_priors(self.beta_om)
596
603
 
597
604
  def _validate_organic_rf_custom_priors(
598
- param: tfp.distributions.Distribution,
605
+ param: backend.tfd.Distribution,
599
606
  ) -> None:
600
607
  if (
601
608
  param.batch_shape.as_list()
@@ -617,7 +624,7 @@ class PriorDistribution:
617
624
  _validate_organic_rf_custom_priors(self.beta_orf)
618
625
 
619
626
  def _validate_rf_custom_priors(
620
- param: tfp.distributions.Distribution,
627
+ param: backend.tfd.Distribution,
621
628
  ) -> None:
622
629
  if param.batch_shape.as_list() and n_rf_channels != param.batch_shape[0]:
623
630
  raise ValueError(
@@ -637,7 +644,7 @@ class PriorDistribution:
637
644
  _validate_rf_custom_priors(self.beta_rf)
638
645
 
639
646
  def _validate_control_custom_priors(
640
- param: tfp.distributions.Distribution,
647
+ param: backend.tfd.Distribution,
641
648
  ) -> None:
642
649
  if param.batch_shape.as_list() and n_controls != param.batch_shape[0]:
643
650
  raise ValueError(
@@ -651,7 +658,7 @@ class PriorDistribution:
651
658
  _validate_control_custom_priors(self.xi_c)
652
659
 
653
660
  def _validate_non_media_custom_priors(
654
- param: tfp.distributions.Distribution,
661
+ param: backend.tfd.Distribution,
655
662
  ) -> None:
656
663
  if (
657
664
  param.batch_shape.as_list()
@@ -669,7 +676,7 @@ class PriorDistribution:
669
676
  _validate_non_media_custom_priors(self.gamma_n)
670
677
  _validate_non_media_custom_priors(self.xi_n)
671
678
 
672
- knot_values = tfp.distributions.BatchBroadcast(
679
+ knot_values = backend.tfd.BatchBroadcast(
673
680
  self.knot_values,
674
681
  n_knots,
675
682
  name=constants.KNOT_VALUES,
@@ -680,19 +687,19 @@ class PriorDistribution:
680
687
  )
681
688
  else:
682
689
  tau_g_converted = self.tau_g_excl_baseline
683
- tau_g_excl_baseline = tfp.distributions.BatchBroadcast(
690
+ tau_g_excl_baseline = backend.tfd.BatchBroadcast(
684
691
  tau_g_converted, n_geos - 1, name=constants.TAU_G_EXCL_BASELINE
685
692
  )
686
- beta_m = tfp.distributions.BatchBroadcast(
693
+ beta_m = backend.tfd.BatchBroadcast(
687
694
  self.beta_m, n_media_channels, name=constants.BETA_M
688
695
  )
689
- beta_rf = tfp.distributions.BatchBroadcast(
696
+ beta_rf = backend.tfd.BatchBroadcast(
690
697
  self.beta_rf, n_rf_channels, name=constants.BETA_RF
691
698
  )
692
- beta_om = tfp.distributions.BatchBroadcast(
699
+ beta_om = backend.tfd.BatchBroadcast(
693
700
  self.beta_om, n_organic_media_channels, name=constants.BETA_OM
694
701
  )
695
- beta_orf = tfp.distributions.BatchBroadcast(
702
+ beta_orf = backend.tfd.BatchBroadcast(
696
703
  self.beta_orf, n_organic_rf_channels, name=constants.BETA_ORF
697
704
  )
698
705
  if is_national:
@@ -705,66 +712,66 @@ class PriorDistribution:
705
712
  eta_rf_converted = self.eta_rf
706
713
  eta_om_converted = self.eta_om
707
714
  eta_orf_converted = self.eta_orf
708
- eta_m = tfp.distributions.BatchBroadcast(
715
+ eta_m = backend.tfd.BatchBroadcast(
709
716
  eta_m_converted, n_media_channels, name=constants.ETA_M
710
717
  )
711
- eta_rf = tfp.distributions.BatchBroadcast(
718
+ eta_rf = backend.tfd.BatchBroadcast(
712
719
  eta_rf_converted, n_rf_channels, name=constants.ETA_RF
713
720
  )
714
- eta_om = tfp.distributions.BatchBroadcast(
721
+ eta_om = backend.tfd.BatchBroadcast(
715
722
  eta_om_converted,
716
723
  n_organic_media_channels,
717
724
  name=constants.ETA_OM,
718
725
  )
719
- eta_orf = tfp.distributions.BatchBroadcast(
726
+ eta_orf = backend.tfd.BatchBroadcast(
720
727
  eta_orf_converted, n_organic_rf_channels, name=constants.ETA_ORF
721
728
  )
722
- gamma_c = tfp.distributions.BatchBroadcast(
729
+ gamma_c = backend.tfd.BatchBroadcast(
723
730
  self.gamma_c, n_controls, name=constants.GAMMA_C
724
731
  )
725
732
  if is_national:
726
733
  xi_c_converted = _convert_to_deterministic_0_distribution(self.xi_c)
727
734
  else:
728
735
  xi_c_converted = self.xi_c
729
- xi_c = tfp.distributions.BatchBroadcast(
736
+ xi_c = backend.tfd.BatchBroadcast(
730
737
  xi_c_converted, n_controls, name=constants.XI_C
731
738
  )
732
- gamma_n = tfp.distributions.BatchBroadcast(
739
+ gamma_n = backend.tfd.BatchBroadcast(
733
740
  self.gamma_n, n_non_media_channels, name=constants.GAMMA_N
734
741
  )
735
742
  if is_national:
736
743
  xi_n_converted = _convert_to_deterministic_0_distribution(self.xi_n)
737
744
  else:
738
745
  xi_n_converted = self.xi_n
739
- xi_n = tfp.distributions.BatchBroadcast(
746
+ xi_n = backend.tfd.BatchBroadcast(
740
747
  xi_n_converted, n_non_media_channels, name=constants.XI_N
741
748
  )
742
- alpha_m = tfp.distributions.BatchBroadcast(
749
+ alpha_m = backend.tfd.BatchBroadcast(
743
750
  self.alpha_m, n_media_channels, name=constants.ALPHA_M
744
751
  )
745
- alpha_rf = tfp.distributions.BatchBroadcast(
752
+ alpha_rf = backend.tfd.BatchBroadcast(
746
753
  self.alpha_rf, n_rf_channels, name=constants.ALPHA_RF
747
754
  )
748
- alpha_om = tfp.distributions.BatchBroadcast(
755
+ alpha_om = backend.tfd.BatchBroadcast(
749
756
  self.alpha_om, n_organic_media_channels, name=constants.ALPHA_OM
750
757
  )
751
- alpha_orf = tfp.distributions.BatchBroadcast(
758
+ alpha_orf = backend.tfd.BatchBroadcast(
752
759
  self.alpha_orf, n_organic_rf_channels, name=constants.ALPHA_ORF
753
760
  )
754
- ec_m = tfp.distributions.BatchBroadcast(
761
+ ec_m = backend.tfd.BatchBroadcast(
755
762
  self.ec_m, n_media_channels, name=constants.EC_M
756
763
  )
757
- ec_rf = tfp.distributions.BatchBroadcast(
764
+ ec_rf = backend.tfd.BatchBroadcast(
758
765
  self.ec_rf, n_rf_channels, name=constants.EC_RF
759
766
  )
760
- ec_om = tfp.distributions.BatchBroadcast(
767
+ ec_om = backend.tfd.BatchBroadcast(
761
768
  self.ec_om, n_organic_media_channels, name=constants.EC_OM
762
769
  )
763
- ec_orf = tfp.distributions.BatchBroadcast(
770
+ ec_orf = backend.tfd.BatchBroadcast(
764
771
  self.ec_orf, n_organic_rf_channels, name=constants.EC_ORF
765
772
  )
766
773
  if (
767
- not isinstance(self.slope_m, tfp.distributions.Deterministic)
774
+ not isinstance(self.slope_m, backend.tfd.Deterministic)
768
775
  or (np.isscalar(self.slope_m.loc.numpy()) and self.slope_m.loc != 1.0)
769
776
  or (
770
777
  self.slope_m.batch_shape.as_list()
@@ -776,14 +783,14 @@ class PriorDistribution:
776
783
  ' This may lead to poor MCMC convergence and budget optimization'
777
784
  ' may no longer produce a global optimum.'
778
785
  )
779
- slope_m = tfp.distributions.BatchBroadcast(
786
+ slope_m = backend.tfd.BatchBroadcast(
780
787
  self.slope_m, n_media_channels, name=constants.SLOPE_M
781
788
  )
782
- slope_rf = tfp.distributions.BatchBroadcast(
789
+ slope_rf = backend.tfd.BatchBroadcast(
783
790
  self.slope_rf, n_rf_channels, name=constants.SLOPE_RF
784
791
  )
785
792
  if (
786
- not isinstance(self.slope_om, tfp.distributions.Deterministic)
793
+ not isinstance(self.slope_om, backend.tfd.Deterministic)
787
794
  or (np.isscalar(self.slope_om.loc.numpy()) and self.slope_om.loc != 1.0)
788
795
  or (
789
796
  self.slope_om.batch_shape.as_list()
@@ -795,16 +802,16 @@ class PriorDistribution:
795
802
  ' This may lead to poor MCMC convergence and budget optimization'
796
803
  ' may no longer produce a global optimum.'
797
804
  )
798
- slope_om = tfp.distributions.BatchBroadcast(
805
+ slope_om = backend.tfd.BatchBroadcast(
799
806
  self.slope_om, n_organic_media_channels, name=constants.SLOPE_OM
800
807
  )
801
- slope_orf = tfp.distributions.BatchBroadcast(
808
+ slope_orf = backend.tfd.BatchBroadcast(
802
809
  self.slope_orf, n_organic_rf_channels, name=constants.SLOPE_ORF
803
810
  )
804
811
 
805
812
  # If `unique_sigma_for_each_geo == False`, then make a scalar batch.
806
813
  sigma_shape = n_geos if (n_geos > 1 and unique_sigma_for_each_geo) else []
807
- sigma = tfp.distributions.BatchBroadcast(
814
+ sigma = backend.tfd.BatchBroadcast(
808
815
  self.sigma, sigma_shape, name=constants.SIGMA
809
816
  )
810
817
 
@@ -818,37 +825,37 @@ class PriorDistribution:
818
825
  else:
819
826
  roi_m_converted = self.roi_m
820
827
  roi_rf_converted = self.roi_rf
821
- roi_m = tfp.distributions.BatchBroadcast(
828
+ roi_m = backend.tfd.BatchBroadcast(
822
829
  roi_m_converted, n_media_channels, name=constants.ROI_M
823
830
  )
824
- roi_rf = tfp.distributions.BatchBroadcast(
831
+ roi_rf = backend.tfd.BatchBroadcast(
825
832
  roi_rf_converted, n_rf_channels, name=constants.ROI_RF
826
833
  )
827
834
 
828
- mroi_m = tfp.distributions.BatchBroadcast(
835
+ mroi_m = backend.tfd.BatchBroadcast(
829
836
  self.mroi_m, n_media_channels, name=constants.MROI_M
830
837
  )
831
- mroi_rf = tfp.distributions.BatchBroadcast(
838
+ mroi_rf = backend.tfd.BatchBroadcast(
832
839
  self.mroi_rf, n_rf_channels, name=constants.MROI_RF
833
840
  )
834
841
 
835
- contribution_m = tfp.distributions.BatchBroadcast(
842
+ contribution_m = backend.tfd.BatchBroadcast(
836
843
  self.contribution_m, n_media_channels, name=constants.CONTRIBUTION_M
837
844
  )
838
- contribution_rf = tfp.distributions.BatchBroadcast(
845
+ contribution_rf = backend.tfd.BatchBroadcast(
839
846
  self.contribution_rf, n_rf_channels, name=constants.CONTRIBUTION_RF
840
847
  )
841
- contribution_om = tfp.distributions.BatchBroadcast(
848
+ contribution_om = backend.tfd.BatchBroadcast(
842
849
  self.contribution_om,
843
850
  n_organic_media_channels,
844
851
  name=constants.CONTRIBUTION_OM,
845
852
  )
846
- contribution_orf = tfp.distributions.BatchBroadcast(
853
+ contribution_orf = backend.tfd.BatchBroadcast(
847
854
  self.contribution_orf,
848
855
  n_organic_rf_channels,
849
856
  name=constants.CONTRIBUTION_ORF,
850
857
  )
851
- contribution_n = tfp.distributions.BatchBroadcast(
858
+ contribution_n = backend.tfd.BatchBroadcast(
852
859
  self.contribution_n, n_non_media_channels, name=constants.CONTRIBUTION_N
853
860
  )
854
861
 
@@ -892,9 +899,283 @@ class PriorDistribution:
892
899
  )
893
900
 
894
901
 
902
+ class IndependentMultivariateDistribution(backend.tfd.Distribution):
903
+ """Container for a joint distribution created from independent distributions.
904
+
905
+ This class is useful when one wants to define a joint distribution for a
906
+ Meridian prior, where the elements are not necessarily from the same
907
+ distribution family. For example, to define a distribution where
908
+ one element is Uniform and the second is triangular:
909
+
910
+ ```python
911
+ distributions = [
912
+ tfp.distributions.Uniform(0.0, 1.0),
913
+ tfp.distributions.Triangular(0.0, 1.0, 0.5)
914
+ ]
915
+ distribution = IndependentMultivariateDistribution(distributions)
916
+ ```
917
+
918
+ It is also possible to define a distribution where multiple elements come
919
+ from the same distribution family. For example, to define a distribution where
920
+ the three elements are LogNormal(0.2, 0.9), LogNormal(0, 0.5) and
921
+ Gamma(2, 2):
922
+
923
+ ```python
924
+ distributions = [
925
+ tfp.distributions.LogNormal([0.2, 0.0], [0.9, 0.5]),
926
+ tfp.distributions.Gamma(2.0, 2.0)
927
+ ]
928
+ distribution = IndependentMultivariateDistribution(distributions)
929
+ ```
930
+
931
+ This class cannot contain instances of `tfd.Deterministic`.
932
+ """
933
+
934
+ def __init__(
935
+ self,
936
+ distributions: Sequence[backend.tfd.Distribution],
937
+ validate_args: bool = False,
938
+ allow_nan_stats: bool = True,
939
+ name: str | None = None,
940
+ ):
941
+ """Initializes a batch of independent distributions from different families.
942
+
943
+ Args:
944
+ distributions: List of `tfd.Distribution` from which to construct a
945
+ multivariate distribution. The distributions must have scalar or one
946
+ dimensional batch shapes; the resulting batch shape will be the sum of
947
+ the underlying batch shapes.
948
+ validate_args: Python `bool`. When `True` distribution parameters are
949
+ checked for validity despite possibly degrading runtime performance.
950
+ When `False` invalid inputs may silently render incorrect outputs.
951
+ Default value is `False`.
952
+ allow_nan_stats: Python `bool`. When `True`, statistics (e.g., mean, mode,
953
+ variance) use the value "`NaN`" to indicate the result is undefined.
954
+ When `False`, an exception is raised if one or more of the statistic's
955
+ batch members are undefined. Default value is `True`.
956
+ name: Python `str` name prefixed to Ops created by this class. Default
957
+ value is 'IndependentMultivariate' followed by the names of the
958
+ underlying distributions.
959
+
960
+ Raises:
961
+ ValueError: If one or more distributions are instances of
962
+ `tfd.Deterministic` or dtypes differ between the
963
+ distributions.
964
+ """
965
+ parameters = dict(locals())
966
+
967
+ self._verify_distributions(distributions)
968
+
969
+ self._distributions = [
970
+ dist
971
+ if not dist.is_scalar_batch()
972
+ else backend.tfd.BatchBroadcast(dist, (1,))
973
+ for dist in distributions
974
+ ]
975
+
976
+ self._distribution_batch_shapes = self._get_distribution_batch_shapes()
977
+ self._distribution_batch_shape_tensors = backend.concatenate(
978
+ [dist.batch_shape_tensor() for dist in self._distributions],
979
+ axis=0,
980
+ )
981
+
982
+ dtype = self._verify_dtypes()
983
+
984
+ name = name or '-'.join(
985
+ [constants.INDEPENDENT_MULTIVARIATE] + [d.name for d in distributions]
986
+ )
987
+
988
+ super().__init__(
989
+ dtype=dtype,
990
+ reparameterization_type=backend.tfd.NOT_REPARAMETERIZED,
991
+ validate_args=validate_args,
992
+ allow_nan_stats=allow_nan_stats,
993
+ parameters=parameters,
994
+ name=name,
995
+ )
996
+
997
+ def _verify_distributions(
998
+ self, distributions: Sequence[backend.tfd.Distribution]
999
+ ):
1000
+ """Check for deterministic distributions and raise an error if found."""
1001
+
1002
+ if any(
1003
+ isinstance(dist, backend.tfd.Deterministic)
1004
+ for dist in distributions
1005
+ ):
1006
+ raise ValueError(
1007
+ f'{self.__class__.__name__} cannot contain `Deterministic` '
1008
+ 'distributions. To implement a nearly deterministic element of this '
1009
+ 'distribution, we recommend using `backend.tfd.Uniform` with a '
1010
+ 'small range. For example to define a distribution that is nearly '
1011
+ '`Deterministic(1.0)`, use '
1012
+ '`tfp.distribution.Uniform(1.0 - 1e-9, 1.0 + 1e-9)`'
1013
+ )
1014
+
1015
+ def _verify_dtypes(self) -> str:
1016
+ dtypes = [dist.dtype for dist in self._distributions]
1017
+ if len(set(dtypes)) != 1:
1018
+ raise ValueError(
1019
+ f'All distributions must have the same dtype. Found: {dtypes}.'
1020
+ )
1021
+
1022
+ return backend.result_type(*dtypes)
1023
+
1024
+ def _event_shape(self):
1025
+ return backend.TensorShape([])
1026
+
1027
+ def _batch_shape_tensor(self):
1028
+ distribution_batch_shape_tensors = backend.concatenate(
1029
+ [dist.batch_shape_tensor() for dist in self._distributions],
1030
+ axis=0,
1031
+ )
1032
+ return backend.reduce_sum(
1033
+ distribution_batch_shape_tensors, keepdims=True
1034
+ )
1035
+
1036
+ def _batch_shape(self):
1037
+ return backend.TensorShape(sum(self._distribution_batch_shapes))
1038
+
1039
+ def _sample_n(self, n, seed=None):
1040
+ return backend.concatenate(
1041
+ [dist.sample(n, seed) for dist in self._distributions], axis=-1
1042
+ )
1043
+
1044
+ def _quantile(self, value):
1045
+ value = self._broadcast_value(value)
1046
+ split_value = backend.split(
1047
+ value,
1048
+ self._distribution_batch_shapes, axis=-1
1049
+ )
1050
+ quantiles = [
1051
+ dist.quantile(sv) for dist, sv in zip(self._distributions, split_value)
1052
+ ]
1053
+
1054
+ return backend.concatenate(quantiles, axis=-1)
1055
+
1056
+ def _log_prob(self, value):
1057
+ value = self._broadcast_value(value)
1058
+ split_value = backend.split(
1059
+ value,
1060
+ self._distribution_batch_shapes,
1061
+ axis=-1
1062
+ )
1063
+ log_probs = [
1064
+ dist.log_prob(sv) for dist, sv in zip(self._distributions, split_value)
1065
+ ]
1066
+
1067
+ return backend.concatenate(log_probs, axis=-1)
1068
+
1069
+ def _log_cdf(self, value):
1070
+ value = self._broadcast_value(value)
1071
+ split_value = backend.split(
1072
+ value,
1073
+ self._distribution_batch_shapes,
1074
+ axis=-1
1075
+ )
1076
+
1077
+ log_cdfs = [
1078
+ dist.log_cdf(sv) for dist, sv in zip(self._distributions, split_value)
1079
+ ]
1080
+
1081
+ return backend.concatenate(log_cdfs, axis=-1)
1082
+
1083
+ def _mean(self):
1084
+ return backend.concatenate(
1085
+ [dist.mean() for dist in self._distributions], axis=0
1086
+ )
1087
+
1088
+ def _variance(self):
1089
+ return backend.concatenate(
1090
+ [dist.variance() for dist in self._distributions], axis=0
1091
+ )
1092
+
1093
+ def _default_event_space_bijector(self):
1094
+ """Mapping from R^n to the event space of the wrapped distributions.
1095
+
1096
+ This is the blockwise concatenation of the underlying bijectors.
1097
+
1098
+ Returns:
1099
+ A `tfp.bijectors.Blockwise` object that concatenates the underlying
1100
+ bijectors.
1101
+ """
1102
+ bijectors = [
1103
+ d.experimental_default_event_space_bijector()
1104
+ for d in self._distributions
1105
+ ]
1106
+
1107
+ return backend.bijectors.Blockwise(
1108
+ bijectors,
1109
+ block_sizes=self._distribution_batch_shapes,
1110
+ )
1111
+
1112
+ def _broadcast_value(self, value: backend.Tensor) -> backend.Tensor:
1113
+ value = backend.to_tensor(value)
1114
+ broadcast_shape = backend.broadcast_dynamic_shape(
1115
+ value.shape, self.batch_shape_tensor()
1116
+ )
1117
+ return backend.broadcast_to(value, broadcast_shape)
1118
+
1119
+ def _get_distribution_batch_shapes(self) -> Sequence[int]:
1120
+ """Sequence of batch shapes of underlying distributions."""
1121
+
1122
+ batch_shapes = []
1123
+
1124
+ for dist in self._distributions:
1125
+ try:
1126
+ (dist_batch_shape,) = dist.batch_shape
1127
+ except ValueError as exc:
1128
+ raise ValueError(
1129
+ 'All distributions must be 0- or 1-dimensional.'
1130
+ f' Found {len(dist.batch_shape)}-dimensional distribution:'
1131
+ f' {dist.batch_shape}.'
1132
+ ) from exc
1133
+ else:
1134
+ batch_shapes.append(dist_batch_shape)
1135
+
1136
+ return batch_shapes
1137
+
1138
+
1139
+ def distributions_are_equal(
1140
+ a: backend.tfd.Distribution, b: backend.tfd.Distribution
1141
+ ) -> bool:
1142
+ """Determine if two distributions are equal."""
1143
+ if type(a) != type(b): # pylint: disable=unidiomatic-typecheck
1144
+ return False
1145
+
1146
+ a_params = a.parameters.copy()
1147
+ b_params = b.parameters.copy()
1148
+
1149
+ if constants.DISTRIBUTION in a_params and constants.DISTRIBUTION in b_params:
1150
+ if not distributions_are_equal(
1151
+ a_params[constants.DISTRIBUTION], b_params[constants.DISTRIBUTION]
1152
+ ):
1153
+ return False
1154
+ del a_params[constants.DISTRIBUTION]
1155
+ del b_params[constants.DISTRIBUTION]
1156
+
1157
+ if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
1158
+ return False
1159
+
1160
+ if a_params.keys() != b_params.keys():
1161
+ return False
1162
+
1163
+ for key in a_params.keys():
1164
+ if isinstance(
1165
+ a_params[key], (backend.Tensor, np.ndarray, float, int)
1166
+ ) and isinstance(b_params[key], (backend.Tensor, np.ndarray, float, int)):
1167
+ if not backend.allclose(a_params[key], b_params[key]):
1168
+ return False
1169
+ else:
1170
+ if a_params[key] != b_params[key]:
1171
+ return False
1172
+
1173
+ return True
1174
+
1175
+
895
1176
  def _convert_to_deterministic_0_distribution(
896
- distribution: tfp.distributions.Distribution,
897
- ) -> tfp.distributions.Distribution:
1177
+ distribution: backend.tfd.Distribution,
1178
+ ) -> backend.tfd.Distribution:
898
1179
  """Converts the given distribution to a `Deterministic(0)` one.
899
1180
 
900
1181
  Args:
@@ -909,7 +1190,7 @@ def _convert_to_deterministic_0_distribution(
909
1190
  distribution.
910
1191
  """
911
1192
  if (
912
- not isinstance(distribution, tfp.distributions.Deterministic)
1193
+ not isinstance(distribution, backend.tfd.Deterministic)
913
1194
  or distribution.loc != 0
914
1195
  ):
915
1196
  warnings.warn(
@@ -917,7 +1198,7 @@ def _convert_to_deterministic_0_distribution(
917
1198
  f' for national models. {distribution.name} has been automatically set'
918
1199
  ' to Deterministic(0).'
919
1200
  )
920
- return tfp.distributions.Deterministic(loc=0, name=distribution.name)
1201
+ return backend.tfd.Deterministic(loc=0, name=distribution.name)
921
1202
  else:
922
1203
  return distribution
923
1204
 
@@ -928,7 +1209,7 @@ def _get_total_media_contribution_prior(
928
1209
  name: str,
929
1210
  p_mean: float = constants.P_MEAN,
930
1211
  p_sd: float = constants.P_SD,
931
- ) -> tfp.distributions.Distribution:
1212
+ ) -> backend.tfd.Distribution:
932
1213
  """Determines ROI priors based on total media contribution.
933
1214
 
934
1215
  Args:
@@ -945,34 +1226,123 @@ def _get_total_media_contribution_prior(
945
1226
  """
946
1227
  roi_mean = p_mean * kpi / np.sum(total_spend)
947
1228
  roi_sd = p_sd * kpi / np.sqrt(np.sum(np.power(total_spend, 2)))
948
- lognormal_sigma = tf.cast(
949
- np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=tf.float32
1229
+ lognormal_sigma = backend.cast(
1230
+ np.sqrt(np.log(roi_sd**2 / roi_mean**2 + 1)), dtype=backend.float32
950
1231
  )
951
- lognormal_mu = tf.cast(
952
- np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)), dtype=tf.float32
1232
+ lognormal_mu = backend.cast(
1233
+ np.log(roi_mean * np.exp(-(lognormal_sigma**2) / 2)),
1234
+ dtype=backend.float32,
953
1235
  )
954
- return tfp.distributions.LogNormal(lognormal_mu, lognormal_sigma, name=name)
1236
+ return backend.tfd.LogNormal(lognormal_mu, lognormal_sigma, name=name)
955
1237
 
956
1238
 
957
- def distributions_are_equal(
958
- a: tfp.distributions.Distribution, b: tfp.distributions.Distribution
959
- ) -> bool:
960
- """Determine if two distributions are equal."""
961
- if type(a) != type(b): # pylint: disable=unidiomatic-typecheck
962
- return False
1239
+ def _validate_support(
1240
+ parameter_name: str,
1241
+ tfp_dist: backend.tfp.distributions.Distribution,
1242
+ bounds: tuple[float, float],
1243
+ prevent_deterministic_prior_at_bounds: tuple[bool, bool],
1244
+ ) -> None:
1245
+ """Validates that distribution support is within the parameter bounds.
963
1246
 
964
- a_params = a.parameters.copy()
965
- b_params = b.parameters.copy()
1247
+ Args:
1248
+ parameter_name: Name of the parameter.
1249
+ tfp_dist: The TFP distribution to validate.
1250
+ bounds: Tuple containing the min and max values of the parameteter space.
1251
+ prevent_deterministic_prior_at_bounds: Tuple of two booleans indicating
1252
+ whether a deterministic prior is allowed at the lower and upper bounds,
1253
+ respectively.
966
1254
 
967
- if constants.DISTRIBUTION in a_params and constants.DISTRIBUTION in b_params:
968
- if not distributions_are_equal(
969
- a_params[constants.DISTRIBUTION], b_params[constants.DISTRIBUTION]
970
- ):
971
- return False
972
- del a_params[constants.DISTRIBUTION]
973
- del b_params[constants.DISTRIBUTION]
1255
+ Raises:
1256
+ ValueError: If the distribution support is not within the parameter bounds.
1257
+ """
1258
+ # Note that `tfp.distributions.BatchBroadcast` objects have a `distribution`
1259
+ # attribute that points to a `tfp.distributions.Distribution` object.
1260
+ if isinstance(tfp_dist, backend.tfp.distributions.BatchBroadcast):
1261
+ tfp_dist = tfp_dist.distribution
1262
+ # Note that `tfp.distributions.Deterministic` does not have a `quantile`
1263
+ # method implemented, so the min and max values must be extracted from the
1264
+ # `loc` attribute instead.
1265
+ if isinstance(
1266
+ tfp_dist,
1267
+ backend.tfp.python.distributions.deterministic.Deterministic
1268
+ ):
1269
+ support_min_vals = tfp_dist.loc
1270
+ support_max_vals = tfp_dist.loc
1271
+ for i in (0, 1):
1272
+ if (
1273
+ prevent_deterministic_prior_at_bounds[i]
1274
+ and np.any(tfp_dist.loc == bounds[i])
1275
+ ):
1276
+ raise ValueError(
1277
+ f'{parameter_name} was assigned a point mass (deterministic) prior'
1278
+ f' at {bounds[i]}, which is not allowed.'
1279
+ )
1280
+ else:
1281
+ try:
1282
+ support_min_vals = tfp_dist.quantile(0)
1283
+ support_max_vals = tfp_dist.quantile(1)
1284
+ except (AttributeError, NotImplementedError):
1285
+ warnings.warn(
1286
+ f'The prior distribution for {parameter_name} does not have a'
1287
+ f' `quantile` method implemented, so the support range validation'
1288
+ f' was skipped. Confirm that your prior for {parameter_name} is'
1289
+ f' appropriate.'
1290
+ )
1291
+ return
1292
+ if np.any(support_min_vals < bounds[0]):
1293
+ raise ValueError(
1294
+ f'{parameter_name} was assigned a prior distribution that allows values'
1295
+ f' less than the parameter minimum {bounds[0]}.'
1296
+ )
1297
+ if np.any(support_max_vals > bounds[1]):
1298
+ raise ValueError(
1299
+ f'{parameter_name} was assigned a prior distribution that allows values'
1300
+ f' greater than the parameter maximum {bounds[1]}.'
1301
+ )
974
1302
 
975
- if constants.DISTRIBUTION in a_params or constants.DISTRIBUTION in b_params:
976
- return False
1303
+ # Dictionary of parameters that have a limited parameters space. The tuple
1304
+ # contains the lower and upper bounds, respectively.
1305
+ _parameter_space_bounds = {
1306
+ 'eta_m': (0, np.inf),
1307
+ 'eta_rf': (0, np.inf),
1308
+ 'eta_om': (0, np.inf),
1309
+ 'eta_orf': (0, np.inf),
1310
+ 'xi_c': (0, np.inf),
1311
+ 'xi_n': (0, np.inf),
1312
+ 'alpha_m': (0, 1),
1313
+ 'alpha_rf': (0, 1),
1314
+ 'alpha_om': (0, 1),
1315
+ 'alpha_orf': (0, 1),
1316
+ 'ec_m': (0, np.inf),
1317
+ 'ec_rf': (0, np.inf),
1318
+ 'ec_om': (0, np.inf),
1319
+ 'ec_orf': (0, np.inf),
1320
+ 'slope_m': (0, np.inf),
1321
+ 'slope_rf': (0, np.inf),
1322
+ 'slope_om': (0, np.inf),
1323
+ 'slope_orf': (0, np.inf),
1324
+ 'sigma': (0, np.inf),
1325
+ }
977
1326
 
978
- return a_params == b_params
1327
+ # Dictionary of parameters that do not allow a deterministic prior at one or
1328
+ # more of the parameter space bounds. The boolean tuple indicates whether a
1329
+ # deterministic prior is allowed at the lower bound or upper bound,
1330
+ # respectively, where `True` means "not allowed". This check is specifically for
1331
+ # point mass at finite paramteter space bounds, since point mass at infinity is
1332
+ # generally problematic for all parameters. Note that `sigma` should generally
1333
+ # not have point mass at zero, but this is not checked here because unit tests
1334
+ # require the ability to simulate data with `sigma` set to zero.
1335
+ _prevent_deterministic_prior_at_bounds = {
1336
+ 'alpha_m': (False, True),
1337
+ 'alpha_rf': (False, True),
1338
+ 'alpha_om': (False, True),
1339
+ 'alpha_orf': (False, True),
1340
+ 'ec_m': (True, False),
1341
+ 'ec_rf': (True, False),
1342
+ 'ec_om': (True, False),
1343
+ 'ec_orf': (True, False),
1344
+ 'slope_m': (True, False),
1345
+ 'slope_rf': (True, False),
1346
+ 'slope_om': (True, False),
1347
+ 'slope_orf': (True, False),
1348
+ }