google-meridian 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -70,6 +70,10 @@ class WithInputDataSamples:
70
70
  _TEST_DIR,
71
71
  "sample_prior_media_only.nc",
72
72
  )
73
+ _TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
74
+ _TEST_DIR,
75
+ "sample_prior_media_only_no_controls.nc",
76
+ )
73
77
  _TEST_SAMPLE_PRIOR_RF_ONLY_PATH = os.path.join(
74
78
  _TEST_DIR,
75
79
  "sample_prior_rf_only.nc",
@@ -82,6 +86,10 @@ class WithInputDataSamples:
82
86
  _TEST_DIR,
83
87
  "sample_posterior_media_only.nc",
84
88
  )
89
+ _TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
90
+ _TEST_DIR,
91
+ "sample_posterior_media_only_no_controls.nc",
92
+ )
85
93
  _TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH = os.path.join(
86
94
  _TEST_DIR,
87
95
  "sample_posterior_rf_only.nc",
@@ -130,6 +138,17 @@ class WithInputDataSamples:
130
138
  seed=0,
131
139
  )
132
140
  )
141
+ self.input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
142
+ test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
143
+ n_geos=self._N_GEOS,
144
+ n_times=self._N_TIMES,
145
+ n_media_times=self._N_MEDIA_TIMES,
146
+ n_controls=self._N_CONTROLS,
147
+ n_media_channels=self._N_MEDIA_CHANNELS,
148
+ n_rf_channels=self._N_RF_CHANNELS,
149
+ seed=0,
150
+ )
151
+ )
133
152
  self.input_data_with_media_only = (
134
153
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
135
154
  n_geos=self._N_GEOS,
@@ -161,6 +180,17 @@ class WithInputDataSamples:
161
180
  seed=0,
162
181
  )
163
182
  )
183
+ self.input_data_with_media_and_rf_no_controls = (
184
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
185
+ n_geos=self._N_GEOS,
186
+ n_times=self._N_TIMES,
187
+ n_media_times=self._N_MEDIA_TIMES,
188
+ n_controls=None,
189
+ n_media_channels=self._N_MEDIA_CHANNELS,
190
+ n_rf_channels=self._N_RF_CHANNELS,
191
+ seed=0,
192
+ )
193
+ )
164
194
  self.short_input_data_with_media_only = (
165
195
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
166
196
  n_geos=self._N_GEOS,
@@ -171,6 +201,16 @@ class WithInputDataSamples:
171
201
  seed=0,
172
202
  )
173
203
  )
