google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
  2. google_meridian-1.3.0.dist-info/RECORD +62 -0
  3. meridian/analysis/__init__.py +2 -0
  4. meridian/analysis/analyzer.py +280 -142
  5. meridian/analysis/formatter.py +2 -2
  6. meridian/analysis/optimizer.py +353 -169
  7. meridian/analysis/review/__init__.py +20 -0
  8. meridian/analysis/review/checks.py +721 -0
  9. meridian/analysis/review/configs.py +110 -0
  10. meridian/analysis/review/constants.py +40 -0
  11. meridian/analysis/review/results.py +544 -0
  12. meridian/analysis/review/reviewer.py +186 -0
  13. meridian/analysis/summarizer.py +14 -12
  14. meridian/analysis/templates/chips.html.jinja +12 -0
  15. meridian/analysis/test_utils.py +27 -5
  16. meridian/analysis/visualizer.py +45 -50
  17. meridian/backend/__init__.py +698 -55
  18. meridian/backend/config.py +75 -16
  19. meridian/backend/test_utils.py +127 -1
  20. meridian/constants.py +52 -11
  21. meridian/data/input_data.py +7 -2
  22. meridian/data/test_utils.py +5 -3
  23. meridian/mlflow/autolog.py +2 -2
  24. meridian/model/__init__.py +1 -0
  25. meridian/model/adstock_hill.py +10 -9
  26. meridian/model/eda/__init__.py +3 -0
  27. meridian/model/eda/constants.py +21 -0
  28. meridian/model/eda/eda_engine.py +1580 -84
  29. meridian/model/eda/eda_outcome.py +200 -0
  30. meridian/model/eda/eda_spec.py +84 -0
  31. meridian/model/eda/meridian_eda.py +220 -0
  32. meridian/model/knots.py +56 -50
  33. meridian/model/media.py +10 -8
  34. meridian/model/model.py +79 -16
  35. meridian/model/model_test_data.py +53 -9
  36. meridian/model/posterior_sampler.py +398 -391
  37. meridian/model/prior_distribution.py +114 -39
  38. meridian/model/prior_sampler.py +146 -90
  39. meridian/model/spec.py +7 -8
  40. meridian/model/transformers.py +16 -8
  41. meridian/version.py +1 -1
  42. google_meridian-1.2.0.dist-info/RECORD +0 -52
  43. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
  44. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
  45. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ used by the Meridian model object.
