google-meridian 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -78,386 +78,386 @@ def _xla_windowed_adaptive_nuts(**kwargs):
78
78
  return backend.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
79
79
 
80
80
 
81
- class PosteriorMCMCSampler:
82
- """A callable that samples from posterior distributions using MCMC."""
83
-
84
- def __init__(self, meridian: "model.Meridian"):
85
- self._meridian = meridian
86
-
87
- @property
88
- def model(self) -> "model.Meridian":
89
- return self._meridian
90
-
91
- def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
92
- """Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
93
- mmm = self.model
94
- mmm.populate_cached_properties()
81
+ def _joint_dist_unpinned(mmm: "model.Meridian"):
82
+ """Returns unpinned joint distribution."""
83
+
84
+ # This lists all the derived properties and states of this Meridian object
85
+ # that are referenced by the joint distribution coroutine.
86
+ # That is, these are the list of captured parameters.
87
+ prior_broadcast = mmm.prior_broadcast
88
+ baseline_geo_idx = mmm.baseline_geo_idx
89
+ knot_info = mmm.knot_info
90
+ n_geos = mmm.n_geos
91
+ n_times = mmm.n_times
92
+ n_media_channels = mmm.n_media_channels
93
+ n_rf_channels = mmm.n_rf_channels
94
+ n_organic_media_channels = mmm.n_organic_media_channels
95
+ n_organic_rf_channels = mmm.n_organic_rf_channels
96
+ n_controls = mmm.n_controls
97
+ n_non_media_channels = mmm.n_non_media_channels
98
+ holdout_id = mmm.holdout_id
99
+ media_tensors = mmm.media_tensors
100
+ rf_tensors = mmm.rf_tensors
101
+ organic_media_tensors = mmm.organic_media_tensors
102
+ organic_rf_tensors = mmm.organic_rf_tensors
103
+ controls_scaled = mmm.controls_scaled
104
+ non_media_treatments_normalized = mmm.non_media_treatments_normalized
105
+ media_effects_dist = mmm.media_effects_dist
106
+ adstock_hill_media_fn = mmm.adstock_hill_media
107
+ adstock_hill_rf_fn = mmm.adstock_hill_rf
108
+ total_outcome = mmm.total_outcome
109
+
110
+ # Sample directly from prior.
111
+ knot_values = yield prior_broadcast.knot_values
112
+ sigma = yield prior_broadcast.sigma
113
+
114
+ tau_g_excl_baseline = yield backend.tfd.Sample(
115
+ prior_broadcast.tau_g_excl_baseline,
116
+ name=constants.TAU_G_EXCL_BASELINE,
117
+ )
118
+ tau_g = yield _get_tau_g(
119
+ tau_g_excl_baseline=tau_g_excl_baseline,
120
+ baseline_geo_idx=baseline_geo_idx,
121
+ )
122
+ mu_t = yield backend.tfd.Deterministic(
123
+ backend.einsum(
124
+ "k,kt->t",
125
+ knot_values,
126
+ backend.to_tensor(knot_info.weights),
127
+ ),
128
+ name=constants.MU_T,
129
+ )
95
130
 
96
- # This lists all the derived properties and states of this Meridian object
97
- # that are referenced by the joint distribution coroutine.
98
- # That is, these are the list of captured parameters.
99
- prior_broadcast = mmm.prior_broadcast
100
- baseline_geo_idx = mmm.baseline_geo_idx
101
- knot_info = mmm.knot_info
102
- n_geos = mmm.n_geos
103
- n_times = mmm.n_times
104
- n_media_channels = mmm.n_media_channels
105
- n_rf_channels = mmm.n_rf_channels
106
- n_organic_media_channels = mmm.n_organic_media_channels
107
- n_organic_rf_channels = mmm.n_organic_rf_channels
108
- n_controls = mmm.n_controls
109
- n_non_media_channels = mmm.n_non_media_channels
110
- holdout_id = mmm.holdout_id
111
- media_tensors = mmm.media_tensors
112
- rf_tensors = mmm.rf_tensors
113
- organic_media_tensors = mmm.organic_media_tensors
114
- organic_rf_tensors = mmm.organic_rf_tensors
115
- controls_scaled = mmm.controls_scaled
116
- non_media_treatments_normalized = mmm.non_media_treatments_normalized
117
- media_effects_dist = mmm.media_effects_dist
118
- adstock_hill_media_fn = mmm.adstock_hill_media
119
- adstock_hill_rf_fn = mmm.adstock_hill_rf
120
- total_outcome = mmm.total_outcome
121
-
122
- @backend.tfd.JointDistributionCoroutineAutoBatched
123
- def joint_dist_unpinned():
124
- # Sample directly from prior.
125
- knot_values = yield prior_broadcast.knot_values
126
- sigma = yield prior_broadcast.sigma
127
-
128
- tau_g_excl_baseline = yield backend.tfd.Sample(
129
- prior_broadcast.tau_g_excl_baseline,
130
- name=constants.TAU_G_EXCL_BASELINE,
131
+ tau_gt = tau_g[:, backend.newaxis] + mu_t
132
+ combined_media_transformed = backend.zeros(
133
+ shape=(n_geos, n_times, 0), dtype=backend.float32
134
+ )
135
+ combined_beta = backend.zeros(shape=(n_geos, 0), dtype=backend.float32)
136
+ if media_tensors.media is not None:
137
+ alpha_m = yield prior_broadcast.alpha_m
138
+ ec_m = yield prior_broadcast.ec_m
139
+ eta_m = yield prior_broadcast.eta_m
140
+ slope_m = yield prior_broadcast.slope_m
141
+ beta_gm_dev = yield backend.tfd.Sample(
142
+ backend.tfd.Normal(0, 1),
143
+ [n_geos, n_media_channels],
144
+ name=constants.BETA_GM_DEV,
145
+ )
146
+ media_transformed = adstock_hill_media_fn(
147
+ media=media_tensors.media_scaled,
148
+ alpha=alpha_m,
149
+ ec=ec_m,
150
+ slope=slope_m,
151
+ decay_functions=mmm.adstock_decay_spec.media,
152
+ )
153
+ prior_type = mmm.model_spec.effective_media_prior_type
154
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
155
+ beta_m = yield prior_broadcast.beta_m
156
+ else:
157
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
158
+ treatment_parameter_m = yield prior_broadcast.roi_m
159
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
160
+ treatment_parameter_m = yield prior_broadcast.mroi_m
161
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
162
+ treatment_parameter_m = yield prior_broadcast.contribution_m
163
+ else:
164
+ raise ValueError(f"Unsupported prior type: {prior_type}")
165
+ incremental_outcome_m = (
166
+ treatment_parameter_m * media_tensors.prior_denominator
131
167
  )
132
- tau_g = yield _get_tau_g(
133
- tau_g_excl_baseline=tau_g_excl_baseline,
134
- baseline_geo_idx=baseline_geo_idx,
168
+ linear_predictor_counterfactual_difference = (
169
+ mmm.linear_predictor_counterfactual_difference_media(
170
+ media_transformed=media_transformed,
171
+ alpha_m=alpha_m,
172
+ ec_m=ec_m,
173
+ slope_m=slope_m,
174
+ )
135
175
  )
136
- mu_t = yield backend.tfd.Deterministic(
137
- backend.einsum(
138
- "k,kt->t",
139
- knot_values,
140
- backend.to_tensor(knot_info.weights),
141
- ),
142
- name=constants.MU_T,
176
+ beta_m_value = mmm.calculate_beta_x(
177
+ is_non_media=False,
178
+ incremental_outcome_x=incremental_outcome_m,
179
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
180
+ eta_x=eta_m,
181
+ beta_gx_dev=beta_gm_dev,
143
182
  )
144
-
145
- tau_gt = tau_g[:, backend.newaxis] + mu_t
146
- combined_media_transformed = backend.zeros(
147
- shape=(n_geos, n_times, 0), dtype=backend.float32
183
+ beta_m = yield backend.tfd.Deterministic(
184
+ beta_m_value, name=constants.BETA_M
148
185
  )
149
- combined_beta = backend.zeros(shape=(n_geos, 0), dtype=backend.float32)
150
- if media_tensors.media is not None:
151
- alpha_m = yield prior_broadcast.alpha_m
152
- ec_m = yield prior_broadcast.ec_m
153
- eta_m = yield prior_broadcast.eta_m
154
- slope_m = yield prior_broadcast.slope_m
155
- beta_gm_dev = yield backend.tfd.Sample(
156
- backend.tfd.Normal(0, 1),
157
- [n_geos, n_media_channels],
158
- name=constants.BETA_GM_DEV,
159
- )
160
- media_transformed = adstock_hill_media_fn(
161
- media=media_tensors.media_scaled,
162
- alpha=alpha_m,
163
- ec=ec_m,
164
- slope=slope_m,
165
- decay_functions=mmm.adstock_decay_spec.media,
166
- )
167
- prior_type = mmm.model_spec.effective_media_prior_type
168
- if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
169
- beta_m = yield prior_broadcast.beta_m
170
- else:
171
- if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
172
- treatment_parameter_m = yield prior_broadcast.roi_m
173
- elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
174
- treatment_parameter_m = yield prior_broadcast.mroi_m
175
- elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
176
- treatment_parameter_m = yield prior_broadcast.contribution_m
177
- else:
178
- raise ValueError(f"Unsupported prior type: {prior_type}")
179
- incremental_outcome_m = (
180
- treatment_parameter_m * media_tensors.prior_denominator
181
- )
182
- linear_predictor_counterfactual_difference = (
183
- mmm.linear_predictor_counterfactual_difference_media(
184
- media_transformed=media_transformed,
185
- alpha_m=alpha_m,
186
- ec_m=ec_m,
187
- slope_m=slope_m,
188
- )
189
- )
190
- beta_m_value = mmm.calculate_beta_x(
191
- is_non_media=False,
192
- incremental_outcome_x=incremental_outcome_m,
193
- linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
194
- eta_x=eta_m,
195
- beta_gx_dev=beta_gm_dev,
196
- )
197
- beta_m = yield backend.tfd.Deterministic(
198
- beta_m_value, name=constants.BETA_M
199
- )
200
186
 
201
- beta_eta_combined = beta_m + eta_m * beta_gm_dev
202
- beta_gm_value = (
203
- beta_eta_combined
204
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
205
- else backend.exp(beta_eta_combined)
206
- )
207
- beta_gm = yield backend.tfd.Deterministic(
208
- beta_gm_value, name=constants.BETA_GM
209
- )
210
- combined_media_transformed = backend.concatenate(
211
- [combined_media_transformed, media_transformed], axis=-1
212
- )
213
- combined_beta = backend.concatenate([combined_beta, beta_gm], axis=-1)
214
-
215
- if rf_tensors.reach is not None:
216
- alpha_rf = yield prior_broadcast.alpha_rf
217
- ec_rf = yield prior_broadcast.ec_rf
218
- eta_rf = yield prior_broadcast.eta_rf
219
- slope_rf = yield prior_broadcast.slope_rf
220
- beta_grf_dev = yield backend.tfd.Sample(
221
- backend.tfd.Normal(0, 1),
222
- [n_geos, n_rf_channels],
223
- name=constants.BETA_GRF_DEV,
224
- )
225
- rf_transformed = adstock_hill_rf_fn(
226
- reach=rf_tensors.reach_scaled,
227
- frequency=rf_tensors.frequency,
228
- alpha=alpha_rf,
229
- ec=ec_rf,
230
- slope=slope_rf,
231
- decay_functions=mmm.adstock_decay_spec.rf,
232
- )
187
+ beta_eta_combined = beta_m + eta_m * beta_gm_dev
188
+ beta_gm_value = (
189
+ beta_eta_combined
190
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
191
+ else backend.exp(beta_eta_combined)
192
+ )
193
+ beta_gm = yield backend.tfd.Deterministic(
194
+ beta_gm_value, name=constants.BETA_GM
195
+ )
196
+ combined_media_transformed = backend.concatenate(
197
+ [combined_media_transformed, media_transformed], axis=-1
198
+ )
199
+ combined_beta = backend.concatenate([combined_beta, beta_gm], axis=-1)
200
+
201
+ if rf_tensors.reach is not None:
202
+ alpha_rf = yield prior_broadcast.alpha_rf
203
+ ec_rf = yield prior_broadcast.ec_rf
204
+ eta_rf = yield prior_broadcast.eta_rf
205
+ slope_rf = yield prior_broadcast.slope_rf
206
+ beta_grf_dev = yield backend.tfd.Sample(
207
+ backend.tfd.Normal(0, 1),
208
+ [n_geos, n_rf_channels],
209
+ name=constants.BETA_GRF_DEV,
210
+ )
211
+ rf_transformed = adstock_hill_rf_fn(
212
+ reach=rf_tensors.reach_scaled,
213
+ frequency=rf_tensors.frequency,
214
+ alpha=alpha_rf,
215
+ ec=ec_rf,
216
+ slope=slope_rf,
217
+ decay_functions=mmm.adstock_decay_spec.rf,
218
+ )
233
219
 
234
- prior_type = mmm.model_spec.effective_rf_prior_type
235
- if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
236
- beta_rf = yield prior_broadcast.beta_rf
237
- else:
238
- if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
239
- treatment_parameter_rf = yield prior_broadcast.roi_rf
240
- elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
241
- treatment_parameter_rf = yield prior_broadcast.mroi_rf
242
- elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
243
- treatment_parameter_rf = yield prior_broadcast.contribution_rf
244
- else:
245
- raise ValueError(f"Unsupported prior type: {prior_type}")
246
- incremental_outcome_rf = (
247
- treatment_parameter_rf * rf_tensors.prior_denominator
248
- )
249
- linear_predictor_counterfactual_difference = (
250
- mmm.linear_predictor_counterfactual_difference_rf(
251
- rf_transformed=rf_transformed,
252
- alpha_rf=alpha_rf,
253
- ec_rf=ec_rf,
254
- slope_rf=slope_rf,
255
- )
256
- )
257
- beta_rf_value = mmm.calculate_beta_x(
258
- is_non_media=False,
259
- incremental_outcome_x=incremental_outcome_rf,
260
- linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
261
- eta_x=eta_rf,
262
- beta_gx_dev=beta_grf_dev,
263
- )
264
- beta_rf = yield backend.tfd.Deterministic(
265
- beta_rf_value, name=constants.BETA_RF
220
+ prior_type = mmm.model_spec.effective_rf_prior_type
221
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
222
+ beta_rf = yield prior_broadcast.beta_rf
223
+ else:
224
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
225
+ treatment_parameter_rf = yield prior_broadcast.roi_rf
226
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
227
+ treatment_parameter_rf = yield prior_broadcast.mroi_rf
228
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
229
+ treatment_parameter_rf = yield prior_broadcast.contribution_rf
230
+ else:
231
+ raise ValueError(f"Unsupported prior type: {prior_type}")
232
+ incremental_outcome_rf = (
233
+ treatment_parameter_rf * rf_tensors.prior_denominator
234
+ )
235
+ linear_predictor_counterfactual_difference = (
236
+ mmm.linear_predictor_counterfactual_difference_rf(
237
+ rf_transformed=rf_transformed,
238
+ alpha_rf=alpha_rf,
239
+ ec_rf=ec_rf,
240
+ slope_rf=slope_rf,
266
241
  )
242
+ )
243
+ beta_rf_value = mmm.calculate_beta_x(
244
+ is_non_media=False,
245
+ incremental_outcome_x=incremental_outcome_rf,
246
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
247
+ eta_x=eta_rf,
248
+ beta_gx_dev=beta_grf_dev,
249
+ )
250
+ beta_rf = yield backend.tfd.Deterministic(
251
+ beta_rf_value, name=constants.BETA_RF
252
+ )
267
253
 
268
- beta_eta_combined = beta_rf + eta_rf * beta_grf_dev
269
- beta_grf_value = (
270
- beta_eta_combined
271
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
272
- else backend.exp(beta_eta_combined)
273
- )
274
- beta_grf = yield backend.tfd.Deterministic(
275
- beta_grf_value, name=constants.BETA_GRF
276
- )
277
- combined_media_transformed = backend.concatenate(
278
- [combined_media_transformed, rf_transformed], axis=-1
279
- )
280
- combined_beta = backend.concatenate([combined_beta, beta_grf], axis=-1)
281
-
282
- if organic_media_tensors.organic_media is not None:
283
- alpha_om = yield prior_broadcast.alpha_om
284
- ec_om = yield prior_broadcast.ec_om
285
- eta_om = yield prior_broadcast.eta_om
286
- slope_om = yield prior_broadcast.slope_om
287
- beta_gom_dev = yield backend.tfd.Sample(
288
- backend.tfd.Normal(0, 1),
289
- [n_geos, n_organic_media_channels],
290
- name=constants.BETA_GOM_DEV,
291
- )
292
- organic_media_transformed = adstock_hill_media_fn(
293
- media=organic_media_tensors.organic_media_scaled,
294
- alpha=alpha_om,
295
- ec=ec_om,
296
- slope=slope_om,
297
- decay_functions=mmm.adstock_decay_spec.organic_media,
298
- )
299
- prior_type = mmm.model_spec.organic_media_prior_type
300
- if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
301
- beta_om = yield prior_broadcast.beta_om
302
- elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
303
- contribution_om = yield prior_broadcast.contribution_om
304
- incremental_outcome_om = contribution_om * total_outcome
305
- beta_om_value = mmm.calculate_beta_x(
306
- is_non_media=False,
307
- incremental_outcome_x=incremental_outcome_om,
308
- linear_predictor_counterfactual_difference=organic_media_transformed,
309
- eta_x=eta_om,
310
- beta_gx_dev=beta_gom_dev,
311
- )
312
- beta_om = yield backend.tfd.Deterministic(
313
- beta_om_value, name=constants.BETA_OM
314
- )
315
- else:
316
- raise ValueError(f"Unsupported prior type: {prior_type}")
317
-
318
- beta_eta_combined = beta_om + eta_om * beta_gom_dev
319
- beta_gom_value = (
320
- beta_eta_combined
321
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
322
- else backend.exp(beta_eta_combined)
323
- )
324
- beta_gom = yield backend.tfd.Deterministic(
325
- beta_gom_value, name=constants.BETA_GOM
326
- )
327
- combined_media_transformed = backend.concatenate(
328
- [combined_media_transformed, organic_media_transformed], axis=-1
329
- )
330
- combined_beta = backend.concatenate([combined_beta, beta_gom], axis=-1)
331
-
332
- if organic_rf_tensors.organic_reach is not None:
333
- alpha_orf = yield prior_broadcast.alpha_orf
334
- ec_orf = yield prior_broadcast.ec_orf
335
- eta_orf = yield prior_broadcast.eta_orf
336
- slope_orf = yield prior_broadcast.slope_orf
337
- beta_gorf_dev = yield backend.tfd.Sample(
338
- backend.tfd.Normal(0, 1),
339
- [n_geos, n_organic_rf_channels],
340
- name=constants.BETA_GORF_DEV,
341
- )
342
- organic_rf_transformed = adstock_hill_rf_fn(
343
- reach=organic_rf_tensors.organic_reach_scaled,
344
- frequency=organic_rf_tensors.organic_frequency,
345
- alpha=alpha_orf,
346
- ec=ec_orf,
347
- slope=slope_orf,
348
- decay_functions=mmm.adstock_decay_spec.organic_rf,
349
- )
254
+ beta_eta_combined = beta_rf + eta_rf * beta_grf_dev
255
+ beta_grf_value = (
256
+ beta_eta_combined
257
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
258
+ else backend.exp(beta_eta_combined)
259
+ )
260
+ beta_grf = yield backend.tfd.Deterministic(
261
+ beta_grf_value, name=constants.BETA_GRF
262
+ )
263
+ combined_media_transformed = backend.concatenate(
264
+ [combined_media_transformed, rf_transformed], axis=-1
265
+ )
266
+ combined_beta = backend.concatenate([combined_beta, beta_grf], axis=-1)
267
+
268
+ if organic_media_tensors.organic_media is not None:
269
+ alpha_om = yield prior_broadcast.alpha_om
270
+ ec_om = yield prior_broadcast.ec_om
271
+ eta_om = yield prior_broadcast.eta_om
272
+ slope_om = yield prior_broadcast.slope_om
273
+ beta_gom_dev = yield backend.tfd.Sample(
274
+ backend.tfd.Normal(0, 1),
275
+ [n_geos, n_organic_media_channels],
276
+ name=constants.BETA_GOM_DEV,
277
+ )
278
+ organic_media_transformed = adstock_hill_media_fn(
279
+ media=organic_media_tensors.organic_media_scaled,
280
+ alpha=alpha_om,
281
+ ec=ec_om,
282
+ slope=slope_om,
283
+ decay_functions=mmm.adstock_decay_spec.organic_media,
284
+ )
285
+ prior_type = mmm.model_spec.organic_media_prior_type
286
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
287
+ beta_om = yield prior_broadcast.beta_om
288
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
289
+ contribution_om = yield prior_broadcast.contribution_om
290
+ incremental_outcome_om = contribution_om * total_outcome
291
+ beta_om_value = mmm.calculate_beta_x(
292
+ is_non_media=False,
293
+ incremental_outcome_x=incremental_outcome_om,
294
+ linear_predictor_counterfactual_difference=organic_media_transformed,
295
+ eta_x=eta_om,
296
+ beta_gx_dev=beta_gom_dev,
297
+ )
298
+ beta_om = yield backend.tfd.Deterministic(
299
+ beta_om_value, name=constants.BETA_OM
300
+ )
301
+ else:
302
+ raise ValueError(f"Unsupported prior type: {prior_type}")
303
+
304
+ beta_eta_combined = beta_om + eta_om * beta_gom_dev
305
+ beta_gom_value = (
306
+ beta_eta_combined
307
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
308
+ else backend.exp(beta_eta_combined)
309
+ )
310
+ beta_gom = yield backend.tfd.Deterministic(
311
+ beta_gom_value, name=constants.BETA_GOM
312
+ )
313
+ combined_media_transformed = backend.concatenate(
314
+ [combined_media_transformed, organic_media_transformed], axis=-1
315
+ )
316
+ combined_beta = backend.concatenate([combined_beta, beta_gom], axis=-1)
317
+
318
+ if organic_rf_tensors.organic_reach is not None:
319
+ alpha_orf = yield prior_broadcast.alpha_orf
320
+ ec_orf = yield prior_broadcast.ec_orf
321
+ eta_orf = yield prior_broadcast.eta_orf
322
+ slope_orf = yield prior_broadcast.slope_orf
323
+ beta_gorf_dev = yield backend.tfd.Sample(
324
+ backend.tfd.Normal(0, 1),
325
+ [n_geos, n_organic_rf_channels],
326
+ name=constants.BETA_GORF_DEV,
327
+ )
328
+ organic_rf_transformed = adstock_hill_rf_fn(
329
+ reach=organic_rf_tensors.organic_reach_scaled,
330
+ frequency=organic_rf_tensors.organic_frequency,
331
+ alpha=alpha_orf,
332
+ ec=ec_orf,
333
+ slope=slope_orf,
334
+ decay_functions=mmm.adstock_decay_spec.organic_rf,
335
+ )
350
336
 
351
- prior_type = mmm.model_spec.organic_rf_prior_type
352
- if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
353
- beta_orf = yield prior_broadcast.beta_orf
354
- elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
355
- contribution_orf = yield prior_broadcast.contribution_orf
356
- incremental_outcome_orf = contribution_orf * total_outcome
357
- beta_orf_value = mmm.calculate_beta_x(
358
- is_non_media=False,
359
- incremental_outcome_x=incremental_outcome_orf,
360
- linear_predictor_counterfactual_difference=organic_rf_transformed,
361
- eta_x=eta_orf,
362
- beta_gx_dev=beta_gorf_dev,
363
- )
364
- beta_orf = yield backend.tfd.Deterministic(
365
- beta_orf_value, name=constants.BETA_ORF
366
- )
367
- else:
368
- raise ValueError(f"Unsupported prior type: {prior_type}")
369
-
370
- beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev
371
- beta_gorf_value = (
372
- beta_eta_combined
373
- if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
374
- else backend.exp(beta_eta_combined)
375
- )
376
- beta_gorf = yield backend.tfd.Deterministic(
377
- beta_gorf_value, name=constants.BETA_GORF
378
- )
379
- combined_media_transformed = backend.concatenate(
380
- [combined_media_transformed, organic_rf_transformed], axis=-1
381
- )
382
- combined_beta = backend.concatenate([combined_beta, beta_gorf], axis=-1)
337
+ prior_type = mmm.model_spec.organic_rf_prior_type
338
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
339
+ beta_orf = yield prior_broadcast.beta_orf
340
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
341
+ contribution_orf = yield prior_broadcast.contribution_orf
342
+ incremental_outcome_orf = contribution_orf * total_outcome
343
+ beta_orf_value = mmm.calculate_beta_x(
344
+ is_non_media=False,
345
+ incremental_outcome_x=incremental_outcome_orf,
346
+ linear_predictor_counterfactual_difference=organic_rf_transformed,
347
+ eta_x=eta_orf,
348
+ beta_gx_dev=beta_gorf_dev,
349
+ )
350
+ beta_orf = yield backend.tfd.Deterministic(
351
+ beta_orf_value, name=constants.BETA_ORF
352
+ )
353
+ else:
354
+ raise ValueError(f"Unsupported prior type: {prior_type}")
355
+
356
+ beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev
357
+ beta_gorf_value = (
358
+ beta_eta_combined
359
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
360
+ else backend.exp(beta_eta_combined)
361
+ )
362
+ beta_gorf = yield backend.tfd.Deterministic(
363
+ beta_gorf_value, name=constants.BETA_GORF
364
+ )
365
+ combined_media_transformed = backend.concatenate(
366
+ [combined_media_transformed, organic_rf_transformed], axis=-1
367
+ )
368
+ combined_beta = backend.concatenate([combined_beta, beta_gorf], axis=-1)
383
369
 
384
- sigma_gt = backend.transpose(
385
- backend.broadcast_to(sigma, [n_times, n_geos])
370
+ sigma_gt = backend.transpose(backend.broadcast_to(sigma, [n_times, n_geos]))
371
+ y_pred_combined_media = tau_gt + backend.einsum(
372
+ "gtm,gm->gt", combined_media_transformed, combined_beta
373
+ )
374
+ # Omit gamma_c, xi_c, and gamma_gc from joint distribution output if
375
+ # there are no control variables in the model.
376
+ if n_controls:
377
+ gamma_c = yield prior_broadcast.gamma_c
378
+ xi_c = yield prior_broadcast.xi_c
379
+ gamma_gc_dev = yield backend.tfd.Sample(
380
+ backend.tfd.Normal(0, 1),
381
+ [n_geos, n_controls],
382
+ name=constants.GAMMA_GC_DEV,
383
+ )
384
+ gamma_gc = yield backend.tfd.Deterministic(
385
+ gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
386
+ )
387
+ y_pred_combined_media += backend.einsum(
388
+ "gtc,gc->gt", controls_scaled, gamma_gc
389
+ )
390
+
391
+ if mmm.non_media_treatments is not None:
392
+ xi_n = yield prior_broadcast.xi_n
393
+ gamma_gn_dev = yield backend.tfd.Sample(
394
+ backend.tfd.Normal(0, 1),
395
+ [n_geos, n_non_media_channels],
396
+ name=constants.GAMMA_GN_DEV,
397
+ )
398
+ prior_type = mmm.model_spec.non_media_treatments_prior_type
399
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
400
+ gamma_n = yield prior_broadcast.gamma_n
401
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
402
+ contribution_n = yield prior_broadcast.contribution_n
403
+ incremental_outcome_n = contribution_n * total_outcome
404
+ baseline_scaled = mmm.non_media_transformer.forward( # pytype: disable=attribute-error
405
+ mmm.compute_non_media_treatments_baseline()
386
406
  )
387
- y_pred_combined_media = tau_gt + backend.einsum(
388
- "gtm,gm->gt", combined_media_transformed, combined_beta
407
+ linear_predictor_counterfactual_difference = (
408
+ non_media_treatments_normalized - baseline_scaled
389
409
  )
390
- # Omit gamma_c, xi_c, and gamma_gc from joint distribution output if
391
- # there are no control variables in the model.
392
- if n_controls:
393
- gamma_c = yield prior_broadcast.gamma_c
394
- xi_c = yield prior_broadcast.xi_c
395
- gamma_gc_dev = yield backend.tfd.Sample(
396
- backend.tfd.Normal(0, 1),
397
- [n_geos, n_controls],
398
- name=constants.GAMMA_GC_DEV,
399
- )
400
- gamma_gc = yield backend.tfd.Deterministic(
401
- gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
402
- )
403
- y_pred_combined_media += backend.einsum(
404
- "gtc,gc->gt", controls_scaled, gamma_gc
405
- )
410
+ gamma_n_value = mmm.calculate_beta_x(
411
+ is_non_media=True,
412
+ incremental_outcome_x=incremental_outcome_n,
413
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
414
+ eta_x=xi_n,
415
+ beta_gx_dev=gamma_gn_dev,
416
+ )
417
+ gamma_n = yield backend.tfd.Deterministic(
418
+ gamma_n_value, name=constants.GAMMA_N
419
+ )
420
+ else:
421
+ raise ValueError(f"Unsupported prior type: {prior_type}")
422
+ gamma_gn = yield backend.tfd.Deterministic(
423
+ gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
424
+ )
425
+ y_pred = y_pred_combined_media + backend.einsum(
426
+ "gtn,gn->gt", non_media_treatments_normalized, gamma_gn
427
+ )
428
+ else:
429
+ y_pred = y_pred_combined_media
430
+
431
+ # If there are any holdout observations, the holdout KPI values will
432
+ # be replaced with zeros using `experimental_pin`. For these
433
+ # observations, we set the posterior mean equal to zero and standard
434
+ # deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the
435
+ # sampled posterior parameter values.
436
+ if holdout_id is not None:
437
+ y_pred_holdout = backend.where(holdout_id, 0.0, y_pred)
438
+ test_sd = backend.cast(1.0 / np.sqrt(2.0 * np.pi), backend.float32)
439
+ sigma_gt_holdout = backend.where(holdout_id, test_sd, sigma_gt)
440
+ yield backend.tfd.Normal(y_pred_holdout, sigma_gt_holdout, name="y")
441
+ else:
442
+ yield backend.tfd.Normal(y_pred, sigma_gt, name="y")
406
443
 
407
- if mmm.non_media_treatments is not None:
408
- xi_n = yield prior_broadcast.xi_n
409
- gamma_gn_dev = yield backend.tfd.Sample(
410
- backend.tfd.Normal(0, 1),
411
- [n_geos, n_non_media_channels],
412
- name=constants.GAMMA_GN_DEV,
413
- )
414
- prior_type = mmm.model_spec.non_media_treatments_prior_type
415
- if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
416
- gamma_n = yield prior_broadcast.gamma_n
417
- elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
418
- contribution_n = yield prior_broadcast.contribution_n
419
- incremental_outcome_n = contribution_n * total_outcome
420
- baseline_scaled = mmm.non_media_transformer.forward( # pytype: disable=attribute-error
421
- mmm.compute_non_media_treatments_baseline()
422
- )
423
- linear_predictor_counterfactual_difference = (
424
- non_media_treatments_normalized - baseline_scaled
425
- )
426
- gamma_n_value = mmm.calculate_beta_x(
427
- is_non_media=True,
428
- incremental_outcome_x=incremental_outcome_n,
429
- linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
430
- eta_x=xi_n,
431
- beta_gx_dev=gamma_gn_dev,
432
- )
433
- gamma_n = yield backend.tfd.Deterministic(
434
- gamma_n_value, name=constants.GAMMA_N
435
- )
436
- else:
437
- raise ValueError(f"Unsupported prior type: {prior_type}")
438
- gamma_gn = yield backend.tfd.Deterministic(
439
- gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
440
- )
441
- y_pred = y_pred_combined_media + backend.einsum(
442
- "gtn,gn->gt", non_media_treatments_normalized, gamma_gn
443
- )
444
- else:
445
- y_pred = y_pred_combined_media
446
-
447
- # If there are any holdout observations, the holdout KPI values will
448
- # be replaced with zeros using `experimental_pin`. For these
449
- # observations, we set the posterior mean equal to zero and standard
450
- # deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the
451
- # sampled posterior parameter values.
452
- if holdout_id is not None:
453
- y_pred_holdout = backend.where(holdout_id, 0.0, y_pred)
454
- test_sd = backend.cast(1.0 / np.sqrt(2.0 * np.pi), backend.float32)
455
- sigma_gt_holdout = backend.where(holdout_id, test_sd, sigma_gt)
456
- yield backend.tfd.Normal(y_pred_holdout, sigma_gt_holdout, name="y")
457
- else:
458
- yield backend.tfd.Normal(y_pred, sigma_gt, name="y")
459
444
 
460
- return joint_dist_unpinned
445
+ class PosteriorMCMCSampler:
446
+ """A callable that samples from posterior distributions using MCMC."""
447
+
448
+ def __init__(self, meridian: "model.Meridian"):
449
+ self._meridian = meridian
450
+
451
+ @property
452
+ def model(self) -> "model.Meridian":
453
+ return self._meridian
454
+
455
+ def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
456
+ """Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
457
+ mmm = self.model
458
+ mmm.populate_cached_properties()
459
+ fn = lambda: _joint_dist_unpinned(mmm)
460
+ return backend.tfd.JointDistributionCoroutineAutoBatched(fn)
461
461
 
462
462
  def _get_joint_dist(self) -> backend.tfd.Distribution:
463
463
  mmm = self.model