204
+ self.short_input_data_with_media_only_no_controls = (
205
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
206
+ n_geos=self._N_GEOS,
207
+ n_times=self._N_TIMES_SHORT,
208
+ n_media_times=self._N_MEDIA_TIMES_SHORT,
209
+ n_controls=0,
210
+ n_media_channels=self._N_MEDIA_CHANNELS,
211
+ seed=0,
212
+ )
213
+ )
174
214
  self.short_input_data_with_rf_only = (
175
215
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
176
216
  n_geos=self._N_GEOS,
@@ -220,6 +260,9 @@ class WithInputDataSamples:
220
260
  test_prior_media_only = xr.open_dataset(
221
261
  self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
222
262
  )
263
+ test_prior_media_only_no_controls = xr.open_dataset(
264
+ self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
265
+ )
223
266
  test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
224
267
  self.test_dist_media_and_rf = collections.OrderedDict({
225
268
  param: tf.convert_to_tensor(test_prior_media_and_rf[param])
@@ -232,6 +275,18 @@ class WithInputDataSamples:
232
275
  for param in constants.COMMON_PARAMETER_NAMES
233
276
  + constants.MEDIA_PARAMETER_NAMES
234
277
  })
278
+ self.test_dist_media_only_no_controls = collections.OrderedDict({
279
+ param: tf.convert_to_tensor(test_prior_media_only_no_controls[param])
280
+ for param in (
281
+ set(
282
+ constants.COMMON_PARAMETER_NAMES
283
+ + constants.MEDIA_PARAMETER_NAMES
284
+ )
285
+ - set(
286
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
287
+ )
288
+ )
289
+ })
235
290
  self.test_dist_rf_only = collections.OrderedDict({
236
291
  param: tf.convert_to_tensor(test_prior_rf_only[param])
237
292
  for param in constants.COMMON_PARAMETER_NAMES
@@ -244,6 +299,9 @@ class WithInputDataSamples:
244
299
  test_posterior_media_only = xr.open_dataset(
245
300
  self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
246
301
  )
302
+ test_posterior_media_only_no_controls = xr.open_dataset(
303
+ self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
304
+ )
247
305
  test_posterior_rf_only = xr.open_dataset(
248
306
  self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
249
307
  )
@@ -262,6 +320,21 @@ class WithInputDataSamples:
262
320
  for param in constants.COMMON_PARAMETER_NAMES
263
321
  + constants.MEDIA_PARAMETER_NAMES
264
322
  }
323
+ posterior_params_to_tensors_media_only_no_controls = {
324
+ param: _convert_with_swap(
325
+ test_posterior_media_only_no_controls[param],
326
+ n_burnin=self._N_BURNIN,
327
+ )
328
+ for param in (
329
+ set(
330
+ constants.COMMON_PARAMETER_NAMES
331
+ + constants.MEDIA_PARAMETER_NAMES
332
+ )
333
+ - set(
334
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
335
+ )
336
+ )
337
+ }
265
338
  posterior_params_to_tensors_rf_only = {
266
339
  param: _convert_with_swap(
267
340
  test_posterior_rf_only[param], n_burnin=self._N_BURNIN
@@ -279,6 +352,18 @@ class WithInputDataSamples:
279
352
  "StructTuple",
280
353
  constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
281
354
  )(**posterior_params_to_tensors_media_only)
355
+ self.test_posterior_states_media_only_no_controls = collections.namedtuple(
356
+ "StructTuple",
357
+ (
358
+ set(
359
+ constants.COMMON_PARAMETER_NAMES
360
+ + constants.MEDIA_PARAMETER_NAMES
361
+ )
362
+ - set(
363
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
364
+ )
365
+ ),
366
+ )(**posterior_params_to_tensors_media_only_no_controls)
282
367
  self.test_posterior_states_rf_only = collections.namedtuple(
283
368
  "StructTuple",
284
369
  constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -110,23 +110,16 @@ class PosteriorMCMCSampler:
110
110
  organic_media_tensors = mmm.organic_media_tensors
111
111
  organic_rf_tensors = mmm.organic_rf_tensors
112
112
  controls_scaled = mmm.controls_scaled
113
- non_media_treatments_scaled = mmm.non_media_treatments_scaled
113
+ non_media_treatments_normalized = mmm.non_media_treatments_normalized
114
114
  media_effects_dist = mmm.media_effects_dist
115
115
  adstock_hill_media_fn = mmm.adstock_hill_media
116
116
  adstock_hill_rf_fn = mmm.adstock_hill_rf
117
- get_roi_prior_beta_m_value_fn = (
118
- mmm.prior_sampler_callable.get_roi_prior_beta_m_value
119
- )
120
- get_roi_prior_beta_rf_value_fn = (
121
- mmm.prior_sampler_callable.get_roi_prior_beta_rf_value
122
- )
117
+ total_outcome = mmm.total_outcome
123
118
 
124
119
  @tfp.distributions.JointDistributionCoroutineAutoBatched
125
120
  def joint_dist_unpinned():
126
121
  # Sample directly from prior.
127
122
  knot_values = yield prior_broadcast.knot_values
128
- gamma_c = yield prior_broadcast.gamma_c
129
- xi_c = yield prior_broadcast.xi_c
130
123
  sigma = yield prior_broadcast.sigma
131
124
 
132
125
  tau_g_excl_baseline = yield tfp.distributions.Sample(
@@ -167,26 +160,39 @@ class PosteriorMCMCSampler:
167
160
  ec=ec_m,
168
161
  slope=slope_m,
169
162
  )
170
- prior_type = mmm.model_spec.paid_media_prior_type
171
- if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
172
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
173
- roi_or_mroi_m = yield prior_broadcast.roi_m
163
+ prior_type = mmm.model_spec.effective_media_prior_type
164
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
165
+ beta_m = yield prior_broadcast.beta_m
166
+ else:
167
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
168
+ treatment_parameter_m = yield prior_broadcast.roi_m
169
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
170
+ treatment_parameter_m = yield prior_broadcast.mroi_m
171
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
172
+ treatment_parameter_m = yield prior_broadcast.contribution_m
174
173
  else:
175
- roi_or_mroi_m = yield prior_broadcast.mroi_m
176
- beta_m_value = get_roi_prior_beta_m_value_fn(
177
- alpha_m,
178
- beta_gm_dev,
179
- ec_m,
180
- eta_m,
181
- roi_or_mroi_m,
182
- slope_m,
183
- media_transformed,
174
+ raise ValueError(f"Unsupported prior type: {prior_type}")
175
+ incremental_outcome_m = (
176
+ treatment_parameter_m * media_tensors.prior_denominator
177
+ )
178
+ linear_predictor_counterfactual_difference = (
179
+ mmm.linear_predictor_counterfactual_difference_media(
180
+ media_transformed=media_transformed,
181
+ alpha_m=alpha_m,
182
+ ec_m=ec_m,
183
+ slope_m=slope_m,
184
+ )
185
+ )
186
+ beta_m_value = mmm.calculate_beta_x(
187
+ is_non_media=False,
188
+ incremental_outcome_x=incremental_outcome_m,
189
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
190
+ eta_x=eta_m,
191
+ beta_gx_dev=beta_gm_dev,
184
192
  )
185
193
  beta_m = yield tfp.distributions.Deterministic(
186
194
  beta_m_value, name=constants.BETA_M
187
195
  )
188
- else:
189
- beta_m = yield prior_broadcast.beta_m
190
196
 
191
197
  beta_eta_combined = beta_m + eta_m * beta_gm_dev
192
198
  beta_gm_value = (
@@ -220,27 +226,39 @@ class PosteriorMCMCSampler:
220
226
  slope=slope_rf,
221
227
  )
222
228
 
223
- prior_type = mmm.model_spec.paid_media_prior_type
224
- if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
225
- if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
226
- roi_or_mroi_rf = yield prior_broadcast.roi_rf
229
+ prior_type = mmm.model_spec.effective_rf_prior_type
230
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
231
+ beta_rf = yield prior_broadcast.beta_rf
232
+ else:
233
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_ROI:
234
+ treatment_parameter_rf = yield prior_broadcast.roi_rf
235
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_MROI:
236
+ treatment_parameter_rf = yield prior_broadcast.mroi_rf
237
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
238
+ treatment_parameter_rf = yield prior_broadcast.contribution_rf
227
239
  else:
228
- roi_or_mroi_rf = yield prior_broadcast.mroi_rf
229
- beta_rf_value = get_roi_prior_beta_rf_value_fn(
230
- alpha_rf,
231
- beta_grf_dev,
232
- ec_rf,
233
- eta_rf,
234
- roi_or_mroi_rf,
235
- slope_rf,
236
- rf_transformed,
240
+ raise ValueError(f"Unsupported prior type: {prior_type}")
241
+ incremental_outcome_rf = (
242
+ treatment_parameter_rf * rf_tensors.prior_denominator
243
+ )
244
+ linear_predictor_counterfactual_difference = (
245
+ mmm.linear_predictor_counterfactual_difference_rf(
246
+ rf_transformed=rf_transformed,
247
+ alpha_rf=alpha_rf,
248
+ ec_rf=ec_rf,
249
+ slope_rf=slope_rf,
250
+ )
251
+ )
252
+ beta_rf_value = mmm.calculate_beta_x(
253
+ is_non_media=False,
254
+ incremental_outcome_x=incremental_outcome_rf,
255
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
256
+ eta_x=eta_rf,
257
+ beta_gx_dev=beta_grf_dev,
237
258
  )
238
259
  beta_rf = yield tfp.distributions.Deterministic(
239
- beta_rf_value,
240
- name=constants.BETA_RF,
260
+ beta_rf_value, name=constants.BETA_RF
241
261
  )
242
- else:
243
- beta_rf = yield prior_broadcast.beta_rf
244
262
 
245
263
  beta_eta_combined = beta_rf + eta_rf * beta_grf_dev
246
264
  beta_grf_value = (
@@ -272,7 +290,24 @@ class PosteriorMCMCSampler:
272
290
  ec=ec_om,
273
291
  slope=slope_om,
274
292
  )
275
- beta_om = yield prior_broadcast.beta_om
293
+ prior_type = mmm.model_spec.organic_media_prior_type
294
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
295
+ beta_om = yield prior_broadcast.beta_om
296
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
297
+ contribution_om = yield prior_broadcast.contribution_om
298
+ incremental_outcome_om = contribution_om * total_outcome
299
+ beta_om_value = mmm.calculate_beta_x(
300
+ is_non_media=False,
301
+ incremental_outcome_x=incremental_outcome_om,
302
+ linear_predictor_counterfactual_difference=organic_media_transformed,
303
+ eta_x=eta_om,
304
+ beta_gx_dev=beta_gom_dev,
305
+ )
306
+ beta_om = yield tfp.distributions.Deterministic(
307
+ beta_om_value, name=constants.BETA_OM
308
+ )
309
+ else:
310
+ raise ValueError(f"Unsupported prior type: {prior_type}")
276
311
 
277
312
  beta_eta_combined = beta_om + eta_om * beta_gom_dev
278
313
  beta_gom_value = (
@@ -306,7 +341,24 @@ class PosteriorMCMCSampler:
306
341
  slope=slope_orf,
307
342
  )
308
343
 
309
- beta_orf = yield prior_broadcast.beta_orf
344
+ prior_type = mmm.model_spec.organic_rf_prior_type
345
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
346
+ beta_orf = yield prior_broadcast.beta_orf
347
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
348
+ contribution_orf = yield prior_broadcast.contribution_orf
349
+ incremental_outcome_orf = contribution_orf * total_outcome
350
+ beta_orf_value = mmm.calculate_beta_x(
351
+ is_non_media=False,
352
+ incremental_outcome_x=incremental_outcome_orf,
353
+ linear_predictor_counterfactual_difference=organic_rf_transformed,
354
+ eta_x=eta_orf,
355
+ beta_gx_dev=beta_gorf_dev,
356
+ )
357
+ beta_orf = yield tfp.distributions.Deterministic(
358
+ beta_orf_value, name=constants.BETA_ORF
359
+ )
360
+ else:
361
+ raise ValueError(f"Unsupported prior type: {prior_type}")
310
362
 
311
363
  beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev
312
364
  beta_gorf_value = (
@@ -323,33 +375,62 @@ class PosteriorMCMCSampler:
323
375
  combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1)
324
376
 
325
377
  sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos]))
326
- gamma_gc_dev = yield tfp.distributions.Sample(
327
- tfp.distributions.Normal(0, 1),
328
- [n_geos, n_controls],
329
- name=constants.GAMMA_GC_DEV,
330
- )
331
- gamma_gc = yield tfp.distributions.Deterministic(
332
- gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
333
- )
334
- y_pred_combined_media = (
335
- tau_gt
336
- + tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta)
337
- + tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc)
378
+ y_pred_combined_media = tau_gt + tf.einsum(
379
+ "gtm,gm->gt", combined_media_transformed, combined_beta
338
380
  )
381
+ # Omit gamma_c, xi_c, and gamma_gc from joint distribution output if
382
+ # there are no control variables in the model.
383
+ if n_controls:
384
+ gamma_c = yield prior_broadcast.gamma_c
385
+ xi_c = yield prior_broadcast.xi_c
386
+ gamma_gc_dev = yield tfp.distributions.Sample(
387
+ tfp.distributions.Normal(0, 1),
388
+ [n_geos, n_controls],
389
+ name=constants.GAMMA_GC_DEV,
390
+ )
391
+ gamma_gc = yield tfp.distributions.Deterministic(
392
+ gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
393
+ )
394
+ y_pred_combined_media += tf.einsum(
395
+ "gtc,gc->gt", controls_scaled, gamma_gc
396
+ )
339
397
 
340
398
  if mmm.non_media_treatments is not None:
341
- gamma_n = yield prior_broadcast.gamma_n
342
399
  xi_n = yield prior_broadcast.xi_n
343
400
  gamma_gn_dev = yield tfp.distributions.Sample(
344
401
  tfp.distributions.Normal(0, 1),
345
402
  [n_geos, n_non_media_channels],
346
403
  name=constants.GAMMA_GN_DEV,
347
404
  )
405
+ prior_type = mmm.model_spec.non_media_treatments_prior_type
406
+ if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
407
+ gamma_n = yield prior_broadcast.gamma_n
408
+ elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
409
+ contribution_n = yield prior_broadcast.contribution_n
410
+ incremental_outcome_n = contribution_n * total_outcome
411
+ baseline_scaled = mmm.non_media_transformer.forward( # pytype: disable=attribute-error
412
+ mmm.compute_non_media_treatments_baseline()
413
+ )
414
+ linear_predictor_counterfactual_difference = (
415
+ non_media_treatments_normalized - baseline_scaled
416
+ )
417
+ gamma_n_value = mmm.calculate_beta_x(
418
+ is_non_media=True,
419
+ incremental_outcome_x=incremental_outcome_n,
420
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
421
+ eta_x=xi_n,
422
+ beta_gx_dev=gamma_gn_dev,
423
+ )
424
+ gamma_n = yield tfp.distributions.Deterministic(
425
+ gamma_n_value, name=constants.GAMMA_N
426
+ )
427
+ else:
428
+ raise ValueError(f"Unsupported prior type: {prior_type}")
348
429
  gamma_gn = yield tfp.distributions.Deterministic(
349
430
  gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
350
431
  )
351
432
  y_pred = y_pred_combined_media + tf.einsum(
352
- "gtn,gn->gt", non_media_treatments_scaled, gamma_gn
433
+ "gtn,gn->gt", non_media_treatments_normalized, gamma_gn
353
434
  )
354
435
  else:
355
436
  y_pred = y_pred_combined_media
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -84,6 +84,11 @@ class PriorDistribution:
84
84
  | `roi_rf` | `n_rf_channels` |
85
85
  | `mroi_m` | `n_media_channels` |
86
86
  | `mroi_rf` | `n_rf_channels` |
87
+ | `contribution_m` | `n_media_channels` |
88
+ | `contribution_rf` | `n_rf_channels` |
89
+ | `contribution_om` | `n_organic_media_channels` |
90
+ | `contribution_orf` | `n_organic_f_channels` |
91
+ | `contribution_n` | `n_non_media_channels` |
87
92
 
88
93
  (σ) `n_geos` if `unique_sigma_for_each_geo`, otherwise this is `1`
89
94
 
@@ -233,6 +238,36 @@ class PriorDistribution:
233
238
  is interpreted as the marginal incremental KPI units per monetary unit
234
239
  spent. In this case, a default distribution is not provided, so the user
235
240
  must specify it.
241
+ contribution_m: Prior distribution on the contribution of each media channel
242
+ as a percentage of total outcome. This parameter is only used when
243
+ `paid_media_prior_type` is `'contribution'`, in which case `beta_m` is
244
+ calculated as a deterministic function of `contribution_m`, `alpha_m`,
245
+ `ec_m`, `slope_m`, and the total outcome. Default distribution is
246
+ `Beta(1.0, 99.0)`.
247
+ contribution_rf: Prior distribution on the contribution of each Reach &
248
+ Frequency channel as a percentage of total outcome. This parameter is only
249
+ used when `paid_media_prior_type` is `'contribution'`, in which case
250
+ `beta_rf` is calculated as a deterministic function of `contribution_rf`,
251
+ `alpha_rf`, `ec_rf`, `slope_rf`, and the total outcome. Default
252
+ distribution is `Beta(1.0, 99.0)`.
253
+ contribution_om: Prior distribution on the contribution of each organic
254
+ media channel as a percentage of total outcome. This parameter is only
255
+ used when `organic_media_prior_type` is `'contribution'`, in which case
256
+ `beta_om` is calculated as a deterministic function of `contribution_om`,
257
+ `alpha_om`, `ec_om`, `slope_om`, and the total outcome. Default
258
+ distribution is `Beta(1.0, 99.0)`.
259
+ contribution_orf: Prior distribution on the contribution of each organic
260
+ Reach & Frequency channel as a percentage of total outcome. This parameter
261
+ is only used when `organic_media_prior_type` is `'contribution'`, in which
262
+ case `beta_orf` is calculated as a deterministic function of
263
+ `contribution_orf`, `alpha_orf`, `ec_orf`, `slope_orf`, and the total
264
+ outcome. Default distribution is `Beta(1.0, 99.0)`.
265
+ contribution_n: Prior distribution on the contribution of each non-media
266
+ treatment channel as a percentage of total outcome. This parameter is only
267
+ used when `non_media_treatment_prior_type` is `'contribution'`, in which
268
+ case `gamma_n` is calculated as a deterministic function of
269
+ `contribution_n` and the total outcome. Default distribution is
270
+ `TruncatedNormal(0.0, 0.1, -1.0, 1.0)`.
236
271
  """
237
272
 
238
273
  knot_values: tfp.distributions.Distribution = dataclasses.field(
@@ -394,6 +429,31 @@ class PriorDistribution:
394
429
  0.0, 0.5, name=constants.MROI_RF
395
430
  ),
396
431
  )
432
+ contribution_m: tfp.distributions.Distribution = dataclasses.field(
433
+ default_factory=lambda: tfp.distributions.Beta(
434
+ 1.0, 99.0, name=constants.CONTRIBUTION_M
435
+ ),
436
+ )
437
+ contribution_rf: tfp.distributions.Distribution = dataclasses.field(
438
+ default_factory=lambda: tfp.distributions.Beta(
439
+ 1.0, 99.0, name=constants.CONTRIBUTION_RF
440
+ ),
441
+ )
442
+ contribution_om: tfp.distributions.Distribution = dataclasses.field(
443
+ default_factory=lambda: tfp.distributions.Beta(
444
+ 1.0, 99.0, name=constants.CONTRIBUTION_OM
445
+ ),
446
+ )
447
+ contribution_orf: tfp.distributions.Distribution = dataclasses.field(
448
+ default_factory=lambda: tfp.distributions.Beta(
449
+ 1.0, 99.0, name=constants.CONTRIBUTION_ORF
450
+ ),
451
+ )
452
+ contribution_n: tfp.distributions.Distribution = dataclasses.field(
453
+ default_factory=lambda: tfp.distributions.TruncatedNormal(
454
+ loc=0.0, scale=0.1, low=-1.0, high=1.0, name=constants.CONTRIBUTION_N
455
+ ),
456
+ )
397
457
 
398
458
  def __setstate__(self, state):
399
459
  # Override to support pickling.
@@ -455,7 +515,6 @@ class PriorDistribution:
455
515
  set_total_media_contribution_prior: bool,
456
516
  kpi: float,
457
517
  total_spend: np.ndarray,
458
- media_effects_dist: str,
459
518
  ) -> PriorDistribution:
460
519
  """Returns a new `PriorDistribution` with broadcast distribution attributes.
461
520
 
@@ -481,8 +540,6 @@ class PriorDistribution:
481
540
  `set_total_media_contribution_prior=True`.
482
541
  total_spend: Spend per media channel summed across geos and time. Required
483
542
  if `set_total_media_contribution_prior=True`.
484
- media_effects_dist: A string to specify the distribution of media random
485
- effects across geos.
486
543
 
487
544
  Returns:
488
545
  A new `PriorDistribution` broadcast from this prior distribution,
@@ -508,6 +565,7 @@ class PriorDistribution:
508
565
 
509
566
  _validate_media_custom_priors(self.roi_m)
510
567
  _validate_media_custom_priors(self.mroi_m)
568
+ _validate_media_custom_priors(self.contribution_m)
511
569
  _validate_media_custom_priors(self.alpha_m)
512
570
  _validate_media_custom_priors(self.ec_m)
513
571
  _validate_media_custom_priors(self.slope_m)
@@ -529,6 +587,7 @@ class PriorDistribution:
529
587
  'that channel.'
530
588
  )
531
589
 
590
+ _validate_organic_media_custom_priors(self.contribution_om)
532
591
  _validate_organic_media_custom_priors(self.alpha_om)
533
592
  _validate_organic_media_custom_priors(self.ec_om)
534
593
  _validate_organic_media_custom_priors(self.slope_om)
@@ -550,6 +609,7 @@ class PriorDistribution:
550
609
  'for that channel.'
551
610
  )
552
611
 
612
+ _validate_organic_rf_custom_priors(self.contribution_orf)
553
613
  _validate_organic_rf_custom_priors(self.alpha_orf)
554
614
  _validate_organic_rf_custom_priors(self.ec_orf)
555
615
  _validate_organic_rf_custom_priors(self.slope_orf)
@@ -569,6 +629,7 @@ class PriorDistribution:
569
629
 
570
630
  _validate_rf_custom_priors(self.roi_rf)
571
631
  _validate_rf_custom_priors(self.mroi_rf)
632
+ _validate_rf_custom_priors(self.contribution_rf)
572
633
  _validate_rf_custom_priors(self.alpha_rf)
573
634
  _validate_rf_custom_priors(self.ec_rf)
574
635
  _validate_rf_custom_priors(self.slope_rf)
@@ -604,6 +665,7 @@ class PriorDistribution:
604
665
  'that channel.'
605
666
  )
606
667
 
668
+ _validate_non_media_custom_priors(self.contribution_n)
607
669
  _validate_non_media_custom_priors(self.gamma_n)
608
670
  _validate_non_media_custom_priors(self.xi_n)
609
671
 
@@ -743,57 +805,50 @@ class PriorDistribution:
743
805
  self.sigma, sigma_shape, name=constants.SIGMA
