google-meridian 1.0.9__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -110,16 +110,11 @@ class PosteriorMCMCSampler:
110
110
  organic_media_tensors = mmm.organic_media_tensors
111
111
  organic_rf_tensors = mmm.organic_rf_tensors
112
112
  controls_scaled = mmm.controls_scaled
113
- non_media_treatments_scaled = mmm.non_media_treatments_scaled
113
+ non_media_treatments_normalized = mmm.non_media_treatments_normalized
114
114
  media_effects_dist = mmm.media_effects_dist
115
115
  adstock_hill_media_fn = mmm.adstock_hill_media
116
116
  adstock_hill_rf_fn = mmm.adstock_hill_rf
117
- get_roi_prior_beta_m_value_fn = (
118
- mmm.prior_sampler_callable.get_roi_prior_beta_m_value
119
- )
120
- get_roi_prior_beta_rf_value_fn = (
121
- mmm.prior_sampler_callable.get_roi_prior_beta_rf_value
122
- )
117
+ total_outcome = mmm.total_outcome
123
118
 
124
119
  @tfp.distributions.JointDistributionCoroutineAutoBatched
125
120
  def joint_dist_unpinned():
@@ -167,26 +162,39 @@ class PosteriorMCMCSampler:
167
162
  ec=ec_m,
168
163
  slope=slope_m,
169
164
  )
170
- prior_type = mmm.model_spec.paid_media_prior_type
171
- if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
172
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
173
- roi_or_mroi_m = yield prior_broadcast.roi_m
165
+ prior_type = mmm.model_spec.effective_media_prior_type
166
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
167
+ beta_m = yield prior_broadcast.beta_m
168
+ else:
169
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
170
+ treatment_parameter_m = yield prior_broadcast.roi_m
171
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
172
+ treatment_parameter_m = yield prior_broadcast.mroi_m
173
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
174
+ treatment_parameter_m = yield prior_broadcast.contribution_m
174
175
  else:
175
- roi_or_mroi_m = yield prior_broadcast.mroi_m
176
- beta_m_value = get_roi_prior_beta_m_value_fn(
177
- alpha_m,
178
- beta_gm_dev,
179
- ec_m,
180
- eta_m,
181
- roi_or_mroi_m,
182
- slope_m,
183
- media_transformed,
176
+ raise ValueError(f"Unsupported prior type: {prior_type}")
177
+ incremental_outcome_m = (
178
+ treatment_parameter_m * media_tensors.prior_denominator
179
+ )
180
+ linear_predictor_counterfactual_difference = (
181
+ mmm.linear_predictor_counterfactual_difference_media(
182
+ media_transformed=media_transformed,
183
+ alpha_m=alpha_m,
184
+ ec_m=ec_m,
185
+ slope_m=slope_m,
186
+ )
187
+ )
188
+ beta_m_value = mmm.calculate_beta_x(
189
+ is_non_media=False,
190
+ incremental_outcome_x=incremental_outcome_m,
191
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
192
+ eta_x=eta_m,
193
+ beta_gx_dev=beta_gm_dev,
184
194
  )
185
195
  beta_m = yield tfp.distributions.Deterministic(
186
196
  beta_m_value, name=constants.BETA_M
187
197
  )
188
- else:
189
- beta_m = yield prior_broadcast.beta_m
190
198
 
191
199
  beta_eta_combined = beta_m + eta_m * beta_gm_dev
192
200
  beta_gm_value = (
@@ -220,27 +228,39 @@ class PosteriorMCMCSampler:
220
228
  slope=slope_rf,
221
229
  )
222
230
 
223
- prior_type = mmm.model_spec.paid_media_prior_type
224
- if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
225
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
226
- roi_or_mroi_rf = yield prior_broadcast.roi_rf
231
+ prior_type = mmm.model_spec.effective_rf_prior_type
232
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
233
+ beta_rf = yield prior_broadcast.beta_rf
234
+ else:
235
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
236
+ treatment_parameter_rf = yield prior_broadcast.roi_rf
237
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
238
+ treatment_parameter_rf = yield prior_broadcast.mroi_rf
239
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
240
+ treatment_parameter_rf = yield prior_broadcast.contribution_rf
227
241
  else:
228
- roi_or_mroi_rf = yield prior_broadcast.mroi_rf
229
- beta_rf_value = get_roi_prior_beta_rf_value_fn(
230
- alpha_rf,
231
- beta_grf_dev,
232
- ec_rf,
233
- eta_rf,
234
- roi_or_mroi_rf,
235
- slope_rf,
236
- rf_transformed,
242
+ raise ValueError(f"Unsupported prior type: {prior_type}")
243
+ incremental_outcome_rf = (
244
+ treatment_parameter_rf * rf_tensors.prior_denominator
245
+ )
246
+ linear_predictor_counterfactual_difference = (
247
+ mmm.linear_predictor_counterfactual_difference_rf(
248
+ rf_transformed=rf_transformed,
249
+ alpha_rf=alpha_rf,
250
+ ec_rf=ec_rf,
251
+ slope_rf=slope_rf,
252
+ )
253
+ )
254
+ beta_rf_value = mmm.calculate_beta_x(
255
+ is_non_media=False,
256
+ incremental_outcome_x=incremental_outcome_rf,
257
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
258
+ eta_x=eta_rf,
259
+ beta_gx_dev=beta_grf_dev,
237
260
  )
238
261
  beta_rf = yield tfp.distributions.Deterministic(
239
- beta_rf_value,
240
- name=constants.BETA_RF,
262
+ beta_rf_value, name=constants.BETA_RF
241
263
  )
242
- else:
243
- beta_rf = yield prior_broadcast.beta_rf
244
264
 
245
265
  beta_eta_combined = beta_rf + eta_rf * beta_grf_dev
246
266
  beta_grf_value = (
@@ -272,7 +292,24 @@ class PosteriorMCMCSampler:
272
292
  ec=ec_om,
273
293
  slope=slope_om,
274
294
  )
275
- beta_om = yield prior_broadcast.beta_om
295
+ prior_type = mmm.model_spec.organic_media_prior_type
296
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
297
+ beta_om = yield prior_broadcast.beta_om
298
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
299
+ contribution_om = yield prior_broadcast.contribution_om
300
+ incremental_outcome_om = contribution_om * total_outcome
301
+ beta_om_value = mmm.calculate_beta_x(
302
+ is_non_media=False,
303
+ incremental_outcome_x=incremental_outcome_om,
304
+ linear_predictor_counterfactual_difference=organic_media_transformed,
305
+ eta_x=eta_om,
306
+ beta_gx_dev=beta_gom_dev,
307
+ )
308
+ beta_om = yield tfp.distributions.Deterministic(
309
+ beta_om_value, name=constants.BETA_OM
310
+ )
311
+ else:
312
+ raise ValueError(f"Unsupported prior type: {prior_type}")
276
313
 
277
314
  beta_eta_combined = beta_om + eta_om * beta_gom_dev
278
315
  beta_gom_value = (
@@ -306,7 +343,24 @@ class PosteriorMCMCSampler:
306
343
  slope=slope_orf,
307
344
  )
308
345
 
309
- beta_orf = yield prior_broadcast.beta_orf
346
+ prior_type = mmm.model_spec.organic_rf_prior_type
347
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
348
+ beta_orf = yield prior_broadcast.beta_orf
349
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
350
+ contribution_orf = yield prior_broadcast.contribution_orf
351
+ incremental_outcome_orf = contribution_orf * total_outcome
352
+ beta_orf_value = mmm.calculate_beta_x(
353
+ is_non_media=False,
354
+ incremental_outcome_x=incremental_outcome_orf,
355
+ linear_predictor_counterfactual_difference=organic_rf_transformed,
356
+ eta_x=eta_orf,
357
+ beta_gx_dev=beta_gorf_dev,
358
+ )
359
+ beta_orf = yield tfp.distributions.Deterministic(
360
+ beta_orf_value, name=constants.BETA_ORF
361
+ )
362
+ else:
363
+ raise ValueError(f"Unsupported prior type: {prior_type}")
310
364
 
311
365
  beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev
312
366
  beta_gorf_value = (
@@ -338,18 +392,41 @@ class PosteriorMCMCSampler:
338
392
  )
339
393
 
340
394
  if mmm.non_media_treatments is not None:
341
- gamma_n = yield prior_broadcast.gamma_n
342
395
  xi_n = yield prior_broadcast.xi_n
343
396
  gamma_gn_dev = yield tfp.distributions.Sample(
344
397
  tfp.distributions.Normal(0, 1),
345
398
  [n_geos, n_non_media_channels],
346
399
  name=constants.GAMMA_GN_DEV,
347
400
  )
401
+ prior_type = mmm.model_spec.non_media_treatments_prior_type
402
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
403
+ gamma_n = yield prior_broadcast.gamma_n
404
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
405
+ contribution_n = yield prior_broadcast.contribution_n
406
+ incremental_outcome_n = contribution_n * total_outcome
407
+ baseline_scaled = mmm.non_media_transformer.forward( # pytype: disable=attribute-error
408
+ mmm.compute_non_media_treatments_baseline()
409
+ )
410
+ linear_predictor_counterfactual_difference = (
411
+ non_media_treatments_normalized - baseline_scaled
412
+ )
413
+ gamma_n_value = mmm.calculate_beta_x(
414
+ is_non_media=True,
415
+ incremental_outcome_x=incremental_outcome_n,
416
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
417
+ eta_x=xi_n,
418
+ beta_gx_dev=gamma_gn_dev,
419
+ )
420
+ gamma_n = yield tfp.distributions.Deterministic(
421
+ gamma_n_value, name=constants.GAMMA_N
422
+ )
423
+ else:
424
+ raise ValueError(f"Unsupported prior type: {prior_type}")
348
425
  gamma_gn = yield tfp.distributions.Deterministic(
349
426
  gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
350
427
  )
351
428
  y_pred = y_pred_combined_media + tf.einsum(
352
- "gtn,gn->gt", non_media_treatments_scaled, gamma_gn
429
+ "gtn,gn->gt", non_media_treatments_normalized, gamma_gn
353
430
  )
354
431
  else:
355
432
  y_pred = y_pred_combined_media
@@ -84,6 +84,11 @@ class PriorDistribution:
84
84
  | `roi_rf` | `n_rf_channels` |
85
85
  | `mroi_m` | `n_media_channels` |
86
86
  | `mroi_rf` | `n_rf_channels` |
87
+ | `contribution_m` | `n_media_channels` |
88
+ | `contribution_rf` | `n_rf_channels` |
89
+ | `contribution_om` | `n_organic_media_channels` |
90
+ | `contribution_orf` | `n_organic_f_channels` |
91
+ | `contribution_n` | `n_non_media_channels` |
87
92
 
88
93
  (σ) `n_geos` if `unique_sigma_for_each_geo`, otherwise this is `1`
89
94
 
@@ -233,6 +238,36 @@ class PriorDistribution:
233
238
  is interpreted as the marginal incremental KPI units per monetary unit
234
239
  spent. In this case, a default distribution is not provided, so the user
235
240
  must specify it.
241
+ contribution_m: Prior distribution on the contribution of each media channel
242
+ as a percentage of total outcome. This parameter is only used when
243
+ `paid_media_prior_type` is `'contribution'`, in which case `beta_m` is
244
+ calculated as a deterministic function of `contribution_m`, `alpha_m`,
245
+ `ec_m`, `slope_m`, and the total outcome. Default distribution is
246
+ `Beta(1.0, 99.0)`.
247
+ contribution_rf: Prior distribution on the contribution of each Reach &
248
+ Frequency channel as a percentage of total outcome. This parameter is only
249
+ used when `paid_media_prior_type` is `'contribution'`, in which case
250
+ `beta_rf` is calculated as a deterministic function of `contribution_rf`,
251
+ `alpha_rf`, `ec_rf`, `slope_rf`, and the total outcome. Default
252
+ distribution is `Beta(1.0, 99.0)`.
253
+ contribution_om: Prior distribution on the contribution of each organic
254
+ media channel as a percentage of total outcome. This parameter is only
255
+ used when `organic_media_prior_type` is `'contribution'`, in which case
256
+ `beta_om` is calculated as a deterministic function of `contribution_om`,
257
+ `alpha_om`, `ec_om`, `slope_om`, and the total outcome. Default
258
+ distribution is `Beta(1.0, 99.0)`.
259
+ contribution_orf: Prior distribution on the contribution of each organic
260
+ Reach & Frequency channel as a percentage of total outcome. This parameter
261
+ is only used when `organic_media_prior_type` is `'contribution'`, in which
262
+ case `beta_orf` is calculated as a deterministic function of
263
+ `contribution_orf`, `alpha_orf`, `ec_orf`, `slope_orf`, and the total
264
+ outcome. Default distribution is `Beta(1.0, 99.0)`.
265
+ contribution_n: Prior distribution on the contribution of each non-media
266
+ treatment channel as a percentage of total outcome. This parameter is only
267
+ used when `non_media_treatment_prior_type` is `'contribution'`, in which
268
+ case `gamma_n` is calculated as a deterministic function of
269
+ `contribution_n` and the total outcome. Default distribution is
270
+ `TruncatedNormal(0.0, 0.1, -1.0, 1.0)`.
236
271
  """
237
272
 
238
273
  knot_values: tfp.distributions.Distribution = dataclasses.field(
@@ -394,6 +429,31 @@ class PriorDistribution:
394
429
  0.0, 0.5, name=constants.MROI_RF
395
430
  ),
396
431
  )
432
+ contribution_m: tfp.distributions.Distribution = dataclasses.field(
433
+ default_factory=lambda: tfp.distributions.Beta(
434
+ 1.0, 99.0, name=constants.CONTRIBUTION_M
435
+ ),
436
+ )
437
+ contribution_rf: tfp.distributions.Distribution = dataclasses.field(
438
+ default_factory=lambda: tfp.distributions.Beta(
439
+ 1.0, 99.0, name=constants.CONTRIBUTION_RF
440
+ ),
441
+ )
442
+ contribution_om: tfp.distributions.Distribution = dataclasses.field(
443
+ default_factory=lambda: tfp.distributions.Beta(
444
+ 1.0, 99.0, name=constants.CONTRIBUTION_OM
445
+ ),
446
+ )
447
+ contribution_orf: tfp.distributions.Distribution = dataclasses.field(
448
+ default_factory=lambda: tfp.distributions.Beta(
449
+ 1.0, 99.0, name=constants.CONTRIBUTION_ORF
450
+ ),
451
+ )
452
+ contribution_n: tfp.distributions.Distribution = dataclasses.field(
453
+ default_factory=lambda: tfp.distributions.TruncatedNormal(
454
+ loc=0.0, scale=0.1, low=-1.0, high=1.0, name=constants.CONTRIBUTION_N
455
+ ),
456
+ )
397
457
 
398
458
  def __setstate__(self, state):
399
459
  # Override to support pickling.
@@ -455,7 +515,6 @@ class PriorDistribution:
455
515
  set_total_media_contribution_prior: bool,
456
516
  kpi: float,
457
517
  total_spend: np.ndarray,
458
- media_effects_dist: str,
459
518
  ) -> PriorDistribution:
460
519
  """Returns a new `PriorDistribution` with broadcast distribution attributes.
461
520
 
@@ -481,8 +540,6 @@ class PriorDistribution:
481
540
  `set_total_media_contribution_prior=True`.
482
541
  total_spend: Spend per media channel summed across geos and time. Required
483
542
  if `set_total_media_contribution_prior=True`.
484
- media_effects_dist: A string to specify the distribution of media random
485
- effects across geos.
486
543
 
487
544
  Returns:
488
545
  A new `PriorDistribution` broadcast from this prior distribution,
@@ -508,6 +565,7 @@ class PriorDistribution:
508
565
 
509
566
  _validate_media_custom_priors(self.roi_m)
510
567
  _validate_media_custom_priors(self.mroi_m)
568
+ _validate_media_custom_priors(self.contribution_m)
511
569
  _validate_media_custom_priors(self.alpha_m)
512
570
  _validate_media_custom_priors(self.ec_m)
513
571
  _validate_media_custom_priors(self.slope_m)
@@ -529,6 +587,7 @@ class PriorDistribution:
529
587
  'that channel.'
530
588
  )
531
589
 
590
+ _validate_organic_media_custom_priors(self.contribution_om)
532
591
  _validate_organic_media_custom_priors(self.alpha_om)
533
592
  _validate_organic_media_custom_priors(self.ec_om)
534
593
  _validate_organic_media_custom_priors(self.slope_om)
@@ -550,6 +609,7 @@ class PriorDistribution:
550
609
  'for that channel.'
551
610
  )
552
611
 
612
+ _validate_organic_rf_custom_priors(self.contribution_orf)
553
613
  _validate_organic_rf_custom_priors(self.alpha_orf)
554
614
  _validate_organic_rf_custom_priors(self.ec_orf)
555
615
  _validate_organic_rf_custom_priors(self.slope_orf)
@@ -569,6 +629,7 @@ class PriorDistribution:
569
629
 
570
630
  _validate_rf_custom_priors(self.roi_rf)
571
631
  _validate_rf_custom_priors(self.mroi_rf)
632
+ _validate_rf_custom_priors(self.contribution_rf)
572
633
  _validate_rf_custom_priors(self.alpha_rf)
573
634
  _validate_rf_custom_priors(self.ec_rf)
574
635
  _validate_rf_custom_priors(self.slope_rf)
@@ -604,6 +665,7 @@ class PriorDistribution:
604
665
  'that channel.'
605
666
  )
606
667
 
668
+ _validate_non_media_custom_priors(self.contribution_n)
607
669
  _validate_non_media_custom_priors(self.gamma_n)
608
670
  _validate_non_media_custom_priors(self.xi_n)
609
671
 
@@ -743,57 +805,50 @@ class PriorDistribution:
743
805
  self.sigma, sigma_shape, name=constants.SIGMA
744
806
  )
745
807
 
746
- default_distribution = PriorDistribution()
747
- if set_total_media_contribution_prior and distributions_are_equal(
748
- self.roi_m, default_distribution.roi_m
749
- ):
750
- warnings.warn(
751
- 'Consider setting custom ROI priors, as kpi_type was specified as'
752
- ' `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the'
753
- ' total media contribution prior will be used with'
754
- f' `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further'
755
- ' documentation available at '
756
- ' https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior',
757
- )
808
+ if set_total_media_contribution_prior:
758
809
  roi_m_converted = _get_total_media_contribution_prior(
759
810
  kpi, total_spend, constants.ROI_M
760
811
  )
761
- else:
762
- roi_m_converted = self.roi_m
763
- _check_for_negative_effect(roi_m_converted, media_effects_dist)
764
- roi_m = tfp.distributions.BatchBroadcast(
765
- roi_m_converted, n_media_channels, name=constants.ROI_M
766
- )
767
-
768
- if set_total_media_contribution_prior and distributions_are_equal(
769
- self.roi_rf, default_distribution.roi_rf
770
- ):
771
- warnings.warn(
772
- 'Consider setting custom ROI priors, as kpi_type was specified as'
773
- ' `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the'
774
- ' total media contribution prior will be used with'
775
- f' `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further'
776
- ' documentation available at '
777
- ' https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior',
778
- )
779
812
  roi_rf_converted = _get_total_media_contribution_prior(
780
813
  kpi, total_spend, constants.ROI_RF
781
814
  )
782
815
  else:
816
+ roi_m_converted = self.roi_m
783
817
  roi_rf_converted = self.roi_rf
784
- _check_for_negative_effect(roi_rf_converted, media_effects_dist)
818
+ roi_m = tfp.distributions.BatchBroadcast(
819
+ roi_m_converted, n_media_channels, name=constants.ROI_M
820
+ )
785
821
  roi_rf = tfp.distributions.BatchBroadcast(
786
822
  roi_rf_converted, n_rf_channels, name=constants.ROI_RF
787
823
  )
788
- _check_for_negative_effect(self.mroi_m, media_effects_dist)
824
+
789
825
  mroi_m = tfp.distributions.BatchBroadcast(
790
826
  self.mroi_m, n_media_channels, name=constants.MROI_M
791
827
  )
792
- _check_for_negative_effect(self.mroi_rf, media_effects_dist)
793
828
  mroi_rf = tfp.distributions.BatchBroadcast(
794
829
  self.mroi_rf, n_rf_channels, name=constants.MROI_RF
795
830
  )
796
831
 
832
+ contribution_m = tfp.distributions.BatchBroadcast(
833
+ self.contribution_m, n_media_channels, name=constants.CONTRIBUTION_M
834
+ )
835
+ contribution_rf = tfp.distributions.BatchBroadcast(
836
+ self.contribution_rf, n_rf_channels, name=constants.CONTRIBUTION_RF
837
+ )
838
+ contribution_om = tfp.distributions.BatchBroadcast(
839
+ self.contribution_om,
840
+ n_organic_media_channels,
841
+ name=constants.CONTRIBUTION_OM,
842
+ )
843
+ contribution_orf = tfp.distributions.BatchBroadcast(
844
+ self.contribution_orf,
845
+ n_organic_rf_channels,
846
+ name=constants.CONTRIBUTION_ORF,
847
+ )
848
+ contribution_n = tfp.distributions.BatchBroadcast(
849
+ self.contribution_n, n_non_media_channels, name=constants.CONTRIBUTION_N
850
+ )
851
+
797
852
  return PriorDistribution(
798
853
  knot_values=knot_values,
799
854
  tau_g_excl_baseline=tau_g_excl_baseline,
@@ -826,6 +881,11 @@ class PriorDistribution:
826
881
  roi_rf=roi_rf,
827
882
  mroi_m=mroi_m,
828
883
  mroi_rf=mroi_rf,
884
+ contribution_m=contribution_m,
885
+ contribution_rf=contribution_rf,
886
+ contribution_om=contribution_om,
887
+ contribution_orf=contribution_orf,
888
+ contribution_n=contribution_n,
829
889
  )
830
890
 
831
891
 
@@ -891,21 +951,6 @@ def _get_total_media_contribution_prior(
891
951
  return tfp.distributions.LogNormal(lognormal_mu, lognormal_sigma, name=name)
892
952
 
893
953
 
894
- def _check_for_negative_effect(
895
- dist: tfp.distributions.Distribution, media_effects_dist: str
896
- ):
897
- """Checks for negative effect in the model."""
898
- if (
899
- media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
900
- and np.any(dist.cdf(0)) > 0
901
- ):
902
- raise ValueError(
903
- 'Media priors must have non-negative support when'
904
- f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
905
- f' in {dist.name}.'
906
- )
907
-
908
-
909
954
  def distributions_are_equal(
910
955
  a: tfp.distributions.Distribution, b: tfp.distributions.Distribution
911
956
  ) -> bool: