google-meridian 1.1.3__py3-none-any.whl → 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.5.dist-info}/METADATA +2 -2
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.5.dist-info}/RECORD +15 -15
- meridian/analysis/analyzer.py +18 -11
- meridian/analysis/optimizer.py +292 -47
- meridian/constants.py +6 -4
- meridian/data/data_frame_input_data_builder.py +222 -61
- meridian/data/input_data_builder.py +3 -1
- meridian/data/load.py +210 -350
- meridian/model/model.py +3 -10
- meridian/model/prior_distribution.py +7 -4
- meridian/model/prior_sampler.py +2 -0
- meridian/version.py +1 -1
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.5.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.5.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.5.dist-info}/top_level.txt +0 -0
meridian/model/model.py
CHANGED
|
@@ -295,10 +295,6 @@ class Meridian:
|
|
|
295
295
|
def is_national(self) -> bool:
|
|
296
296
|
return self.n_geos == 1
|
|
297
297
|
|
|
298
|
-
@property
|
|
299
|
-
def _sigma_shape(self) -> int:
|
|
300
|
-
return len(self.input_data.geo) if self.unique_sigma_for_each_geo else 1
|
|
301
|
-
|
|
302
298
|
@functools.cached_property
|
|
303
299
|
def knot_info(self) -> knots.KnotInfo:
|
|
304
300
|
return knots.get_knot_info(
|
|
@@ -389,6 +385,7 @@ class Meridian:
|
|
|
389
385
|
@functools.cached_property
|
|
390
386
|
def unique_sigma_for_each_geo(self) -> bool:
|
|
391
387
|
if self.is_national:
|
|
388
|
+
# Should evaluate to False.
|
|
392
389
|
return constants.NATIONAL_MODEL_SPEC_ARGS[
|
|
393
390
|
constants.UNIQUE_SIGMA_FOR_EACH_GEO
|
|
394
391
|
]
|
|
@@ -449,7 +446,7 @@ class Meridian:
|
|
|
449
446
|
n_organic_rf_channels=self.n_organic_rf_channels,
|
|
450
447
|
n_controls=self.n_controls,
|
|
451
448
|
n_non_media_channels=self.n_non_media_channels,
|
|
452
|
-
|
|
449
|
+
unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
|
|
453
450
|
n_knots=self.knot_info.n_knots,
|
|
454
451
|
is_national=self.is_national,
|
|
455
452
|
set_total_media_contribution_prior=self._set_total_media_contribution_prior,
|
|
@@ -663,10 +660,6 @@ class Meridian:
|
|
|
663
660
|
self._validate_injected_inference_data_group_coord(
|
|
664
661
|
inference_data, group, constants.TIME, self.n_times
|
|
665
662
|
)
|
|
666
|
-
if not self.model_spec.unique_sigma_for_each_geo:
|
|
667
|
-
self._validate_injected_inference_data_group_coord(
|
|
668
|
-
inference_data, group, constants.SIGMA_DIM, self._sigma_shape
|
|
669
|
-
)
|
|
670
663
|
self._validate_injected_inference_data_group_coord(
|
|
671
664
|
inference_data,
|
|
672
665
|
group,
|
|
@@ -1429,7 +1422,7 @@ class Meridian:
|
|
|
1429
1422
|
if self.unique_sigma_for_each_geo:
|
|
1430
1423
|
inference_dims[constants.SIGMA] = [constants.GEO]
|
|
1431
1424
|
else:
|
|
1432
|
-
inference_dims[constants.SIGMA] = [
|
|
1425
|
+
inference_dims[constants.SIGMA] = []
|
|
1433
1426
|
|
|
1434
1427
|
return {
|
|
1435
1428
|
param: [constants.CHAIN, constants.DRAW] + list(dims)
|
|
@@ -509,7 +509,7 @@ class PriorDistribution:
|
|
|
509
509
|
n_organic_rf_channels: int,
|
|
510
510
|
n_controls: int,
|
|
511
511
|
n_non_media_channels: int,
|
|
512
|
-
|
|
512
|
+
unique_sigma_for_each_geo: bool,
|
|
513
513
|
n_knots: int,
|
|
514
514
|
is_national: bool,
|
|
515
515
|
set_total_media_contribution_prior: bool,
|
|
@@ -527,9 +527,9 @@ class PriorDistribution:
|
|
|
527
527
|
used.
|
|
528
528
|
n_controls: Number of controls used.
|
|
529
529
|
n_non_media_channels: Number of non-media channels used.
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
530
|
+
unique_sigma_for_each_geo: A boolean indicator whether to use the same
|
|
531
|
+
sigma parameter for all geos. Only used if `n_geos > 1`. For more
|
|
532
|
+
information, see `ModelSpec`.
|
|
533
533
|
n_knots: Number of knots used.
|
|
534
534
|
is_national: A boolean indicator whether the prior distribution will be
|
|
535
535
|
adapted for a national model.
|
|
@@ -801,6 +801,9 @@ class PriorDistribution:
|
|
|
801
801
|
slope_orf = tfp.distributions.BatchBroadcast(
|
|
802
802
|
self.slope_orf, n_organic_rf_channels, name=constants.SLOPE_ORF
|
|
803
803
|
)
|
|
804
|
+
|
|
805
|
+
# If `unique_sigma_for_each_geo == False`, then make a scalar batch.
|
|
806
|
+
sigma_shape = n_geos if (n_geos > 1 and unique_sigma_for_each_geo) else []
|
|
804
807
|
sigma = tfp.distributions.BatchBroadcast(
|
|
805
808
|
self.sigma, sigma_shape, name=constants.SIGMA
|
|
806
809
|
)
|
meridian/model/prior_sampler.py
CHANGED
|
@@ -510,6 +510,8 @@ class PriorDistributionSampler:
|
|
|
510
510
|
tf.keras.utils.set_random_seed(1)
|
|
511
511
|
|
|
512
512
|
prior = mmm.prior_broadcast
|
|
513
|
+
# `sample_shape` is prepended to the shape of each BatchBroadcast in `prior`
|
|
514
|
+
# when it is sampled.
|
|
513
515
|
sample_shape = [1, n_draws]
|
|
514
516
|
sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
|
|
515
517
|
|
meridian/version.py
CHANGED
|
File without changes
|
|
File without changes
|
|
File without changes
|