744
806
  )
745
807
 
746
- default_distribution = PriorDistribution()
747
- if set_total_media_contribution_prior and distributions_are_equal(
748
- self.roi_m, default_distribution.roi_m
749
- ):
750
- warnings.warn(
751
- 'Consider setting custom ROI priors, as kpi_type was specified as'
752
- ' `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the'
753
- ' total media contribution prior will be used with'
754
- f' `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further'
755
- ' documentation available at '
756
- ' https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior',
757
- )
808
+ if set_total_media_contribution_prior:
758
809
  roi_m_converted = _get_total_media_contribution_prior(
759
810
  kpi, total_spend, constants.ROI_M
760
811
  )
761
- else:
762
- roi_m_converted = self.roi_m
763
- _check_for_negative_effect(roi_m_converted, media_effects_dist)
764
- roi_m = tfp.distributions.BatchBroadcast(
765
- roi_m_converted, n_media_channels, name=constants.ROI_M
766
- )
767
-
768
- if set_total_media_contribution_prior and distributions_are_equal(
769
- self.roi_rf, default_distribution.roi_rf
770
- ):
771
- warnings.warn(
772
- 'Consider setting custom ROI priors, as kpi_type was specified as'
773
- ' `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the'
774
- ' total media contribution prior will be used with'
775
- f' `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further'
776
- ' documentation available at '
777
- ' https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior',
778
- )
779
812
  roi_rf_converted = _get_total_media_contribution_prior(
780
813
  kpi, total_spend, constants.ROI_RF
781
814
  )