19
19
  """
20
20
 
21
21
  from __future__ import annotations
22
+
22
23
  from collections.abc import MutableMapping, Sequence
23
24
  import dataclasses
24
25
  from typing import Any
@@ -26,7 +27,6 @@ import warnings
26
27
 
27
28
  from meridian import backend
28
29
  from meridian import constants
29
-
30
30
  import numpy as np
31
31
 
32
32
 
@@ -34,6 +34,8 @@ __all__ = [
34
34
  'IndependentMultivariateDistribution',
35
35
  'PriorDistribution',
36
36
  'distributions_are_equal',
37
+ 'lognormal_dist_from_mean_std',
38
+ 'lognormal_dist_from_range',
37
39
  ]
38
40
 
39
41
 
@@ -175,14 +177,14 @@ class PriorDistribution:
175
177
  xi_n: Prior distribution on the hierarchical standard deviation of
176
178
  `gamma_gn` which is the coefficient on non-media channel `n` for geo `g`.
177
179
  Hierarchy is defined over geos. Default distribution is `HalfNormal(5.0)`.
178
- alpha_m: Prior distribution on the `geometric decay` Adstock parameter for
180
+ alpha_m: Prior distribution on the Adstock decay parameter for media input.
181
+ Default distribution is `Uniform(0.0, 1.0)`.
182
+ alpha_rf: Prior distribution on the Adstock decay parameter for RF input.
183
+ Default distribution is `Uniform(0.0, 1.0)`.
184
+ alpha_om: Prior distribution on the Adstock decay parameter for organic
179
185
  media input. Default distribution is `Uniform(0.0, 1.0)`.
180
- alpha_rf: Prior distribution on the `geometric decay` Adstock parameter for
181
- RF input. Default distribution is `Uniform(0.0, 1.0)`.
182
- alpha_om: Prior distribution on the `geometric decay` Adstock parameter for
183
- organic media input. Default distribution is `Uniform(0.0, 1.0)`.
184
- alpha_orf: Prior distribution on the `geometric decay` Adstock parameter for
185
- organic RF input. Default distribution is `Uniform(0.0, 1.0)`.
186
+ alpha_orf: Prior distribution on the Adstock decay parameter for organic RF
187
+ input. Default distribution is `Uniform(0.0, 1.0)`.
186
188
  ec_m: Prior distribution on the `half-saturation` Hill parameter for media
187
189
  input. Default distribution is `TruncatedNormal(0.8, 0.8, 0.1, 10)`.
188
190
  ec_rf: Prior distribution on the `half-saturation` Hill parameter for RF
@@ -772,7 +774,7 @@ class PriorDistribution:
772
774
  )
773
775
  if (
774
776
  not isinstance(self.slope_m, backend.tfd.Deterministic)
775
- or (np.isscalar(self.slope_m.loc.numpy()) and self.slope_m.loc != 1.0)
777
+ or (backend.rank(self.slope_m.loc) == 0 and self.slope_m.loc != 1.0)
776
778
  or (
777
779
  self.slope_m.batch_shape.as_list()
778
780
  and any(x != 1.0 for x in self.slope_m.loc)
@@ -791,7 +793,7 @@ class PriorDistribution:
791
793
  )
792
794
  if (
793
795
  not isinstance(self.slope_om, backend.tfd.Deterministic)
794
- or (np.isscalar(self.slope_om.loc.numpy()) and self.slope_om.loc != 1.0)
796
+ or (backend.rank(self.slope_om.loc) == 0 and self.slope_om.loc != 1.0)
795
797
  or (
796
798
  self.slope_om.batch_shape.as_list()
797
799
  and any(x != 1.0 for x in self.slope_om.loc)
@@ -1000,8 +1002,7 @@ class IndependentMultivariateDistribution(backend.tfd.Distribution):
1000
1002
  """Check for deterministic distributions and raise an error if found."""
1001
1003
 
1002
1004
  if any(
1003
- isinstance(dist, backend.tfd.Deterministic)
1004
- for dist in distributions
1005
+ isinstance(dist, backend.tfd.Deterministic) for dist in distributions
1005
1006
  ):
1006
1007
  raise ValueError(
1007
1008
  f'{self.__class__.__name__} cannot contain `Deterministic` '
@@ -1029,9 +1030,7 @@ class IndependentMultivariateDistribution(backend.tfd.Distribution):
1029
1030
  [dist.batch_shape_tensor() for dist in self._distributions],
1030
1031
  axis=0,
1031
1032
  )
1032
- return backend.reduce_sum(
1033
- distribution_batch_shape_tensors, keepdims=True
1034
- )
1033
+ return backend.reduce_sum(distribution_batch_shape_tensors, keepdims=True)
1035
1034
 
1036
1035
  def _batch_shape(self):
1037
1036
  return backend.TensorShape(sum(self._distribution_batch_shapes))
@@ -1043,10 +1042,7 @@ class IndependentMultivariateDistribution(backend.tfd.Distribution):
1043
1042
 
1044
1043
  def _quantile(self, value):
1045
1044
  value = self._broadcast_value(value)
1046
- split_value = backend.split(
1047
- value,
1048
- self._distribution_batch_shapes, axis=-1
1049
- )
1045
+ split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
1050
1046
  quantiles = [
1051
1047
  dist.quantile(sv) for dist, sv in zip(self._distributions, split_value)
1052
1048
  ]
@@ -1055,11 +1051,7 @@ class IndependentMultivariateDistribution(backend.tfd.Distribution):
1055
1051
 
1056
1052
  def _log_prob(self, value):
1057
1053
  value = self._broadcast_value(value)
1058
- split_value = backend.split(
1059
- value,
1060
- self._distribution_batch_shapes,
1061
- axis=-1
1062
- )
1054
+ split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
1063
1055
  log_probs = [
1064
1056
  dist.log_prob(sv) for dist, sv in zip(self._distributions, split_value)
1065
1057
  ]
@@ -1068,11 +1060,7 @@ class IndependentMultivariateDistribution(backend.tfd.Distribution):
1068
1060
 
1069
1061
  def _log_cdf(self, value):
1070
1062
  value = self._broadcast_value(value)
1071
- split_value = backend.split(
1072
- value,
1073
- self._distribution_batch_shapes,
1074
- axis=-1
1075
- )
1063
+ split_value = backend.split(value, self._distribution_batch_shapes, axis=-1)
1076
1064
 
1077
1065
  log_cdfs = [
1078
1066
  dist.log_cdf(sv) for dist, sv in zip(self._distributions, split_value)
@@ -1173,6 +1161,87 @@ def distributions_are_equal(
1173
1161
  return True
1174
1162
 
1175
1163
 
1164
+ def lognormal_dist_from_mean_std(
1165
+ mean: float | Sequence[float], std: float | Sequence[float]
1166
+ ) -> backend.tfd.LogNormal:
1167
+ """Define a lognormal distribution from its mean and standard deviation.
1168
+
1169
+ This function parameterizes lognormal distributions by their mean and
1170
+ standard deviation.
1171
+
1172
+ Args:
1173
+ mean: A float or array-like object defining the distribution mean. Must be
1174
+ positive.
1175
+ std: A float or array-like object defining the distribution standard
1176
+ deviation. Must be non-negative.
1177
+
1178
+ Returns:
1179
+ A `backend.tfd.LogNormal` object with the input mean and standard deviation.
1180
+ """
1181
+
1182
+ mean = np.asarray(mean)
1183
+ std = np.asarray(std)
1184
+
1185
+ mu = np.log(mean) - 0.5 * np.log((std / mean) ** 2 + 1)
1186
+ sigma = np.sqrt(np.log((std / mean) ** 2 + 1))
1187
+
1188
+ return backend.tfd.LogNormal(mu, sigma)
1189
+
1190
+
1191
+ def lognormal_dist_from_range(
1192
+ low: float | Sequence[float],
1193
+ high: float | Sequence[float],
1194
+ mass_percent: float | Sequence[float] = 0.95,
1195
+ ) -> backend.tfd.LogNormal:
1196
+ """Define a LogNormal distribution from a specified range.
1197
+
1198
+ This function parameterizes lognormal distributions by the bounds of a range,
1199
+ so that the specified probability mass falls within the bounds defined by
1200
+ `low` and `high`. The probability mass is symmetric about the median. For
1201
+ example, to define a lognormal distribution with a 95% probability mass of
1202
+ (1, 10), use:
1203
+
1204
+ ```python
1205
+ lognormal = lognormal_dist_from_range(1.0, 10.0, mass_percent=0.95)
1206
+ ```
1207
+
1208
+ Args:
1209
+ low: Float or array-like denoting the lower bound of the range. Values must
1210
+ be non-negative.
1211
+ high: Float or array-like denoting the upper bound of range. Values must be
1212
+ non-negative.
1213
+ mass_percent: Float or array-like denoting the probability mass. Values must
1214
+ be between 0 and 1 (exclusive). Default: 0.95.
1215
+
1216
+ Returns:
1217
+ A `backend.tfd.LogNormal` object with the input percentage mass falling
1218
+ within the given range.
1219
+ """
1220
+ low = np.asarray(low)
1221
+ high = np.asarray(high)
1222
+ mass_percent = np.asarray(mass_percent)
1223
+
1224
+ if not ((0.0 < low).all() and (low < high).all()): # pytype: disable=attribute-error
1225
+ raise ValueError("'low' and 'high' values must be non-negative and satisfy "
1226
+ "high > low.")
1227
+
1228
+ if not ((0.0 < mass_percent).all() and (mass_percent < 1.0).all()): # pytype: disable=attribute-error
1229
+ raise ValueError(
1230
+ "'mass_percent' values must be between 0 and 1, exclusive."
1231
+ )
1232
+
1233
+ normal = backend.tfd.Normal(0, 1)
1234
+ mass_lower = 0.5 - (mass_percent / 2)
1235
+ mass_upper = 0.5 + (mass_percent / 2)
1236
+
1237
+ sigma = np.log(high / low) / (
1238
+ normal.quantile(mass_upper) - normal.quantile(mass_lower)
1239
+ )
1240
+ mu = np.log(high) - normal.quantile(mass_upper) * sigma
1241
+
1242
+ return backend.tfd.LogNormal(mu, sigma)
1243
+
1244
+
1176
1245
  def _convert_to_deterministic_0_distribution(
1177
1246
  distribution: backend.tfd.Distribution,
1178
1247
  ) -> backend.tfd.Distribution:
@@ -1257,26 +1326,31 @@ def _validate_support(
1257
1326
  """
1258
1327
  # Note that `tfp.distributions.BatchBroadcast` objects have a `distribution`
1259
1328
  # attribute that points to a `tfp.distributions.Distribution` object.
1260
- if isinstance(tfp_dist, backend.tfp.distributions.BatchBroadcast):
1329
+ if isinstance(tfp_dist, backend.tfd.BatchBroadcast):
1261
1330
  tfp_dist = tfp_dist.distribution
1262
1331
  # Note that `tfp.distributions.Deterministic` does not have a `quantile`
1263
1332
  # method implemented, so the min and max values must be extracted from the
1264
1333
  # `loc` attribute instead.
1265
- if isinstance(
1266
- tfp_dist,
1267
- backend.tfp.python.distributions.deterministic.Deterministic
1268
- ):
1334
+ if isinstance(tfp_dist, backend.tfd.Deterministic):
1269
1335
  support_min_vals = tfp_dist.loc
1270
1336
  support_max_vals = tfp_dist.loc
1271
1337
  for i in (0, 1):
1272
- if (
1273
- prevent_deterministic_prior_at_bounds[i]
1274
- and np.any(tfp_dist.loc == bounds[i])
1338
+ if prevent_deterministic_prior_at_bounds[i] and np.any(
1339
+ tfp_dist.loc == bounds[i]
1275
1340
  ):
1276
1341
  raise ValueError(
1277
1342
  f'{parameter_name} was assigned a point mass (deterministic) prior'
1278
1343
  f' at {bounds[i]}, which is not allowed.'
1279
1344
  )
1345
+ elif isinstance(tfp_dist, backend.tfd.TruncatedNormal):
1346
+ # TruncatedNormal quantile method is not reliable, particularly when the
1347
+ # `low` or `high` value falls into extreme percentile of the untruncated
1348
+ # distribution. Note that
1349
+ # `TruncatedNormal.experimental_default_event_space_bijector()([-inf, inf])`
1350
+ # returns the correct support range, so this method could be used if the
1351
+ # `quantile` method is found to be unreliable for other distributions.
1352
+ support_min_vals = tfp_dist.low
1353
+ support_max_vals = tfp_dist.high
1280
1354
  else:
1281
1355
  try:
1282
1356
  support_min_vals = tfp_dist.quantile(0)
@@ -1284,9 +1358,9 @@ def _validate_support(
1284
1358
  except (AttributeError, NotImplementedError):
1285
1359
  warnings.warn(
1286
1360
  f'The prior distribution for {parameter_name} does not have a'
1287
- f' `quantile` method implemented, so the support range validation'
1361
+ ' `quantile` method implemented, so the support range validation'
1288
1362
  f' was skipped. Confirm that your prior for {parameter_name} is'
1289
- f' appropriate.'
1363
+ ' appropriate.'
1290
1364
  )
1291
1365
  return
1292
1366
  if np.any(support_min_vals < bounds[0]):
@@ -1300,6 +1374,7 @@ def _validate_support(
1300
1374
  f' greater than the parameter maximum {bounds[1]}.'
1301
1375
  )
1302
1376
 
1377
+
1303
1378
  # Dictionary of parameters that have a limited parameters space. The tuple
1304
1379
  # contains the lower and upper bounds, respectively.
1305
1380
  _parameter_space_bounds = {