google-meridian 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/model/model.py CHANGED
@@ -27,12 +27,13 @@ from meridian.data import time_coordinates as tc
27
27
  from meridian.model import adstock_hill
28
28
  from meridian.model import knots
29
29
  from meridian.model import media
30
+ from meridian.model import posterior_sampler
30
31
  from meridian.model import prior_distribution
32
+ from meridian.model import prior_sampler
31
33
  from meridian.model import spec
32
34
  from meridian.model import transformers
33
35
  import numpy as np
34
36
  import tensorflow as tf
35
- import tensorflow_probability as tfp
36
37
 
37
38
 
38
39
  __all__ = [
@@ -49,12 +50,8 @@ class NotFittedModelError(Exception):
49
50
  """Model has not been fitted."""
50
51
 
51
52
 
52
- class MCMCSamplingError(Exception):
53
- """The Markov Chain Monte Carlo (MCMC) sampling failed."""
54
-
55
-
56
- class MCMCOOMError(Exception):
57
- """The Markov Chain Monte Carlo (MCMC) exceeds memory limits."""
53
+ MCMCSamplingError = posterior_sampler.MCMCSamplingError
54
+ MCMCOOMError = posterior_sampler.MCMCOOMError
58
55
 
59
56
 
60
57
  def _warn_setting_national_args(**kwargs):
@@ -70,43 +67,6 @@ def _warn_setting_national_args(**kwargs):
70
67
  )
71
68
 
72
69
 
73
- def _get_tau_g(
74
- tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int
75
- ) -> tfp.distributions.Distribution:
76
- """Computes `tau_g` from `tau_g_excl_baseline`.
77
-
78
- This function computes `tau_g` by inserting a column of zeros at the
79
- `baseline_geo` position in `tau_g_excl_baseline`.
80
-
81
- Args:
82
- tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the
83
- user-defined dimensions of the `tau_g` parameter distribution.
84
- baseline_geo_idx: The index of the baseline geo to be set to zero.
85
-
86
- Returns:
87
- A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g`
88
- parameter with zero at position `baseline_geo_idx` and matching
89
- `tau_g_excl_baseline` elsewhere.
90
- """
91
- rank = len(tau_g_excl_baseline.shape)
92
- shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
93
- tau_g = tf.concat(
94
- [
95
- tau_g_excl_baseline[..., :baseline_geo_idx],
96
- tf.zeros(shape, dtype=tau_g_excl_baseline.dtype),
97
- tau_g_excl_baseline[..., baseline_geo_idx:],
98
- ],
99
- axis=rank - 1,
100
- )
101
- return tfp.distributions.Deterministic(tau_g, name="tau_g")
102
-
103
-
104
- @tf.function(autograph=False, jit_compile=True)
105
- def _xla_windowed_adaptive_nuts(**kwargs):
106
- """XLA wrapper for windowed_adaptive_nuts."""
107
- return tfp.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
108
-
109
-
110
70
  class Meridian:
111
71
  """Contains the main functionality for fitting the Meridian MMM model.
112
72
 
@@ -452,6 +412,18 @@ class Meridian:
452
412
  total_spend=agg_total_spend,
453
413
  )
454
414
 
415
+ @functools.cached_property
416
+ def prior_sampler_callable(self) -> prior_sampler.PriorDistributionSampler:
417
+ """A `PriorDistributionSampler` callable bound to this model."""
418
+ return prior_sampler.PriorDistributionSampler(self)
419
+
420
+ @functools.cached_property
421
+ def posterior_sampler_callable(
422
+ self,
423
+ ) -> posterior_sampler.PosteriorMCMCSampler:
424
+ """A `PosteriorMCMCSampler` callable bound to this model."""
425
+ return posterior_sampler.PosteriorMCMCSampler(self)
426
+
455
427
  def expand_selected_time_dims(
456
428
  self,
457
429
  start_date: tc.Date | None = None,
@@ -565,9 +537,10 @@ class Meridian:
565
537
  self._validate_injected_inference_data_group_coord(
566
538
  inference_data, group, constants.TIME, self.n_times
567
539
  )
568
- self._validate_injected_inference_data_group_coord(
569
- inference_data, group, constants.SIGMA_DIM, self._sigma_shape
570
- )
540
+ if not self.model_spec.unique_sigma_for_each_geo:
541
+ self._validate_injected_inference_data_group_coord(
542
+ inference_data, group, constants.SIGMA_DIM, self._sigma_shape
543
+ )
571
544
  self._validate_injected_inference_data_group_coord(
572
545
  inference_data,
573
546
  group,
@@ -720,7 +693,7 @@ class Meridian:
720
693
  raise ValueError(
721
694
  f"Custom priors should be set on `{constants.MROI_M}` and"
722
695
  f" `{constants.MROI_RF}` when KPI is non-revenue and revenue per kpi"
723
- f" data is missing."
696
+ " data is missing."
724
697
  )
725
698
 
726
699
  def _validate_geo_invariants(self):
@@ -955,143 +928,6 @@ class Meridian:
955
928
 
956
929
  return rf_out
957
930
 
958
- def _get_roi_prior_beta_m_value(
959
- self,
960
- alpha_m: tf.Tensor,
961
- beta_gm_dev: tf.Tensor,
962
- ec_m: tf.Tensor,
963
- eta_m: tf.Tensor,
964
- roi_or_mroi_m: tf.Tensor,
965
- slope_m: tf.Tensor,
966
- media_transformed: tf.Tensor,
967
- ) -> tf.Tensor:
968
- """Returns a tensor to be used in `beta_m`."""
969
- # The `roi_or_mroi_m` parameter represents either ROI or mROI. For reach &
970
- # frequency channels, marginal ROI priors are defined as "mROI by reach",
971
- # which is equivalent to ROI.
972
- media_spend = self.media_tensors.media_spend
973
- media_spend_counterfactual = self.media_tensors.media_spend_counterfactual
974
- media_counterfactual_scaled = self.media_tensors.media_counterfactual_scaled
975
- # If we got here, then we should already have media tensors derived from
976
- # non-None InputData.media data.
977
- assert media_spend is not None
978
- assert media_spend_counterfactual is not None
979
- assert media_counterfactual_scaled is not None
980
-
981
- # Use absolute value here because this difference will be negative for
982
- # marginal ROI priors.
983
- inc_revenue_m = roi_or_mroi_m * tf.reduce_sum(
984
- tf.abs(media_spend - media_spend_counterfactual),
985
- range(media_spend.ndim - 1),
986
- )
987
-
988
- if (
989
- self.model_spec.roi_calibration_period is None
990
- and self.model_spec.paid_media_prior_type
991
- == constants.PAID_MEDIA_PRIOR_TYPE_ROI
992
- ):
993
- # We can skip the adstock/hill computation step in this case.
994
- media_counterfactual_transformed = tf.zeros_like(media_transformed)
995
- else:
996
- media_counterfactual_transformed = self.adstock_hill_media(
997
- media=media_counterfactual_scaled,
998
- alpha=alpha_m,
999
- ec=ec_m,
1000
- slope=slope_m,
1001
- )
1002
-
1003
- revenue_per_kpi = self.revenue_per_kpi
1004
- if self.input_data.revenue_per_kpi is None:
1005
- revenue_per_kpi = tf.ones([self.n_geos, self.n_times], dtype=tf.float32)
1006
- # Note: use absolute value here because this difference will be negative for
1007
- # marginal ROI priors.
1008
- media_contrib_gm = tf.einsum(
1009
- "...gtm,g,,gt->...gm",
1010
- tf.abs(media_transformed - media_counterfactual_transformed),
1011
- self.population,
1012
- self.kpi_transformer.population_scaled_stdev,
1013
- revenue_per_kpi,
1014
- )
1015
-
1016
- if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
1017
- media_contrib_m = tf.einsum("...gm->...m", media_contrib_gm)
1018
- random_effect_m = tf.einsum(
1019
- "...m,...gm,...gm->...m", eta_m, beta_gm_dev, media_contrib_gm
1020
- )
1021
- return (inc_revenue_m - random_effect_m) / media_contrib_m
1022
- else:
1023
- # For log_normal, beta_m and eta_m are not mean & std.
1024
- # The parameterization is beta_gm ~ exp(beta_m + eta_m * N(0, 1)).
1025
- random_effect_m = tf.einsum(
1026
- "...gm,...gm->...m",
1027
- tf.math.exp(beta_gm_dev * eta_m[..., tf.newaxis, :]),
1028
- media_contrib_gm,
1029
- )
1030
- return tf.math.log(inc_revenue_m) - tf.math.log(random_effect_m)
1031
-
1032
- def _get_roi_prior_beta_rf_value(
1033
- self,
1034
- alpha_rf: tf.Tensor,
1035
- beta_grf_dev: tf.Tensor,
1036
- ec_rf: tf.Tensor,
1037
- eta_rf: tf.Tensor,
1038
- roi_or_mroi_rf: tf.Tensor,
1039
- slope_rf: tf.Tensor,
1040
- rf_transformed: tf.Tensor,
1041
- ) -> tf.Tensor:
1042
- """Returns a tensor to be used in `beta_rf`."""
1043
- rf_spend = self.rf_tensors.rf_spend
1044
- rf_spend_counterfactual = self.rf_tensors.rf_spend_counterfactual
1045
- reach_counterfactual_scaled = self.rf_tensors.reach_counterfactual_scaled
1046
- frequency = self.rf_tensors.frequency
1047
- # If we got here, then we should already have RF media tensors derived from
1048
- # non-None InputData.reach data.
1049
- assert rf_spend is not None
1050
- assert rf_spend_counterfactual is not None
1051
- assert reach_counterfactual_scaled is not None
1052
- assert frequency is not None
1053
-
1054
- inc_revenue_rf = roi_or_mroi_rf * tf.reduce_sum(
1055
- rf_spend - rf_spend_counterfactual,
1056
- range(rf_spend.ndim - 1),
1057
- )
1058
- if self.model_spec.rf_roi_calibration_period is not None:
1059
- rf_counterfactual_transformed = self.adstock_hill_rf(
1060
- reach=reach_counterfactual_scaled,
1061
- frequency=frequency,
1062
- alpha=alpha_rf,
1063
- ec=ec_rf,
1064
- slope=slope_rf,
1065
- )
1066
- else:
1067
- rf_counterfactual_transformed = tf.zeros_like(rf_transformed)
1068
- revenue_per_kpi = self.revenue_per_kpi
1069
- if self.input_data.revenue_per_kpi is None:
1070
- revenue_per_kpi = tf.ones([self.n_geos, self.n_times], dtype=tf.float32)
1071
-
1072
- media_contrib_grf = tf.einsum(
1073
- "...gtm,g,,gt->...gm",
1074
- rf_transformed - rf_counterfactual_transformed,
1075
- self.population,
1076
- self.kpi_transformer.population_scaled_stdev,
1077
- revenue_per_kpi,
1078
- )
1079
- if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
1080
- media_contrib_rf = tf.einsum("...gm->...m", media_contrib_grf)
1081
- random_effect_rf = tf.einsum(
1082
- "...m,...gm,...gm->...m", eta_rf, beta_grf_dev, media_contrib_grf
1083
- )
1084
- return (inc_revenue_rf - random_effect_rf) / media_contrib_rf
1085
- else:
1086
- # For log_normal, beta_rf and eta_rf are not mean & std.
1087
- # The parameterization is beta_grf ~ exp(beta_rf + eta_rf * N(0, 1)).
1088
- random_effect_rf = tf.einsum(
1089
- "...gm,...gm->...m",
1090
- tf.math.exp(beta_grf_dev * eta_rf[..., tf.newaxis, :]),
1091
- media_contrib_grf,
1092
- )
1093
- return tf.math.log(inc_revenue_rf) - tf.math.log(random_effect_rf)
1094
-
1095
931
  def populate_cached_properties(self):
1096
932
  """Eagerly activates all cached properties.
1097
933
 
@@ -1111,301 +947,7 @@ class Meridian:
1111
947
  for attr in cached_properties:
1112
948
  _ = getattr(self, attr)
1113
949
 
1114
- def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution:
1115
- """Returns JointDistributionCoroutineAutoBatched function for MCMC."""
1116
-
1117
- self.populate_cached_properties()
1118
-
1119
- # This lists all the derived properties and states of this Meridian object
1120
- # that are referenced by the joint distribution coroutine.
1121
- # That is, these are the list of captured parameters.
1122
- prior_broadcast = self.prior_broadcast
1123
- baseline_geo_idx = self.baseline_geo_idx
1124
- knot_info = self.knot_info
1125
- n_geos = self.n_geos
1126
- n_times = self.n_times
1127
- n_media_channels = self.n_media_channels
1128
- n_rf_channels = self.n_rf_channels
1129
- n_organic_media_channels = self.n_organic_media_channels
1130
- n_organic_rf_channels = self.n_organic_rf_channels
1131
- n_controls = self.n_controls
1132
- n_non_media_channels = self.n_non_media_channels
1133
- holdout_id = self.holdout_id
1134
- media_tensors = self.media_tensors
1135
- rf_tensors = self.rf_tensors
1136
- organic_media_tensors = self.organic_media_tensors
1137
- organic_rf_tensors = self.organic_rf_tensors
1138
- controls_scaled = self.controls_scaled
1139
- non_media_treatments_scaled = self.non_media_treatments_scaled
1140
- media_effects_dist = self.media_effects_dist
1141
- adstock_hill_media_fn = self.adstock_hill_media
1142
- adstock_hill_rf_fn = self.adstock_hill_rf
1143
- get_roi_prior_beta_m_value_fn = self._get_roi_prior_beta_m_value
1144
- get_roi_prior_beta_rf_value_fn = self._get_roi_prior_beta_rf_value
1145
-
1146
- # TODO: Extract this coroutine to be unittestable on its own.
1147
- # This MCMC sampling technique is complex enough to have its own abstraction
1148
- # and testable API, rather than being embedded as a private method in the
1149
- # Meridian class.
1150
- @tfp.distributions.JointDistributionCoroutineAutoBatched
1151
- def joint_dist_unpinned():
1152
- # Sample directly from prior.
1153
- knot_values = yield prior_broadcast.knot_values
1154
- gamma_c = yield prior_broadcast.gamma_c
1155
- xi_c = yield prior_broadcast.xi_c
1156
- sigma = yield prior_broadcast.sigma
1157
-
1158
- tau_g_excl_baseline = yield tfp.distributions.Sample(
1159
- prior_broadcast.tau_g_excl_baseline,
1160
- name=constants.TAU_G_EXCL_BASELINE,
1161
- )
1162
- tau_g = yield _get_tau_g(
1163
- tau_g_excl_baseline=tau_g_excl_baseline,
1164
- baseline_geo_idx=baseline_geo_idx,
1165
- )
1166
- mu_t = yield tfp.distributions.Deterministic(
1167
- tf.einsum(
1168
- "k,kt->t",
1169
- knot_values,
1170
- tf.convert_to_tensor(knot_info.weights),
1171
- ),
1172
- name=constants.MU_T,
1173
- )
1174
-
1175
- tau_gt = tau_g[:, tf.newaxis] + mu_t
1176
- combined_media_transformed = tf.zeros(
1177
- shape=(n_geos, n_times, 0), dtype=tf.float32
1178
- )
1179
- combined_beta = tf.zeros(shape=(n_geos, 0), dtype=tf.float32)
1180
- if media_tensors.media is not None:
1181
- alpha_m = yield prior_broadcast.alpha_m
1182
- ec_m = yield prior_broadcast.ec_m
1183
- eta_m = yield prior_broadcast.eta_m
1184
- slope_m = yield prior_broadcast.slope_m
1185
- beta_gm_dev = yield tfp.distributions.Sample(
1186
- tfp.distributions.Normal(0, 1),
1187
- [n_geos, n_media_channels],
1188
- name=constants.BETA_GM_DEV,
1189
- )
1190
- media_transformed = adstock_hill_media_fn(
1191
- media=media_tensors.media_scaled,
1192
- alpha=alpha_m,
1193
- ec=ec_m,
1194
- slope=slope_m,
1195
- )
1196
- prior_type = self.model_spec.paid_media_prior_type
1197
- if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
1198
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
1199
- roi_or_mroi_m = yield prior_broadcast.roi_m
1200
- else:
1201
- roi_or_mroi_m = yield prior_broadcast.mroi_m
1202
- beta_m_value = get_roi_prior_beta_m_value_fn(
1203
- alpha_m,
1204
- beta_gm_dev,
1205
- ec_m,
1206
- eta_m,
1207
- roi_or_mroi_m,
1208
- slope_m,
1209
- media_transformed,
1210
- )
1211
- beta_m = yield tfp.distributions.Deterministic(
1212
- beta_m_value, name=constants.BETA_M
1213
- )
1214
- else:
1215
- beta_m = yield prior_broadcast.beta_m
1216
-
1217
- beta_eta_combined = beta_m + eta_m * beta_gm_dev
1218
- beta_gm_value = (
1219
- beta_eta_combined
1220
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1221
- else tf.math.exp(beta_eta_combined)
1222
- )
1223
- beta_gm = yield tfp.distributions.Deterministic(
1224
- beta_gm_value, name=constants.BETA_GM
1225
- )
1226
- combined_media_transformed = tf.concat(
1227
- [combined_media_transformed, media_transformed], axis=-1
1228
- )
1229
- combined_beta = tf.concat([combined_beta, beta_gm], axis=-1)
1230
-
1231
- if rf_tensors.reach is not None:
1232
- alpha_rf = yield prior_broadcast.alpha_rf
1233
- ec_rf = yield prior_broadcast.ec_rf
1234
- eta_rf = yield prior_broadcast.eta_rf
1235
- slope_rf = yield prior_broadcast.slope_rf
1236
- beta_grf_dev = yield tfp.distributions.Sample(
1237
- tfp.distributions.Normal(0, 1),
1238
- [n_geos, n_rf_channels],
1239
- name=constants.BETA_GRF_DEV,
1240
- )
1241
- rf_transformed = adstock_hill_rf_fn(
1242
- reach=rf_tensors.reach_scaled,
1243
- frequency=rf_tensors.frequency,
1244
- alpha=alpha_rf,
1245
- ec=ec_rf,
1246
- slope=slope_rf,
1247
- )
1248
-
1249
- prior_type = self.model_spec.paid_media_prior_type
1250
- if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
1251
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
1252
- roi_or_mroi_rf = yield prior_broadcast.roi_rf
1253
- else:
1254
- roi_or_mroi_rf = yield prior_broadcast.mroi_rf
1255
- beta_rf_value = get_roi_prior_beta_rf_value_fn(
1256
- alpha_rf,
1257
- beta_grf_dev,
1258
- ec_rf,
1259
- eta_rf,
1260
- roi_or_mroi_rf,
1261
- slope_rf,
1262
- rf_transformed,
1263
- )
1264
- beta_rf = yield tfp.distributions.Deterministic(
1265
- beta_rf_value,
1266
- name=constants.BETA_RF,
1267
- )
1268
- else:
1269
- beta_rf = yield prior_broadcast.beta_rf
1270
-
1271
- beta_eta_combined = beta_rf + eta_rf * beta_grf_dev
1272
- beta_grf_value = (
1273
- beta_eta_combined
1274
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1275
- else tf.math.exp(beta_eta_combined)
1276
- )
1277
- beta_grf = yield tfp.distributions.Deterministic(
1278
- beta_grf_value, name=constants.BETA_GRF
1279
- )
1280
- combined_media_transformed = tf.concat(
1281
- [combined_media_transformed, rf_transformed], axis=-1
1282
- )
1283
- combined_beta = tf.concat([combined_beta, beta_grf], axis=-1)
1284
-
1285
- if organic_media_tensors.organic_media is not None:
1286
- alpha_om = yield prior_broadcast.alpha_om
1287
- ec_om = yield prior_broadcast.ec_om
1288
- eta_om = yield prior_broadcast.eta_om
1289
- slope_om = yield prior_broadcast.slope_om
1290
- beta_gom_dev = yield tfp.distributions.Sample(
1291
- tfp.distributions.Normal(0, 1),
1292
- [n_geos, n_organic_media_channels],
1293
- name=constants.BETA_GOM_DEV,
1294
- )
1295
- organic_media_transformed = adstock_hill_media_fn(
1296
- media=organic_media_tensors.organic_media_scaled,
1297
- alpha=alpha_om,
1298
- ec=ec_om,
1299
- slope=slope_om,
1300
- )
1301
- beta_om = yield prior_broadcast.beta_om
1302
-
1303
- beta_eta_combined = beta_om + eta_om * beta_gom_dev
1304
- beta_gom_value = (
1305
- beta_eta_combined
1306
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1307
- else tf.math.exp(beta_eta_combined)
1308
- )
1309
- beta_gom = yield tfp.distributions.Deterministic(
1310
- beta_gom_value, name=constants.BETA_GOM
1311
- )
1312
- combined_media_transformed = tf.concat(
1313
- [combined_media_transformed, organic_media_transformed], axis=-1
1314
- )
1315
- combined_beta = tf.concat([combined_beta, beta_gom], axis=-1)
1316
-
1317
- if organic_rf_tensors.organic_reach is not None:
1318
- alpha_orf = yield prior_broadcast.alpha_orf
1319
- ec_orf = yield prior_broadcast.ec_orf
1320
- eta_orf = yield prior_broadcast.eta_orf
1321
- slope_orf = yield prior_broadcast.slope_orf
1322
- beta_gorf_dev = yield tfp.distributions.Sample(
1323
- tfp.distributions.Normal(0, 1),
1324
- [n_geos, n_organic_rf_channels],
1325
- name=constants.BETA_GORF_DEV,
1326
- )
1327
- organic_rf_transformed = adstock_hill_rf_fn(
1328
- reach=organic_rf_tensors.organic_reach_scaled,
1329
- frequency=organic_rf_tensors.organic_frequency,
1330
- alpha=alpha_orf,
1331
- ec=ec_orf,
1332
- slope=slope_orf,
1333
- )
1334
-
1335
- beta_orf = yield prior_broadcast.beta_orf
1336
-
1337
- beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev
1338
- beta_gorf_value = (
1339
- beta_eta_combined
1340
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1341
- else tf.math.exp(beta_eta_combined)
1342
- )
1343
- beta_gorf = yield tfp.distributions.Deterministic(
1344
- beta_gorf_value, name=constants.BETA_GORF
1345
- )
1346
- combined_media_transformed = tf.concat(
1347
- [combined_media_transformed, organic_rf_transformed], axis=-1
1348
- )
1349
- combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1)
1350
-
1351
- sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos]))
1352
- gamma_gc_dev = yield tfp.distributions.Sample(
1353
- tfp.distributions.Normal(0, 1),
1354
- [n_geos, n_controls],
1355
- name=constants.GAMMA_GC_DEV,
1356
- )
1357
- gamma_gc = yield tfp.distributions.Deterministic(
1358
- gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
1359
- )
1360
- y_pred_combined_media = (
1361
- tau_gt
1362
- + tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta)
1363
- + tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc)
1364
- )
1365
-
1366
- if self.non_media_treatments is not None:
1367
- gamma_n = yield prior_broadcast.gamma_n
1368
- xi_n = yield prior_broadcast.xi_n
1369
- gamma_gn_dev = yield tfp.distributions.Sample(
1370
- tfp.distributions.Normal(0, 1),
1371
- [n_geos, n_non_media_channels],
1372
- name=constants.GAMMA_GN_DEV,
1373
- )
1374
- gamma_gn = yield tfp.distributions.Deterministic(
1375
- gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
1376
- )
1377
- y_pred = y_pred_combined_media + tf.einsum(
1378
- "gtn,gn->gt", non_media_treatments_scaled, gamma_gn
1379
- )
1380
- else:
1381
- y_pred = y_pred_combined_media
1382
-
1383
- # If there are any holdout observations, the holdout KPI values will
1384
- # be replaced with zeros using `experimental_pin`. For these
1385
- # observations, we set the posterior mean equal to zero and standard
1386
- # deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the
1387
- # sampled posterior parameter values.
1388
- if holdout_id is not None:
1389
- y_pred_holdout = tf.where(holdout_id, 0.0, y_pred)
1390
- test_sd = tf.cast(1.0 / np.sqrt(2.0 * np.pi), tf.float32)
1391
- sigma_gt_holdout = tf.where(holdout_id, test_sd, sigma_gt)
1392
- yield tfp.distributions.Normal(
1393
- y_pred_holdout, sigma_gt_holdout, name="y"
1394
- )
1395
- else:
1396
- yield tfp.distributions.Normal(y_pred, sigma_gt, name="y")
1397
-
1398
- return joint_dist_unpinned
1399
-
1400
- def _get_joint_dist(self) -> tfp.distributions.Distribution:
1401
- y = (
1402
- tf.where(self.holdout_id, 0.0, self.kpi_scaled)
1403
- if self.holdout_id is not None
1404
- else self.kpi_scaled
1405
- )
1406
- return self._get_joint_dist_unpinned().experimental_pin(y=y)
1407
-
1408
- def _create_inference_data_coords(
950
+ def create_inference_data_coords(
1409
951
  self, n_chains: int, n_draws: int
1410
952
  ) -> Mapping[str, np.ndarray | Sequence[str]]:
1411
953
  """Creates data coordinates for inference data."""
@@ -1449,7 +991,7 @@ class Meridian:
1449
991
  constants.ORGANIC_RF_CHANNEL: organic_rf_channel_values,
1450
992
  }
1451
993
 
1452
- def _create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
994
+ def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
1453
995
  inference_dims = dict(constants.INFERENCE_DIMS)
1454
996
  if self.unique_sigma_for_each_geo:
1455
997
  inference_dims[constants.SIGMA] = [constants.GEO]
@@ -1461,412 +1003,18 @@ class Meridian:
1461
1003
  for param, dims in inference_dims.items()
1462
1004
  }
1463
1005
 
1464
- def _sample_media_priors(
1465
- self,
1466
- n_draws: int,
1467
- seed: int | None = None,
1468
- ) -> Mapping[str, tf.Tensor]:
1469
- """Draws samples from the prior distributions of the media variables.
1470
-
1471
- Args:
1472
- n_draws: Number of samples drawn from the prior distribution.
1473
- seed: Used to set the seed for reproducible results. For more information,
1474
- see [PRNGS and seeds]
1475
- (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1476
-
1477
- Returns:
1478
- A mapping of media parameter names to a tensor of shape [n_draws, n_geos,
1479
- n_media_channels] or [n_draws, n_media_channels] containing the
1480
- samples.
1481
- """
1482
- prior = self.prior_broadcast
1483
- sample_shape = [1, n_draws]
1484
- sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
1485
- media_vars = {
1486
- constants.ALPHA_M: prior.alpha_m.sample(**sample_kwargs),
1487
- constants.EC_M: prior.ec_m.sample(**sample_kwargs),
1488
- constants.ETA_M: prior.eta_m.sample(**sample_kwargs),
1489
- constants.SLOPE_M: prior.slope_m.sample(**sample_kwargs),
1490
- }
1491
- beta_gm_dev = tfp.distributions.Sample(
1492
- tfp.distributions.Normal(0, 1),
1493
- [self.n_geos, self.n_media_channels],
1494
- name=constants.BETA_GM_DEV,
1495
- ).sample(**sample_kwargs)
1496
- media_transformed = self.adstock_hill_media(
1497
- media=self.media_tensors.media_scaled,
1498
- alpha=media_vars[constants.ALPHA_M],
1499
- ec=media_vars[constants.EC_M],
1500
- slope=media_vars[constants.SLOPE_M],
1501
- )
1502
-
1503
- prior_type = self.model_spec.paid_media_prior_type
1504
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
1505
- roi_m = prior.roi_m.sample(**sample_kwargs)
1506
- beta_m_value = self._get_roi_prior_beta_m_value(
1507
- beta_gm_dev=beta_gm_dev,
1508
- media_transformed=media_transformed,
1509
- roi_or_mroi_m=roi_m,
1510
- **media_vars,
1511
- )
1512
- media_vars[constants.ROI_M] = roi_m
1513
- media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
1514
- beta_m_value, name=constants.BETA_M
1515
- ).sample()
1516
- elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
1517
- mroi_m = prior.mroi_m.sample(**sample_kwargs)
1518
- beta_m_value = self._get_roi_prior_beta_m_value(
1519
- beta_gm_dev=beta_gm_dev,
1520
- media_transformed=media_transformed,
1521
- roi_or_mroi_m=mroi_m,
1522
- **media_vars,
1523
- )
1524
- media_vars[constants.MROI_M] = mroi_m
1525
- media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
1526
- beta_m_value, name=constants.BETA_M
1527
- ).sample()
1528
- else:
1529
- media_vars[constants.BETA_M] = prior.beta_m.sample(**sample_kwargs)
1530
-
1531
- beta_eta_combined = (
1532
- media_vars[constants.BETA_M][..., tf.newaxis, :]
1533
- + media_vars[constants.ETA_M][..., tf.newaxis, :] * beta_gm_dev
1534
- )
1535
- beta_gm_value = (
1536
- beta_eta_combined
1537
- if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1538
- else tf.math.exp(beta_eta_combined)
1539
- )
1540
- media_vars[constants.BETA_GM] = tfp.distributions.Deterministic(
1541
- beta_gm_value, name=constants.BETA_GM
1542
- ).sample()
1543
-
1544
- return media_vars
1545
-
1546
- def _sample_rf_priors(
1547
- self,
1548
- n_draws: int,
1549
- seed: int | None = None,
1550
- ) -> Mapping[str, tf.Tensor]:
1551
- """Draws samples from the prior distributions of the RF variables.
1552
-
1553
- Args:
1554
- n_draws: Number of samples drawn from the prior distribution.
1555
- seed: Used to set the seed for reproducible results. For more information,
1556
- see [PRNGS and seeds]
1557
- (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1558
-
1559
- Returns:
1560
- A mapping of RF parameter names to a tensor of shape [n_draws, n_geos,
1561
- n_rf_channels] or [n_draws, n_rf_channels] containing the samples.
1562
- """
1563
- prior = self.prior_broadcast
1564
- sample_shape = [1, n_draws]
1565
- sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
1566
- rf_vars = {
1567
- constants.ALPHA_RF: prior.alpha_rf.sample(**sample_kwargs),
1568
- constants.EC_RF: prior.ec_rf.sample(**sample_kwargs),
1569
- constants.ETA_RF: prior.eta_rf.sample(**sample_kwargs),
1570
- constants.SLOPE_RF: prior.slope_rf.sample(**sample_kwargs),
1571
- }
1572
- beta_grf_dev = tfp.distributions.Sample(
1573
- tfp.distributions.Normal(0, 1),
1574
- [self.n_geos, self.n_rf_channels],
1575
- name=constants.BETA_GRF_DEV,
1576
- ).sample(**sample_kwargs)
1577
- rf_transformed = self.adstock_hill_rf(
1578
- reach=self.rf_tensors.reach_scaled,
1579
- frequency=self.rf_tensors.frequency,
1580
- alpha=rf_vars[constants.ALPHA_RF],
1581
- ec=rf_vars[constants.EC_RF],
1582
- slope=rf_vars[constants.SLOPE_RF],
1583
- )
1584
-
1585
- prior_type = self.model_spec.paid_media_prior_type
1586
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
1587
- roi_rf = prior.roi_rf.sample(**sample_kwargs)
1588
- beta_rf_value = self._get_roi_prior_beta_rf_value(
1589
- beta_grf_dev=beta_grf_dev,
1590
- rf_transformed=rf_transformed,
1591
- roi_or_mroi_rf=roi_rf,
1592
- **rf_vars,
1593
- )
1594
- rf_vars[constants.ROI_RF] = roi_rf
1595
- rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
1596
- beta_rf_value,
1597
- name=constants.BETA_RF,
1598
- ).sample()
1599
- elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
1600
- mroi_rf = prior.mroi_rf.sample(**sample_kwargs)
1601
- beta_rf_value = self._get_roi_prior_beta_rf_value(
1602
- beta_grf_dev=beta_grf_dev,
1603
- rf_transformed=rf_transformed,
1604
- roi_or_mroi_rf=mroi_rf,
1605
- **rf_vars,
1606
- )
1607
- rf_vars[constants.MROI_RF] = mroi_rf
1608
- rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
1609
- beta_rf_value,
1610
- name=constants.BETA_RF,
1611
- ).sample()
1612
- else:
1613
- rf_vars[constants.BETA_RF] = prior.beta_rf.sample(**sample_kwargs)
1614
-
1615
- beta_eta_combined = (
1616
- rf_vars[constants.BETA_RF][..., tf.newaxis, :]
1617
- + rf_vars[constants.ETA_RF][..., tf.newaxis, :] * beta_grf_dev
1618
- )
1619
- beta_grf_value = (
1620
- beta_eta_combined
1621
- if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1622
- else tf.math.exp(beta_eta_combined)
1623
- )
1624
- rf_vars[constants.BETA_GRF] = tfp.distributions.Deterministic(
1625
- beta_grf_value, name=constants.BETA_GRF
1626
- ).sample()
1627
-
1628
- return rf_vars
1629
-
1630
- def _sample_organic_media_priors(
1631
- self,
1632
- n_draws: int,
1633
- seed: int | None = None,
1634
- ) -> Mapping[str, tf.Tensor]:
1635
- """Draws samples from the prior distributions of organic media variables.
1636
-
1637
- Args:
1638
- n_draws: Number of samples drawn from the prior distribution.
1639
- seed: Used to set the seed for reproducible results. For more information,
1640
- see [PRNGS and seeds]
1641
- (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1642
-
1643
- Returns:
1644
- A mapping of organic media parameter names to a tensor of shape [n_draws,
1645
- n_geos, n_organic_media_channels] or [n_draws, n_organic_media_channels]
1646
- containing the samples.
1647
- """
1648
- prior = self.prior_broadcast
1649
- sample_shape = [1, n_draws]
1650
- sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
1651
- organic_media_vars = {
1652
- constants.ALPHA_OM: prior.alpha_om.sample(**sample_kwargs),
1653
- constants.EC_OM: prior.ec_om.sample(**sample_kwargs),
1654
- constants.ETA_OM: prior.eta_om.sample(**sample_kwargs),
1655
- constants.SLOPE_OM: prior.slope_om.sample(**sample_kwargs),
1656
- }
1657
- beta_gom_dev = tfp.distributions.Sample(
1658
- tfp.distributions.Normal(0, 1),
1659
- [self.n_geos, self.n_organic_media_channels],
1660
- name=constants.BETA_GOM_DEV,
1661
- ).sample(**sample_kwargs)
1662
-
1663
- organic_media_vars[constants.BETA_OM] = prior.beta_om.sample(
1664
- **sample_kwargs
1665
- )
1666
-
1667
- beta_eta_combined = (
1668
- organic_media_vars[constants.BETA_OM][..., tf.newaxis, :]
1669
- + organic_media_vars[constants.ETA_OM][..., tf.newaxis, :]
1670
- * beta_gom_dev
1671
- )
1672
- beta_gom_value = (
1673
- beta_eta_combined
1674
- if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1675
- else tf.math.exp(beta_eta_combined)
1676
- )
1677
- organic_media_vars[constants.BETA_GOM] = tfp.distributions.Deterministic(
1678
- beta_gom_value, name=constants.BETA_GOM
1679
- ).sample()
1680
-
1681
- return organic_media_vars
1682
-
1683
- def _sample_organic_rf_priors(
1684
- self,
1685
- n_draws: int,
1686
- seed: int | None = None,
1687
- ) -> Mapping[str, tf.Tensor]:
1688
- """Draws samples from the prior distributions of the organic RF variables.
1689
-
1690
- Args:
1691
- n_draws: Number of samples drawn from the prior distribution.
1692
- seed: Used to set the seed for reproducible results. For more information,
1693
- see [PRNGS and seeds]
1694
- (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1695
-
1696
- Returns:
1697
- A mapping of organic RF parameter names to a tensor of shape [n_draws,
1698
- n_geos, n_organic_rf_channels] or [n_draws, n_organic_rf_channels]
1699
- containing the samples.
1700
- """
1701
- prior = self.prior_broadcast
1702
- sample_shape = [1, n_draws]
1703
- sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
1704
- organic_rf_vars = {
1705
- constants.ALPHA_ORF: prior.alpha_orf.sample(**sample_kwargs),
1706
- constants.EC_ORF: prior.ec_orf.sample(**sample_kwargs),
1707
- constants.ETA_ORF: prior.eta_orf.sample(**sample_kwargs),
1708
- constants.SLOPE_ORF: prior.slope_orf.sample(**sample_kwargs),
1709
- }
1710
- beta_gorf_dev = tfp.distributions.Sample(
1711
- tfp.distributions.Normal(0, 1),
1712
- [self.n_geos, self.n_organic_rf_channels],
1713
- name=constants.BETA_GORF_DEV,
1714
- ).sample(**sample_kwargs)
1715
-
1716
- organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(**sample_kwargs)
1717
-
1718
- beta_eta_combined = (
1719
- organic_rf_vars[constants.BETA_ORF][..., tf.newaxis, :]
1720
- + organic_rf_vars[constants.ETA_ORF][..., tf.newaxis, :] * beta_gorf_dev
1721
- )
1722
- beta_gorf_value = (
1723
- beta_eta_combined
1724
- if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1725
- else tf.math.exp(beta_eta_combined)
1726
- )
1727
- organic_rf_vars[constants.BETA_GORF] = tfp.distributions.Deterministic(
1728
- beta_gorf_value, name=constants.BETA_GORF
1729
- ).sample()
1730
-
1731
- return organic_rf_vars
1732
-
1733
- def _sample_non_media_treatments_priors(
1734
- self,
1735
- n_draws: int,
1736
- seed: int | None = None,
1737
- ) -> Mapping[str, tf.Tensor]:
1738
- """Draws from the prior distributions of the non-media treatment variables.
1739
-
1740
- Args:
1741
- n_draws: Number of samples drawn from the prior distribution.
1742
- seed: Used to set the seed for reproducible results. For more information,
1743
- see [PRNGS and seeds]
1744
- (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1745
-
1746
- Returns:
1747
- A mapping of non-media treatment parameter names to a tensor of shape
1748
- [n_draws,
1749
- n_geos, n_non_media_channels] or [n_draws, n_non_media_channels]
1750
- containing the samples.
1751
- """
1752
- prior = self.prior_broadcast
1753
- sample_shape = [1, n_draws]
1754
- sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
1755
- non_media_treatments_vars = {
1756
- constants.GAMMA_N: prior.gamma_n.sample(**sample_kwargs),
1757
- constants.XI_N: prior.xi_n.sample(**sample_kwargs),
1758
- }
1759
- gamma_gn_dev = tfp.distributions.Sample(
1760
- tfp.distributions.Normal(0, 1),
1761
- [self.n_geos, self.n_non_media_channels],
1762
- name=constants.GAMMA_GN_DEV,
1763
- ).sample(**sample_kwargs)
1764
- non_media_treatments_vars[constants.GAMMA_GN] = (
1765
- tfp.distributions.Deterministic(
1766
- non_media_treatments_vars[constants.GAMMA_N][..., tf.newaxis, :]
1767
- + non_media_treatments_vars[constants.XI_N][..., tf.newaxis, :]
1768
- * gamma_gn_dev,
1769
- name=constants.GAMMA_GN,
1770
- ).sample()
1771
- )
1772
- return non_media_treatments_vars
1773
-
1774
- def _sample_prior_fn(
1775
- self,
1776
- n_draws: int,
1777
- seed: int | None = None,
1778
- ) -> Mapping[str, tf.Tensor]:
1779
- """Returns a mapping of prior parameters to tensors of the samples."""
1780
- # For stateful sampling, the random seed must be set to ensure that any
1781
- # random numbers that are generated are deterministic.
1782
- if seed is not None:
1783
- tf.keras.utils.set_random_seed(1)
1784
- prior = self.prior_broadcast
1785
- sample_shape = [1, n_draws]
1786
- sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
1787
-
1788
- tau_g_excl_baseline = prior.tau_g_excl_baseline.sample(**sample_kwargs)
1789
- base_vars = {
1790
- constants.KNOT_VALUES: prior.knot_values.sample(**sample_kwargs),
1791
- constants.GAMMA_C: prior.gamma_c.sample(**sample_kwargs),
1792
- constants.XI_C: prior.xi_c.sample(**sample_kwargs),
1793
- constants.SIGMA: prior.sigma.sample(**sample_kwargs),
1794
- constants.TAU_G: _get_tau_g(
1795
- tau_g_excl_baseline=tau_g_excl_baseline,
1796
- baseline_geo_idx=self.baseline_geo_idx,
1797
- ).sample(),
1798
- }
1799
- base_vars[constants.MU_T] = tfp.distributions.Deterministic(
1800
- tf.einsum(
1801
- "...k,kt->...t",
1802
- base_vars[constants.KNOT_VALUES],
1803
- tf.convert_to_tensor(self.knot_info.weights),
1804
- ),
1805
- name=constants.MU_T,
1806
- ).sample()
1807
-
1808
- gamma_gc_dev = tfp.distributions.Sample(
1809
- tfp.distributions.Normal(0, 1),
1810
- [self.n_geos, self.n_controls],
1811
- name=constants.GAMMA_GC_DEV,
1812
- ).sample(**sample_kwargs)
1813
- base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic(
1814
- base_vars[constants.GAMMA_C][..., tf.newaxis, :]
1815
- + base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev,
1816
- name=constants.GAMMA_GC,
1817
- ).sample()
1818
-
1819
- media_vars = (
1820
- self._sample_media_priors(n_draws, seed)
1821
- if self.media_tensors.media is not None
1822
- else {}
1823
- )
1824
- rf_vars = (
1825
- self._sample_rf_priors(n_draws, seed)
1826
- if self.rf_tensors.reach is not None
1827
- else {}
1828
- )
1829
- organic_media_vars = (
1830
- self._sample_organic_media_priors(n_draws, seed)
1831
- if self.organic_media_tensors.organic_media is not None
1832
- else {}
1833
- )
1834
- organic_rf_vars = (
1835
- self._sample_organic_rf_priors(n_draws, seed)
1836
- if self.organic_rf_tensors.organic_reach is not None
1837
- else {}
1838
- )
1839
- non_media_treatments_vars = (
1840
- self._sample_non_media_treatments_priors(n_draws, seed)
1841
- if self.non_media_treatments_scaled is not None
1842
- else {}
1843
- )
1844
-
1845
- return (
1846
- base_vars
1847
- | media_vars
1848
- | rf_vars
1849
- | organic_media_vars
1850
- | organic_rf_vars
1851
- | non_media_treatments_vars
1852
- )
1853
-
1854
1006
  def sample_prior(self, n_draws: int, seed: int | None = None):
1855
1007
  """Draws samples from the prior distributions.
1856
1008
 
1009
+ Drawn samples are merged into this model's Arviz `inference_data` property.
1010
+
1857
1011
  Args:
1858
1012
  n_draws: Number of samples drawn from the prior distribution.
1859
1013
  seed: Used to set the seed for reproducible results. For more information,
1860
1014
  see [PRNGS and seeds]
1861
1015
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1862
1016
  """
1863
- prior_draws = self._sample_prior_fn(n_draws, seed=seed)
1864
- # Create Arviz InferenceData for prior draws.
1865
- prior_coords = self._create_inference_data_coords(1, n_draws)
1866
- prior_dims = self._create_inference_data_dims()
1867
- prior_inference_data = az.convert_to_inference_data(
1868
- prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR
1869
- )
1017
+ prior_inference_data = self.prior_sampler_callable(n_draws, seed)
1870
1018
  self.inference_data.extend(prior_inference_data, join="right")
1871
1019
 
1872
1020
  def sample_posterior(
@@ -1890,6 +1038,8 @@ class Meridian:
1890
1038
  For more information about the arguments, see [`windowed_adaptive_nuts`]
1891
1039
  (https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/mcmc/windowed_adaptive_nuts).
1892
1040
 
1041
+ Drawn samples are merged into this model's Arviz `inference_data` property.
1042
+
1893
1043
  Args:
1894
1044
  n_chains: Number of MCMC chains. Given a sequence of integers,
1895
1045
  `windowed_adaptive_nuts` will be called once for each element. The
@@ -1943,112 +1093,20 @@ class Meridian:
1943
1093
  [ResourceExhaustedError when running Meridian.sample_posterior]
1944
1094
  (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
1945
1095
  """
1946
- seed = tfp.random.sanitize_seed(seed) if seed else None
1947
- n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
1948
- total_chains = np.sum(n_chains_list)
1949
-
1950
- states = []
1951
- traces = []
1952
- for n_chains_batch in n_chains_list:
1953
- try:
1954
- mcmc = _xla_windowed_adaptive_nuts(
1955
- n_draws=n_burnin + n_keep,
1956
- joint_dist=self._get_joint_dist(),
1957
- n_chains=n_chains_batch,
1958
- num_adaptation_steps=n_adapt,
1959
- current_state=current_state,
1960
- init_step_size=init_step_size,
1961
- dual_averaging_kwargs=dual_averaging_kwargs,
1962
- max_tree_depth=max_tree_depth,
1963
- max_energy_diff=max_energy_diff,
1964
- unrolled_leapfrog_steps=unrolled_leapfrog_steps,
1965
- parallel_iterations=parallel_iterations,
1966
- seed=seed,
1967
- **pins,
1968
- )
1969
- except tf.errors.ResourceExhaustedError as error:
1970
- raise MCMCOOMError(
1971
- "ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
1972
- " integers as `n_chains` to sample chains serially (see"
1973
- " https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error)"
1974
- ) from error
1975
- states.append(mcmc.all_states._asdict())
1976
- traces.append(mcmc.trace)
1977
-
1978
- mcmc_states = {
1979
- k: tf.einsum(
1980
- "ij...->ji...",
1981
- tf.concat([state[k] for state in states], axis=1)[n_burnin:, ...],
1982
- )
1983
- for k in states[0].keys()
1984
- if k not in constants.UNSAVED_PARAMETERS
1985
- }
1986
- # Create Arviz InferenceData for posterior draws.
1987
- posterior_coords = self._create_inference_data_coords(total_chains, n_keep)
1988
- posterior_dims = self._create_inference_data_dims()
1989
- infdata_posterior = az.convert_to_inference_data(
1990
- mcmc_states, coords=posterior_coords, dims=posterior_dims
1991
- )
1992
-
1993
- # Save trace metrics in InferenceData.
1994
- mcmc_trace = {}
1995
- for k in traces[0].keys():
1996
- if k not in constants.IGNORED_TRACE_METRICS:
1997
- mcmc_trace[k] = tf.concat(
1998
- [
1999
- tf.broadcast_to(
2000
- tf.transpose(trace[k][n_burnin:, ...]),
2001
- [n_chains_list[i], n_keep],
2002
- )
2003
- for i, trace in enumerate(traces)
2004
- ],
2005
- axis=0,
2006
- )
2007
-
2008
- trace_coords = {
2009
- constants.CHAIN: np.arange(total_chains),
2010
- constants.DRAW: np.arange(n_keep),
2011
- }
2012
- trace_dims = {
2013
- k: [constants.CHAIN, constants.DRAW] for k in mcmc_trace.keys()
2014
- }
2015
- infdata_trace = az.convert_to_inference_data(
2016
- mcmc_trace, coords=trace_coords, dims=trace_dims, group="trace"
2017
- )
2018
-
2019
- # Create Arviz InferenceData for divergent transitions and other sampling
2020
- # statistics. Note that InferenceData has a different naming convention
2021
- # than Tensorflow, and only certain variables are recongnized.
2022
- # https://arviz-devs.github.io/arviz/schema/schema.html#sample-stats
2023
- # The list of values returned by windowed_adaptive_nuts() is the following:
2024
- # 'step_size', 'tune', 'target_log_prob', 'diverging', 'accept_ratio',
2025
- # 'variance_scaling', 'n_steps', 'is_accepted'.
2026
-
2027
- sample_stats = {
2028
- constants.SAMPLE_STATS_METRICS[k]: v
2029
- for k, v in mcmc_trace.items()
2030
- if k in constants.SAMPLE_STATS_METRICS
2031
- }
2032
- sample_stats_dims = {
2033
- constants.SAMPLE_STATS_METRICS[k]: v
2034
- for k, v in trace_dims.items()
2035
- if k in constants.SAMPLE_STATS_METRICS
2036
- }
2037
- # Tensorflow does not include a "draw" dimension on step size metric if same
2038
- # step size is used for all chains. Step size must be broadcast to the
2039
- # correct shape.
2040
- sample_stats[constants.STEP_SIZE] = tf.broadcast_to(
2041
- sample_stats[constants.STEP_SIZE], [total_chains, n_keep]
2042
- )
2043
- sample_stats_dims[constants.STEP_SIZE] = [constants.CHAIN, constants.DRAW]
2044
- infdata_sample_stats = az.convert_to_inference_data(
2045
- sample_stats,
2046
- coords=trace_coords,
2047
- dims=sample_stats_dims,
2048
- group="sample_stats",
2049
- )
2050
- posterior_inference_data = az.concat(
2051
- infdata_posterior, infdata_trace, infdata_sample_stats
1096
+ posterior_inference_data = self.posterior_sampler_callable(
1097
+ n_chains,
1098
+ n_adapt,
1099
+ n_burnin,
1100
+ n_keep,
1101
+ current_state,
1102
+ init_step_size,
1103
+ dual_averaging_kwargs,
1104
+ max_tree_depth,
1105
+ max_energy_diff,
1106
+ unrolled_leapfrog_steps,
1107
+ parallel_iterations,
1108
+ seed,
1109
+ **pins,
2052
1110
  )
2053
1111
  self.inference_data.extend(posterior_inference_data, join="right")
2054
1112