google-meridian 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -68,147 +68,6 @@ class PriorDistributionSampler:
68
68
  def __init__(self, meridian: "model.Meridian"):
69
69
  self._meridian = meridian
70
70
 
71
- def get_roi_prior_beta_m_value(
72
- self,
73
- alpha_m: tf.Tensor,
74
- beta_gm_dev: tf.Tensor,
75
- ec_m: tf.Tensor,
76
- eta_m: tf.Tensor,
77
- roi_or_mroi_m: tf.Tensor,
78
- slope_m: tf.Tensor,
79
- media_transformed: tf.Tensor,
80
- ) -> tf.Tensor:
81
- """Returns a tensor to be used in `beta_m`."""
82
- mmm = self._meridian
83
-
84
- # The `roi_or_mroi_m` parameter represents either ROI or mROI. For reach &
85
- # frequency channels, marginal ROI priors are defined as "mROI by reach",
86
- # which is equivalent to ROI.
87
- media_spend = mmm.media_tensors.media_spend
88
- media_spend_counterfactual = mmm.media_tensors.media_spend_counterfactual
89
- media_counterfactual_scaled = mmm.media_tensors.media_counterfactual_scaled
90
- # If we got here, then we should already have media tensors derived from
91
- # non-None InputData.media data.
92
- assert media_spend is not None
93
- assert media_spend_counterfactual is not None
94
- assert media_counterfactual_scaled is not None
95
-
96
- # Use absolute value here because this difference will be negative for
97
- # marginal ROI priors.
98
- inc_revenue_m = roi_or_mroi_m * tf.reduce_sum(
99
- tf.abs(media_spend - media_spend_counterfactual),
100
- range(media_spend.ndim - 1),
101
- )
102
-
103
- if (
104
- mmm.model_spec.roi_calibration_period is None
105
- and mmm.model_spec.paid_media_prior_type
106
- == constants.PAID_MEDIA_PRIOR_TYPE_ROI
107
- ):
108
- # We can skip the adstock/hill computation step in this case.
109
- media_counterfactual_transformed = tf.zeros_like(media_transformed)
110
- else:
111
- media_counterfactual_transformed = mmm.adstock_hill_media(
112
- media=media_counterfactual_scaled,
113
- alpha=alpha_m,
114
- ec=ec_m,
115
- slope=slope_m,
116
- )
117
-
118
- revenue_per_kpi = mmm.revenue_per_kpi
119
- if mmm.input_data.revenue_per_kpi is None:
120
- revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32)
121
- # Note: use absolute value here because this difference will be negative for
122
- # marginal ROI priors.
123
- media_contrib_gm = tf.einsum(
124
- "...gtm,g,,gt->...gm",
125
- tf.abs(media_transformed - media_counterfactual_transformed),
126
- mmm.population,
127
- mmm.kpi_transformer.population_scaled_stdev,
128
- revenue_per_kpi,
129
- )
130
-
131
- if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
132
- media_contrib_m = tf.einsum("...gm->...m", media_contrib_gm)
133
- random_effect_m = tf.einsum(
134
- "...m,...gm,...gm->...m", eta_m, beta_gm_dev, media_contrib_gm
135
- )
136
- return (inc_revenue_m - random_effect_m) / media_contrib_m
137
- else:
138
- # For log_normal, beta_m and eta_m are not mean & std.
139
- # The parameterization is beta_gm ~ exp(beta_m + eta_m * N(0, 1)).
140
- random_effect_m = tf.einsum(
141
- "...gm,...gm->...m",
142
- tf.math.exp(beta_gm_dev * eta_m[..., tf.newaxis, :]),
143
- media_contrib_gm,
144
- )
145
- return tf.math.log(inc_revenue_m) - tf.math.log(random_effect_m)
146
-
147
- def get_roi_prior_beta_rf_value(
148
- self,
149
- alpha_rf: tf.Tensor,
150
- beta_grf_dev: tf.Tensor,
151
- ec_rf: tf.Tensor,
152
- eta_rf: tf.Tensor,
153
- roi_or_mroi_rf: tf.Tensor,
154
- slope_rf: tf.Tensor,
155
- rf_transformed: tf.Tensor,
156
- ) -> tf.Tensor:
157
- """Returns a tensor to be used in `beta_rf`."""
158
- mmm = self._meridian
159
-
160
- rf_spend = mmm.rf_tensors.rf_spend
161
- rf_spend_counterfactual = mmm.rf_tensors.rf_spend_counterfactual
162
- reach_counterfactual_scaled = mmm.rf_tensors.reach_counterfactual_scaled
163
- frequency = mmm.rf_tensors.frequency
164
- # If we got here, then we should already have RF media tensors derived from
165
- # non-None InputData.reach data.
166
- assert rf_spend is not None
167
- assert rf_spend_counterfactual is not None
168
- assert reach_counterfactual_scaled is not None
169
- assert frequency is not None
170
-
171
- inc_revenue_rf = roi_or_mroi_rf * tf.reduce_sum(
172
- rf_spend - rf_spend_counterfactual,
173
- range(rf_spend.ndim - 1),
174
- )
175
- if mmm.model_spec.rf_roi_calibration_period is not None:
176
- rf_counterfactual_transformed = mmm.adstock_hill_rf(
177
- reach=reach_counterfactual_scaled,
178
- frequency=frequency,
179
- alpha=alpha_rf,
180
- ec=ec_rf,
181
- slope=slope_rf,
182
- )
183
- else:
184
- rf_counterfactual_transformed = tf.zeros_like(rf_transformed)
185
- revenue_per_kpi = mmm.revenue_per_kpi
186
- if mmm.input_data.revenue_per_kpi is None:
187
- revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32)
188
-
189
- media_contrib_grf = tf.einsum(
190
- "...gtm,g,,gt->...gm",
191
- rf_transformed - rf_counterfactual_transformed,
192
- mmm.population,
193
- mmm.kpi_transformer.population_scaled_stdev,
194
- revenue_per_kpi,
195
- )
196
- if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
197
- media_contrib_rf = tf.einsum("...gm->...m", media_contrib_grf)
198
- random_effect_rf = tf.einsum(
199
- "...m,...gm,...gm->...m", eta_rf, beta_grf_dev, media_contrib_grf
200
- )
201
- return (inc_revenue_rf - random_effect_rf) / media_contrib_rf
202
- else:
203
- # For log_normal, beta_rf and eta_rf are not mean & std.
204
- # The parameterization is beta_grf ~ exp(beta_rf + eta_rf * N(0, 1)).
205
- random_effect_rf = tf.einsum(
206
- "...gm,...gm->...m",
207
- tf.math.exp(beta_grf_dev * eta_rf[..., tf.newaxis, :]),
208
- media_contrib_grf,
209
- )
210
- return tf.math.log(inc_revenue_rf) - tf.math.log(random_effect_rf)
211
-
212
71
  def _sample_media_priors(
213
72
  self,
214
73
  n_draws: int,
@@ -243,40 +102,49 @@ class PriorDistributionSampler:
243
102
  [mmm.n_geos, mmm.n_media_channels],
244
103
  name=constants.BETA_GM_DEV,
245
104
  ).sample(**sample_kwargs)
246
- media_transformed = mmm.adstock_hill_media(
247
- media=mmm.media_tensors.media_scaled,
248
- alpha=media_vars[constants.ALPHA_M],
249
- ec=media_vars[constants.EC_M],
250
- slope=media_vars[constants.SLOPE_M],
251
- )
252
105
 
253
- prior_type = mmm.model_spec.paid_media_prior_type
254
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
255
- roi_m = prior.roi_m.sample(**sample_kwargs)
256
- beta_m_value = self.get_roi_prior_beta_m_value(
257
- beta_gm_dev=beta_gm_dev,
258
- media_transformed=media_transformed,
259
- roi_or_mroi_m=roi_m,
260
- **media_vars,
106
+ prior_type = mmm.model_spec.effective_media_prior_type
107
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
108
+ media_vars[constants.BETA_M] = prior.beta_m.sample(**sample_kwargs)
109
+ else:
110
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
111
+ treatment_parameter_m = prior.roi_m.sample(**sample_kwargs)
112
+ media_vars[constants.ROI_M] = treatment_parameter_m
113
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
114
+ treatment_parameter_m = prior.mroi_m.sample(**sample_kwargs)
115
+ media_vars[constants.MROI_M] = treatment_parameter_m
116
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
117
+ treatment_parameter_m = prior.contribution_m.sample(**sample_kwargs)
118
+ media_vars[constants.CONTRIBUTION_M] = treatment_parameter_m
119
+ else:
120
+ raise ValueError(f"Unsupported prior type: {prior_type}")
121
+ incremental_outcome_m = (
122
+ treatment_parameter_m * mmm.media_tensors.prior_denominator
261
123
  )
262
- media_vars[constants.ROI_M] = roi_m
263
- media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
264
- beta_m_value, name=constants.BETA_M
265
- ).sample()
266
- elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
267
- mroi_m = prior.mroi_m.sample(**sample_kwargs)
268
- beta_m_value = self.get_roi_prior_beta_m_value(
269
- beta_gm_dev=beta_gm_dev,
270
- media_transformed=media_transformed,
271
- roi_or_mroi_m=mroi_m,
272
- **media_vars,
124
+ media_transformed = mmm.adstock_hill_media(
125
+ media=mmm.media_tensors.media_scaled,
126
+ alpha=media_vars[constants.ALPHA_M],
127
+ ec=media_vars[constants.EC_M],
128
+ slope=media_vars[constants.SLOPE_M],
129
+ )
130
+ linear_predictor_counterfactual_difference = (
131
+ mmm.linear_predictor_counterfactual_difference_media(
132
+ media_transformed=media_transformed,
133
+ alpha_m=media_vars[constants.ALPHA_M],
134
+ ec_m=media_vars[constants.EC_M],
135
+ slope_m=media_vars[constants.SLOPE_M],
136
+ )
137
+ )
138
+ beta_m_value = mmm.calculate_beta_x(
139
+ is_non_media=False,
140
+ incremental_outcome_x=incremental_outcome_m,
141
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
142
+ eta_x=media_vars[constants.ETA_M],
143
+ beta_gx_dev=beta_gm_dev,
273
144
  )
274
- media_vars[constants.MROI_M] = mroi_m
275
145
  media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
276
146
  beta_m_value, name=constants.BETA_M
277
147
  ).sample()
278
- else:
279
- media_vars[constants.BETA_M] = prior.beta_m.sample(**sample_kwargs)
280
148
 
281
149
  beta_eta_combined = (
282
150
  media_vars[constants.BETA_M][..., tf.newaxis, :]
@@ -307,8 +175,9 @@ class PriorDistributionSampler:
307
175
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
308
176
 
309
177
  Returns:
310
- A mapping of RF parameter names to a tensor of shape `[n_draws, n_geos,
311
- n_rf_channels]` or `[n_draws, n_rf_channels]` containing the samples.
178
+ A mapping of RF parameter names to a tensor of shape
179
+ `[n_draws, n_geos, n_rf_channels]` or `[n_draws, n_rf_channels]`
180
+ containing the samples.
312
181
  """
313
182
  mmm = self._meridian
314
183
 
@@ -326,43 +195,51 @@ class PriorDistributionSampler:
326
195
  [mmm.n_geos, mmm.n_rf_channels],
327
196
  name=constants.BETA_GRF_DEV,
328
197
  ).sample(**sample_kwargs)
329
- rf_transformed = mmm.adstock_hill_rf(
330
- reach=mmm.rf_tensors.reach_scaled,
331
- frequency=mmm.rf_tensors.frequency,
332
- alpha=rf_vars[constants.ALPHA_RF],
333
- ec=rf_vars[constants.EC_RF],
334
- slope=rf_vars[constants.SLOPE_RF],
335
- )
336
198
 
337
- prior_type = mmm.model_spec.paid_media_prior_type
338
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
339
- roi_rf = prior.roi_rf.sample(**sample_kwargs)
340
- beta_rf_value = self.get_roi_prior_beta_rf_value(
341
- beta_grf_dev=beta_grf_dev,
342
- rf_transformed=rf_transformed,
343
- roi_or_mroi_rf=roi_rf,
344
- **rf_vars,
199
+ prior_type = mmm.model_spec.effective_rf_prior_type
200
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
201
+ rf_vars[constants.BETA_RF] = prior.beta_rf.sample(**sample_kwargs)
202
+ else:
203
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
204
+ treatment_parameter_rf = prior.roi_rf.sample(**sample_kwargs)
205
+ rf_vars[constants.ROI_RF] = treatment_parameter_rf
206
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
207
+ treatment_parameter_rf = prior.mroi_rf.sample(**sample_kwargs)
208
+ rf_vars[constants.MROI_RF] = treatment_parameter_rf
209
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
210
+ treatment_parameter_rf = prior.contribution_rf.sample(**sample_kwargs)
211
+ rf_vars[constants.CONTRIBUTION_RF] = treatment_parameter_rf
212
+ else:
213
+ raise ValueError(f"Unsupported prior type: {prior_type}")
214
+ incremental_outcome_rf = (
215
+ treatment_parameter_rf * mmm.rf_tensors.prior_denominator
345
216
  )
346
- rf_vars[constants.ROI_RF] = roi_rf
347
- rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
348
- beta_rf_value,
349
- name=constants.BETA_RF,
350
- ).sample()
351
- elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
352
- mroi_rf = prior.mroi_rf.sample(**sample_kwargs)
353
- beta_rf_value = self.get_roi_prior_beta_rf_value(
354
- beta_grf_dev=beta_grf_dev,
355
- rf_transformed=rf_transformed,
356
- roi_or_mroi_rf=mroi_rf,
357
- **rf_vars,
217
+ rf_transformed = mmm.adstock_hill_rf(
218
+ reach=mmm.rf_tensors.reach_scaled,
219
+ frequency=mmm.rf_tensors.frequency,
220
+ alpha=rf_vars[constants.ALPHA_RF],
221
+ ec=rf_vars[constants.EC_RF],
222
+ slope=rf_vars[constants.SLOPE_RF],
223
+ )
224
+ linear_predictor_counterfactual_difference = (
225
+ mmm.linear_predictor_counterfactual_difference_rf(
226
+ rf_transformed=rf_transformed,
227
+ alpha_rf=rf_vars[constants.ALPHA_RF],
228
+ ec_rf=rf_vars[constants.EC_RF],
229
+ slope_rf=rf_vars[constants.SLOPE_RF],
230
+ )
231
+ )
232
+ beta_rf_value = mmm.calculate_beta_x(
233
+ is_non_media=False,
234
+ incremental_outcome_x=incremental_outcome_rf,
235
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
236
+ eta_x=rf_vars[constants.ETA_RF],
237
+ beta_gx_dev=beta_grf_dev,
358
238
  )
359
- rf_vars[constants.MROI_RF] = mroi_rf
360
239
  rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
361
240
  beta_rf_value,
362
241
  name=constants.BETA_RF,
363
242
  ).sample()
364
- else:
365
- rf_vars[constants.BETA_RF] = prior.beta_rf.sample(**sample_kwargs)
366
243
 
367
244
  beta_eta_combined = (
368
245
  rf_vars[constants.BETA_RF][..., tf.newaxis, :]
@@ -393,9 +270,9 @@ class PriorDistributionSampler:
393
270
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
394
271
 
395
272
  Returns:
396
- A mapping of organic media parameter names to a tensor of shape [n_draws,
397
- n_geos, n_organic_media_channels] or [n_draws, n_organic_media_channels]
398
- containing the samples.
273
+ A mapping of organic media parameter names to a tensor of shape
274
+ `[n_draws, n_geos, n_organic_media_channels]` or
275
+ `[n_draws, n_organic_media_channels]` containing the samples.
399
276
  """
400
277
  mmm = self._meridian
401
278
 
@@ -414,9 +291,37 @@ class PriorDistributionSampler:
414
291
  name=constants.BETA_GOM_DEV,
415
292
  ).sample(**sample_kwargs)
416
293
 
417
- organic_media_vars[constants.BETA_OM] = prior.beta_om.sample(
418
- **sample_kwargs
419
- )
294
+ prior_type = mmm.model_spec.organic_media_prior_type
295
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
296
+ organic_media_vars[constants.BETA_OM] = prior.beta_om.sample(
297
+ **sample_kwargs
298
+ )
299
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
300
+ organic_media_vars[constants.CONTRIBUTION_OM] = (
301
+ prior.contribution_om.sample(**sample_kwargs)
302
+ )
303
+ incremental_outcome_om = (
304
+ organic_media_vars[constants.CONTRIBUTION_OM] * mmm.total_outcome
305
+ )
306
+ organic_media_transformed = mmm.adstock_hill_media(
307
+ media=mmm.organic_media_tensors.organic_media_scaled,
308
+ alpha=organic_media_vars[constants.ALPHA_OM],
309
+ ec=organic_media_vars[constants.EC_OM],
310
+ slope=organic_media_vars[constants.SLOPE_OM],
311
+ )
312
+ beta_om_value = mmm.calculate_beta_x(
313
+ is_non_media=False,
314
+ incremental_outcome_x=incremental_outcome_om,
315
+ linear_predictor_counterfactual_difference=organic_media_transformed,
316
+ eta_x=organic_media_vars[constants.ETA_OM],
317
+ beta_gx_dev=beta_gom_dev,
318
+ )
319
+ organic_media_vars[constants.BETA_OM] = tfp.distributions.Deterministic(
320
+ beta_om_value,
321
+ name=constants.BETA_OM,
322
+ ).sample()
323
+ else:
324
+ raise ValueError(f"Unsupported prior type: {prior_type}")
420
325
 
421
326
  beta_eta_combined = (
422
327
  organic_media_vars[constants.BETA_OM][..., tf.newaxis, :]
@@ -448,9 +353,9 @@ class PriorDistributionSampler:
448
353
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
449
354
 
450
355
  Returns:
451
- A mapping of organic RF parameter names to a tensor of shape [n_draws,
452
- n_geos, n_organic_rf_channels] or [n_draws, n_organic_rf_channels]
453
- containing the samples.
356
+ A mapping of organic RF parameter names to a tensor of shape
357
+ `[n_draws, n_geos, n_organic_rf_channels]` or
358
+ `[n_draws, n_organic_rf_channels]` containing the samples.
454
359
  """
455
360
  mmm = self._meridian
456
361
 
@@ -469,7 +374,38 @@ class PriorDistributionSampler:
469
374
  name=constants.BETA_GORF_DEV,
470
375
  ).sample(**sample_kwargs)
471
376
 
472
- organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(**sample_kwargs)
377
+ prior_type = mmm.model_spec.organic_media_prior_type
378
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
379
+ organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(
380
+ **sample_kwargs
381
+ )
382
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
383
+ organic_rf_vars[constants.CONTRIBUTION_ORF] = (
384
+ prior.contribution_orf.sample(**sample_kwargs)
385
+ )
386
+ incremental_outcome_orf = (
387
+ organic_rf_vars[constants.CONTRIBUTION_ORF] * mmm.total_outcome
388
+ )
389
+ organic_rf_transformed = mmm.adstock_hill_rf(
390
+ reach=mmm.organic_rf_tensors.organic_reach_scaled,
391
+ frequency=mmm.organic_rf_tensors.organic_frequency,
392
+ alpha=organic_rf_vars[constants.ALPHA_ORF],
393
+ ec=organic_rf_vars[constants.EC_ORF],
394
+ slope=organic_rf_vars[constants.SLOPE_ORF],
395
+ )
396
+ beta_orf_value = mmm.calculate_beta_x(
397
+ is_non_media=False,
398
+ incremental_outcome_x=incremental_outcome_orf,
399
+ linear_predictor_counterfactual_difference=organic_rf_transformed,
400
+ eta_x=organic_rf_vars[constants.ETA_ORF],
401
+ beta_gx_dev=beta_gorf_dev,
402
+ )
403
+ organic_rf_vars[constants.BETA_ORF] = tfp.distributions.Deterministic(
404
+ beta_orf_value,
405
+ name=constants.BETA_ORF,
406
+ ).sample()
407
+ else:
408
+ raise ValueError(f"Unsupported prior type: {prior_type}")
473
409
 
474
410
  beta_eta_combined = (
475
411
  organic_rf_vars[constants.BETA_ORF][..., tf.newaxis, :]
@@ -501,9 +437,8 @@ class PriorDistributionSampler:
501
437
 
502
438
  Returns:
503
439
  A mapping of non-media treatment parameter names to a tensor of shape
504
- [n_draws,
505
- n_geos, n_non_media_channels] or [n_draws, n_non_media_channels]
506
- containing the samples.
440
+ `[n_draws, n_geos, n_non_media_channels]` or
441
+ `[n_draws, n_non_media_channels]` containing the samples.
507
442
  """
508
443
  mmm = self._meridian
509
444
 
@@ -511,7 +446,6 @@ class PriorDistributionSampler:
511
446
  sample_shape = [1, n_draws]
512
447
  sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
513
448
  non_media_treatments_vars = {
514
- constants.GAMMA_N: prior.gamma_n.sample(**sample_kwargs),
515
449
  constants.XI_N: prior.xi_n.sample(**sample_kwargs),
516
450
  }
517
451
  gamma_gn_dev = tfp.distributions.Sample(
@@ -519,6 +453,39 @@ class PriorDistributionSampler:
519
453
  [mmm.n_geos, mmm.n_non_media_channels],
520
454
  name=constants.GAMMA_GN_DEV,
521
455
  ).sample(**sample_kwargs)
456
+ prior_type = mmm.model_spec.non_media_treatments_prior_type
457
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
458
+ non_media_treatments_vars[constants.GAMMA_N] = prior.gamma_n.sample(
459
+ **sample_kwargs
460
+ )
461
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
462
+ non_media_treatments_vars[constants.CONTRIBUTION_N] = (
463
+ prior.contribution_n.sample(**sample_kwargs)
464
+ )
465
+ incremental_outcome_n = (
466
+ non_media_treatments_vars[constants.CONTRIBUTION_N]
467
+ * mmm.total_outcome
468
+ )
469
+ baseline_scaled = mmm.non_media_transformer.forward( # pytype: disable=attribute-error
470
+ mmm.compute_non_media_treatments_baseline()
471
+ )
472
+ linear_predictor_counterfactual_difference = (
473
+ mmm.non_media_treatments_normalized - baseline_scaled
474
+ )
475
+ gamma_n_value = mmm.calculate_beta_x(
476
+ is_non_media=True,
477
+ incremental_outcome_x=incremental_outcome_n,
478
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
479
+ eta_x=non_media_treatments_vars[constants.XI_N],
480
+ beta_gx_dev=gamma_gn_dev,
481
+ )
482
+ non_media_treatments_vars[constants.GAMMA_N] = (
483
+ tfp.distributions.Deterministic(
484
+ gamma_n_value, name=constants.GAMMA_N
485
+ ).sample()
486
+ )
487
+ else:
488
+ raise ValueError(f"Unsupported prior type: {prior_type}")
522
489
  non_media_treatments_vars[constants.GAMMA_GN] = (
523
490
  tfp.distributions.Deterministic(
524
491
  non_media_treatments_vars[constants.GAMMA_N][..., tf.newaxis, :]
@@ -549,14 +516,15 @@ class PriorDistributionSampler:
549
516
  tau_g_excl_baseline = prior.tau_g_excl_baseline.sample(**sample_kwargs)
550
517
  base_vars = {
551
518
  constants.KNOT_VALUES: prior.knot_values.sample(**sample_kwargs),
552
- constants.GAMMA_C: prior.gamma_c.sample(**sample_kwargs),
553
- constants.XI_C: prior.xi_c.sample(**sample_kwargs),
554
519
  constants.SIGMA: prior.sigma.sample(**sample_kwargs),
555
- constants.TAU_G: _get_tau_g(
556
- tau_g_excl_baseline=tau_g_excl_baseline,
557
- baseline_geo_idx=mmm.baseline_geo_idx,
558
- ).sample(),
520
+ constants.TAU_G: (
521
+ _get_tau_g(
522
+ tau_g_excl_baseline=tau_g_excl_baseline,
523
+ baseline_geo_idx=mmm.baseline_geo_idx,
524
+ ).sample()
525
+ ),
559
526
  }
527
+
560
528
  base_vars[constants.MU_T] = tfp.distributions.Deterministic(
561
529
  tf.einsum(
562
530
  "...k,kt->...t",
@@ -566,16 +534,24 @@ class PriorDistributionSampler:
566
534
  name=constants.MU_T,
567
535
  ).sample()
568
536
 
569
- gamma_gc_dev = tfp.distributions.Sample(
570
- tfp.distributions.Normal(0, 1),
571
- [mmm.n_geos, mmm.n_controls],
572
- name=constants.GAMMA_GC_DEV,
573
- ).sample(**sample_kwargs)
574
- base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic(
575
- base_vars[constants.GAMMA_C][..., tf.newaxis, :]
576
- + base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev,
577
- name=constants.GAMMA_GC,
578
- ).sample()
537
+ # Omit gamma_c, xi_c, and gamma_gc parameters from sampled distributions if
538
+ # there are no control variables in the model.
539
+ if mmm.n_controls:
540
+ base_vars |= {
541
+ constants.GAMMA_C: prior.gamma_c.sample(**sample_kwargs),
542
+ constants.XI_C: prior.xi_c.sample(**sample_kwargs),
543
+ }
544
+
545
+ gamma_gc_dev = tfp.distributions.Sample(
546
+ tfp.distributions.Normal(0, 1),
547
+ [mmm.n_geos, mmm.n_controls],
548
+ name=constants.GAMMA_GC_DEV,
549
+ ).sample(**sample_kwargs)
550
+ base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic(
551
+ base_vars[constants.GAMMA_C][..., tf.newaxis, :]
552
+ + base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev,
553
+ name=constants.GAMMA_GC,
554
+ ).sample()
579
555
 
580
556
  media_vars = (
581
557
  self._sample_media_priors(n_draws, seed)
@@ -599,7 +575,7 @@ class PriorDistributionSampler:
599
575
  )
600
576
  non_media_treatments_vars = (
601
577
  self._sample_non_media_treatments_priors(n_draws, seed)
602
- if mmm.non_media_treatments_scaled is not None
578
+ if mmm.non_media_treatments_normalized is not None
603
579
  else {}
604
580
  )
605
581