782
815
  else:
816
+ roi_m_converted = self.roi_m
783
817
  roi_rf_converted = self.roi_rf
784
- _check_for_negative_effect(roi_rf_converted, media_effects_dist)
818
+ roi_m = tfp.distributions.BatchBroadcast(
819
+ roi_m_converted, n_media_channels, name=constants.ROI_M
820
+ )
785
821
  roi_rf = tfp.distributions.BatchBroadcast(
786
822
  roi_rf_converted, n_rf_channels, name=constants.ROI_RF
787
823
  )
788
- _check_for_negative_effect(self.mroi_m, media_effects_dist)
824
+
789
825
  mroi_m = tfp.distributions.BatchBroadcast(
790
826
  self.mroi_m, n_media_channels, name=constants.MROI_M
791
827
  )
792
- _check_for_negative_effect(self.mroi_rf, media_effects_dist)
793
828
  mroi_rf = tfp.distributions.BatchBroadcast(
794
829
  self.mroi_rf, n_rf_channels, name=constants.MROI_RF
795
830
  )
796
831
 
832
+ contribution_m = tfp.distributions.BatchBroadcast(
833
+ self.contribution_m, n_media_channels, name=constants.CONTRIBUTION_M
834
+ )
835
+ contribution_rf = tfp.distributions.BatchBroadcast(
836
+ self.contribution_rf, n_rf_channels, name=constants.CONTRIBUTION_RF
837
+ )
838
+ contribution_om = tfp.distributions.BatchBroadcast(
839
+ self.contribution_om,
840
+ n_organic_media_channels,
841
+ name=constants.CONTRIBUTION_OM,
842
+ )
843
+ contribution_orf = tfp.distributions.BatchBroadcast(
844
+ self.contribution_orf,
845
+ n_organic_rf_channels,
846
+ name=constants.CONTRIBUTION_ORF,
847
+ )
848
+ contribution_n = tfp.distributions.BatchBroadcast(
849
+ self.contribution_n, n_non_media_channels, name=constants.CONTRIBUTION_N
850
+ )
851
+
797
852
  return PriorDistribution(
798
853
  knot_values=knot_values,
799
854
  tau_g_excl_baseline=tau_g_excl_baseline,
@@ -826,6 +881,11 @@ class PriorDistribution:
826
881
  roi_rf=roi_rf,
827
882
  mroi_m=mroi_m,
828
883
  mroi_rf=mroi_rf,
884
+ contribution_m=contribution_m,
885
+ contribution_rf=contribution_rf,
886
+ contribution_om=contribution_om,
887
+ contribution_orf=contribution_orf,
888
+ contribution_n=contribution_n,
829
889
  )
830
890
 
831
891
 
@@ -891,21 +951,6 @@ def _get_total_media_contribution_prior(
891
951
  return tfp.distributions.LogNormal(lognormal_mu, lognormal_sigma, name=name)
892
952
 
893
953
 
894
- def _check_for_negative_effect(
895
- dist: tfp.distributions.Distribution, media_effects_dist: str
896
- ):
897
- """Checks for negative effect in the model."""
898
- if (
899
- media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
900
- and np.any(dist.cdf(0)) > 0
901
- ):
902
- raise ValueError(
903
- 'Media priors must have non-negative support when'
904
- f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
905
- f' in {dist.name}.'
906
- )
907
-
908
-
909
954
  def distributions_are_equal(
910
955
  a: tfp.distributions.Distribution, b: tfp.distributions.Distribution
911
956
  ) -> bool: