google-meridian 1.1.3__py3-none-any.whl → 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/model/model.py CHANGED
@@ -295,10 +295,6 @@ class Meridian:
295
295
  def is_national(self) -> bool:
296
296
  return self.n_geos == 1
297
297
 
298
- @property
299
- def _sigma_shape(self) -> int:
300
- return len(self.input_data.geo) if self.unique_sigma_for_each_geo else 1
301
-
302
298
  @functools.cached_property
303
299
  def knot_info(self) -> knots.KnotInfo:
304
300
  return knots.get_knot_info(
@@ -389,6 +385,7 @@ class Meridian:
389
385
  @functools.cached_property
390
386
  def unique_sigma_for_each_geo(self) -> bool:
391
387
  if self.is_national:
388
+ # Should evaluate to False.
392
389
  return constants.NATIONAL_MODEL_SPEC_ARGS[
393
390
  constants.UNIQUE_SIGMA_FOR_EACH_GEO
394
391
  ]
@@ -449,7 +446,7 @@ class Meridian:
449
446
  n_organic_rf_channels=self.n_organic_rf_channels,
450
447
  n_controls=self.n_controls,
451
448
  n_non_media_channels=self.n_non_media_channels,
452
- sigma_shape=self._sigma_shape,
449
+ unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
453
450
  n_knots=self.knot_info.n_knots,
454
451
  is_national=self.is_national,
455
452
  set_total_media_contribution_prior=self._set_total_media_contribution_prior,
@@ -663,10 +660,6 @@ class Meridian:
663
660
  self._validate_injected_inference_data_group_coord(
664
661
  inference_data, group, constants.TIME, self.n_times
665
662
  )
666
- if not self.model_spec.unique_sigma_for_each_geo:
667
- self._validate_injected_inference_data_group_coord(
668
- inference_data, group, constants.SIGMA_DIM, self._sigma_shape
669
- )
670
663
  self._validate_injected_inference_data_group_coord(
671
664
  inference_data,
672
665
  group,
@@ -1429,7 +1422,7 @@ class Meridian:
1429
1422
  if self.unique_sigma_for_each_geo:
1430
1423
  inference_dims[constants.SIGMA] = [constants.GEO]
1431
1424
  else:
1432
- inference_dims[constants.SIGMA] = [constants.SIGMA_DIM]
1425
+ inference_dims[constants.SIGMA] = []
1433
1426
 
1434
1427
  return {
1435
1428
  param: [constants.CHAIN, constants.DRAW] + list(dims)
@@ -509,7 +509,7 @@ class PriorDistribution:
509
509
  n_organic_rf_channels: int,
510
510
  n_controls: int,
511
511
  n_non_media_channels: int,
512
- sigma_shape: int,
512
+ unique_sigma_for_each_geo: bool,
513
513
  n_knots: int,
514
514
  is_national: bool,
515
515
  set_total_media_contribution_prior: bool,
@@ -527,9 +527,9 @@ class PriorDistribution:
527
527
  used.
528
528
  n_controls: Number of controls used.
529
529
  n_non_media_channels: Number of non-media channels used.
530
- sigma_shape: A number describing the shape of the sigma parameter. It's
531
- either `1` (if `sigma_for_each_geo=False`) or `n_geos` (if
532
- `sigma_for_each_geo=True`). For more information, see `ModelSpec`.
530
+ unique_sigma_for_each_geo: A boolean indicator whether to use the same
531
+ sigma parameter for all geos. Only used if `n_geos > 1`. For more
532
+ information, see `ModelSpec`.
533
533
  n_knots: Number of knots used.
534
534
  is_national: A boolean indicator whether the prior distribution will be
535
535
  adapted for a national model.
@@ -801,6 +801,9 @@ class PriorDistribution:
801
801
  slope_orf = tfp.distributions.BatchBroadcast(
802
802
  self.slope_orf, n_organic_rf_channels, name=constants.SLOPE_ORF
803
803
  )
804
+
805
+ # If `unique_sigma_for_each_geo == False`, then make a scalar batch.
806
+ sigma_shape = n_geos if (n_geos > 1 and unique_sigma_for_each_geo) else []
804
807
  sigma = tfp.distributions.BatchBroadcast(
805
808
  self.sigma, sigma_shape, name=constants.SIGMA
806
809
  )
@@ -510,6 +510,8 @@ class PriorDistributionSampler:
510
510
  tf.keras.utils.set_random_seed(1)
511
511
 
512
512
  prior = mmm.prior_broadcast
513
+ # `sample_shape` is prepended to the shape of each BatchBroadcast in `prior`
514
+ # when it is sampled.
513
515
  sample_shape = [1, n_draws]
514
516
  sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
515
517
 
meridian/version.py CHANGED
@@ -14,4 +14,4 @@
14
14
 
15
15
  """Module for Meridian version."""
16
16
 
17
- __version__ = "1.1.3"
17
+ __version__ = "1.1.5"