google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
  2. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/summarizer.py +7 -2
  5. meridian/analysis/test_utils.py +934 -485
  6. meridian/analysis/visualizer.py +10 -6
  7. meridian/constants.py +1 -0
  8. meridian/data/test_utils.py +82 -10
  9. meridian/model/__init__.py +2 -0
  10. meridian/model/context.py +925 -0
  11. meridian/model/eda/constants.py +1 -0
  12. meridian/model/equations.py +418 -0
  13. meridian/model/knots.py +58 -47
  14. meridian/model/model.py +93 -792
  15. meridian/version.py +1 -1
  16. scenarioplanner/__init__.py +42 -0
  17. scenarioplanner/converters/__init__.py +25 -0
  18. scenarioplanner/converters/dataframe/__init__.py +28 -0
  19. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  20. scenarioplanner/converters/dataframe/common.py +71 -0
  21. scenarioplanner/converters/dataframe/constants.py +137 -0
  22. scenarioplanner/converters/dataframe/converter.py +42 -0
  23. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  24. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  25. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  26. scenarioplanner/converters/mmm.py +743 -0
  27. scenarioplanner/converters/mmm_converter.py +58 -0
  28. scenarioplanner/converters/sheets.py +156 -0
  29. scenarioplanner/converters/test_data.py +714 -0
  30. scenarioplanner/linkingapi/__init__.py +47 -0
  31. scenarioplanner/linkingapi/constants.py +27 -0
  32. scenarioplanner/linkingapi/url_generator.py +131 -0
  33. scenarioplanner/mmm_ui_proto_generator.py +354 -0
  34. schema/__init__.py +5 -2
  35. schema/mmm_proto_generator.py +71 -0
  36. schema/model_consumer.py +133 -0
  37. schema/processors/__init__.py +77 -0
  38. schema/processors/budget_optimization_processor.py +832 -0
  39. schema/processors/common.py +64 -0
  40. schema/processors/marketing_processor.py +1136 -0
  41. schema/processors/model_fit_processor.py +367 -0
  42. schema/processors/model_kernel_processor.py +117 -0
  43. schema/processors/model_processor.py +412 -0
  44. schema/processors/reach_frequency_optimization_processor.py +584 -0
  45. schema/test_data.py +380 -0
  46. schema/utils/__init__.py +1 -0
  47. schema/utils/date_range_bucketing.py +117 -0
  48. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
  49. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,925 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Defines ModelContext class for Meridian."""
16
+
17
+ from collections.abc import Mapping, Sequence
18
+ import functools
19
+ import warnings
20
+
21
+ from meridian import backend
22
+ from meridian import constants
23
+ from meridian.data import input_data as data
24
+ from meridian.model import adstock_hill
25
+ from meridian.model import knots
26
+ from meridian.model import media
27
+ from meridian.model import prior_distribution
28
+ from meridian.model import spec
29
+ from meridian.model import transformers
30
+ import numpy as np
31
+
32
+ __all__ = [
33
+ "ModelContext",
34
+ ]
35
+
36
+
37
+ class ModelContext:
38
+ """Model context for Meridian.
39
+
40
+ This class contains all model parameters that do not change between the runs
41
+ of Meridian.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ input_data: data.InputData,
47
+ model_spec: spec.ModelSpec,
48
+ ):
49
+ self._input_data = input_data
50
+ self._model_spec = model_spec
51
+
52
+ self._validate_data_dependent_model_spec()
53
+ self._validate_model_spec_shapes()
54
+
55
+ self._set_total_media_contribution_prior = False
56
+ self._warn_setting_ignored_priors()
57
+ self._validate_mroi_priors_non_revenue()
58
+ self._validate_roi_priors_non_revenue()
59
+ self._check_media_prior_support()
60
+ self._validate_geo_invariants()
61
+ self._validate_time_invariants()
62
+
63
+ def _validate_data_dependent_model_spec(self):
64
+ """Validates that the data dependent model specs have correct shapes."""
65
+
66
+ if self._model_spec.roi_calibration_period is not None and (
67
+ self._model_spec.roi_calibration_period.shape
68
+ != (
69
+ self.n_media_times,
70
+ self.n_media_channels,
71
+ )
72
+ ):
73
+ raise ValueError(
74
+ "The shape of `roi_calibration_period`"
75
+ f" {self._model_spec.roi_calibration_period.shape} is different from"
76
+ f" `(n_media_times, n_media_channels) = ({self.n_media_times},"
77
+ f" {self.n_media_channels})`."
78
+ )
79
+
80
+ if self._model_spec.rf_roi_calibration_period is not None and (
81
+ self._model_spec.rf_roi_calibration_period.shape
82
+ != (
83
+ self.n_media_times,
84
+ self.n_rf_channels,
85
+ )
86
+ ):
87
+ raise ValueError(
88
+ "The shape of `rf_roi_calibration_period`"
89
+ f" {self._model_spec.rf_roi_calibration_period.shape} is different"
90
+ f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
91
+ f" {self.n_rf_channels})`."
92
+ )
93
+
94
+ if self._model_spec.holdout_id is not None:
95
+ if self.is_national and (
96
+ self._model_spec.holdout_id.shape != (self.n_times,)
97
+ ):
98
+ raise ValueError(
99
+ f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
100
+ f" different from `(n_times,) = ({self.n_times},)`."
101
+ )
102
+ elif not self.is_national and (
103
+ self._model_spec.holdout_id.shape
104
+ != (
105
+ self.n_geos,
106
+ self.n_times,
107
+ )
108
+ ):
109
+ raise ValueError(
110
+ f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
111
+ f" different from `(n_geos, n_times) = ({self.n_geos},"
112
+ f" {self.n_times})`."
113
+ )
114
+
115
+ if self._model_spec.control_population_scaling_id is not None and (
116
+ self._model_spec.control_population_scaling_id.shape
117
+ != (self.n_controls,)
118
+ ):
119
+ raise ValueError(
120
+ "The shape of `control_population_scaling_id`"
121
+ f" {self._model_spec.control_population_scaling_id.shape} is"
122
+ f" different from `(n_controls,) = ({self.n_controls},)`."
123
+ )
124
+
125
+ if self._model_spec.non_media_population_scaling_id is not None and (
126
+ self._model_spec.non_media_population_scaling_id.shape
127
+ != (self.n_non_media_channels,)
128
+ ):
129
+ raise ValueError(
130
+ "The shape of `non_media_population_scaling_id`"
131
+ f" {self._model_spec.non_media_population_scaling_id.shape} is"
132
+ " different from `(n_non_media_channels,) ="
133
+ f" ({self.n_non_media_channels},)`."
134
+ )
135
+
136
+ def _validate_model_spec_shapes(self):
137
+ """Validate shapes of model_spec attributes."""
138
+ if self._model_spec.roi_calibration_period is not None:
139
+ if self._model_spec.roi_calibration_period.shape != (
140
+ self.n_media_times,
141
+ self.n_media_channels,
142
+ ):
143
+ raise ValueError(
144
+ "The shape of `roi_calibration_period`"
145
+ f" {self._model_spec.roi_calibration_period.shape} is different"
146
+ f" from `(n_media_times, n_media_channels) = ({self.n_media_times},"
147
+ f" {self.n_media_channels})`."
148
+ )
149
+
150
+ if self._model_spec.rf_roi_calibration_period is not None:
151
+ if self._model_spec.rf_roi_calibration_period.shape != (
152
+ self.n_media_times,
153
+ self.n_rf_channels,
154
+ ):
155
+ raise ValueError(
156
+ "The shape of `rf_roi_calibration_period`"
157
+ f" {self._model_spec.rf_roi_calibration_period.shape} is different"
158
+ f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
159
+ f" {self.n_rf_channels})`."
160
+ )
161
+
162
+ if self._model_spec.holdout_id is not None:
163
+ expected_shape = (
164
+ (self.n_times,) if self.is_national else (self.n_geos, self.n_times)
165
+ )
166
+ if self._model_spec.holdout_id.shape != expected_shape:
167
+ raise ValueError(
168
+ f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
169
+ " different from"
170
+ f" {'`(n_times,)`' if self.is_national else '`(n_geos, n_times)`'} ="
171
+ f" {expected_shape}."
172
+ )
173
+
174
+ if self._model_spec.control_population_scaling_id is not None:
175
+ if self._model_spec.control_population_scaling_id.shape != (
176
+ self.n_controls,
177
+ ):
178
+ raise ValueError(
179
+ "The shape of `control_population_scaling_id`"
180
+ f" {self._model_spec.control_population_scaling_id.shape} is"
181
+ f" different from `(n_controls,) = ({self.n_controls},)`."
182
+ )
183
+
184
+ def _validate_geo_invariants(self):
185
+ """Validates non-national model invariants."""
186
+ if self.is_national:
187
+ return
188
+
189
+ if self._input_data.controls is not None:
190
+ self._check_if_no_geo_variation(
191
+ self.controls_scaled,
192
+ constants.CONTROLS,
193
+ self._input_data.controls.coords[constants.CONTROL_VARIABLE].values,
194
+ )
195
+ if self._input_data.non_media_treatments is not None:
196
+ self._check_if_no_geo_variation(
197
+ self.non_media_treatments_normalized,
198
+ constants.NON_MEDIA_TREATMENTS,
199
+ self._input_data.non_media_treatments.coords[
200
+ constants.NON_MEDIA_CHANNEL
201
+ ].values,
202
+ )
203
+ if self._input_data.media is not None:
204
+ self._check_if_no_geo_variation(
205
+ self.media_tensors.media_scaled,
206
+ constants.MEDIA,
207
+ self._input_data.media.coords[constants.MEDIA_CHANNEL].values,
208
+ )
209
+ if self._input_data.reach is not None:
210
+ self._check_if_no_geo_variation(
211
+ self.rf_tensors.reach_scaled,
212
+ constants.REACH,
213
+ self._input_data.reach.coords[constants.RF_CHANNEL].values,
214
+ )
215
+ if self._input_data.organic_media is not None:
216
+ self._check_if_no_geo_variation(
217
+ self.organic_media_tensors.organic_media_scaled,
218
+ "organic_media",
219
+ self._input_data.organic_media.coords[
220
+ constants.ORGANIC_MEDIA_CHANNEL
221
+ ].values,
222
+ )
223
+ if self._input_data.organic_reach is not None:
224
+ self._check_if_no_geo_variation(
225
+ self.organic_rf_tensors.organic_reach_scaled,
226
+ "organic_reach",
227
+ self._input_data.organic_reach.coords[
228
+ constants.ORGANIC_RF_CHANNEL
229
+ ].values,
230
+ )
231
+
232
+ def _check_if_no_geo_variation(
233
+ self,
234
+ scaled_data: backend.Tensor,
235
+ data_name: str,
236
+ data_dims: Sequence[str],
237
+ epsilon=1e-4,
238
+ ):
239
+ """Raise an error if `n_knots == n_time` and data lacks geo variation."""
240
+
241
+ # Result shape: [n, d], where d is the number of axes of condition.
242
+ col_idx_full = backend.get_indices_where(
243
+ backend.reduce_std(scaled_data, axis=0) < epsilon
244
+ )[:, 1]
245
+ col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
246
+ # We use the shape of scaled_data (instead of `n_time`) because the data may
247
+ # be padded to account for lagged effects.
248
+ data_n_time = scaled_data.shape[1]
249
+ mask = backend.equal(counts, data_n_time)
250
+ col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
251
+ dims_bad = backend.gather(data_dims, col_idx_bad)
252
+
253
+ if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
254
+ raise ValueError(
255
+ f"The following {data_name} variables do not vary across geos, making"
256
+ f" a model with n_knots=n_time unidentifiable: {dims_bad}. This can"
257
+ " lead to poor model convergence. Since these variables only vary"
258
+ " across time and not across geo, they are collinear with time and"
259
+ " redundant in a model with a parameter for each time period. To"
260
+ " address this, you can either: (1) decrease the number of knots"
261
+ " (n_knots < n_time), or (2) drop the listed variables that do not"
262
+ " vary across geos."
263
+ )
264
+
265
+ def _validate_time_invariants(self):
266
+ """Validates model time invariants."""
267
+ if self._input_data.controls is not None:
268
+ self._check_if_no_time_variation(
269
+ self.controls_scaled,
270
+ constants.CONTROLS,
271
+ self._input_data.controls.coords[constants.CONTROL_VARIABLE].values,
272
+ )
273
+ if self._input_data.non_media_treatments is not None:
274
+ self._check_if_no_time_variation(
275
+ self.non_media_treatments_normalized,
276
+ constants.NON_MEDIA_TREATMENTS,
277
+ self._input_data.non_media_treatments.coords[
278
+ constants.NON_MEDIA_CHANNEL
279
+ ].values,
280
+ )
281
+ if self._input_data.media is not None:
282
+ self._check_if_no_time_variation(
283
+ self.media_tensors.media_scaled,
284
+ constants.MEDIA,
285
+ self._input_data.media.coords[constants.MEDIA_CHANNEL].values,
286
+ )
287
+ if self._input_data.reach is not None:
288
+ self._check_if_no_time_variation(
289
+ self.rf_tensors.reach_scaled,
290
+ constants.REACH,
291
+ self._input_data.reach.coords[constants.RF_CHANNEL].values,
292
+ )
293
+ if self._input_data.organic_media is not None:
294
+ self._check_if_no_time_variation(
295
+ self.organic_media_tensors.organic_media_scaled,
296
+ constants.ORGANIC_MEDIA,
297
+ self._input_data.organic_media.coords[
298
+ constants.ORGANIC_MEDIA_CHANNEL
299
+ ].values,
300
+ )
301
+ if self._input_data.organic_reach is not None:
302
+ self._check_if_no_time_variation(
303
+ self.organic_rf_tensors.organic_reach_scaled,
304
+ constants.ORGANIC_REACH,
305
+ self._input_data.organic_reach.coords[
306
+ constants.ORGANIC_RF_CHANNEL
307
+ ].values,
308
+ )
309
+
310
+ def _check_if_no_time_variation(
311
+ self,
312
+ scaled_data: backend.Tensor,
313
+ data_name: str,
314
+ data_dims: Sequence[str],
315
+ epsilon=1e-4,
316
+ ):
317
+ """Raise an error if data lacks time variation."""
318
+
319
+ # Result shape: [n, d], where d is the number of axes of condition.
320
+ col_idx_full = backend.get_indices_where(
321
+ backend.reduce_std(scaled_data, axis=1) < epsilon
322
+ )[:, 1]
323
+ col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
324
+ mask = backend.equal(counts, self.n_geos)
325
+ col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
326
+ dims_bad = backend.gather(data_dims, col_idx_bad)
327
+ if col_idx_bad.shape[0]:
328
+ if self.is_national:
329
+ raise ValueError(
330
+ f"The following {data_name} variables do not vary across time,"
331
+ " which is equivalent to no signal at all in a national model:"
332
+ f" {dims_bad}. This can lead to poor model convergence. To address"
333
+ " this, drop the listed variables that do not vary across time."
334
+ )
335
+ else:
336
+ raise ValueError(
337
+ f"The following {data_name} variables do not vary across time,"
338
+ f" making a model with geo main effects unidentifiable: {dims_bad}."
339
+ " This can lead to poor model convergence. Since these variables"
340
+ " only vary across geo and not across time, they are collinear"
341
+ " with geo and redundant in a model with geo main effects. To"
342
+ " address this, drop the listed variables that do not vary across"
343
+ " time."
344
+ )
345
+
346
+ @property
347
+ def input_data(self) -> data.InputData:
348
+ return self._input_data
349
+
350
+ @property
351
+ def model_spec(self) -> spec.ModelSpec:
352
+ return self._model_spec
353
+
354
+ @functools.cached_property
355
+ def media_tensors(self) -> media.MediaTensors:
356
+ return media.build_media_tensors(self._input_data, self._model_spec)
357
+
358
+ @functools.cached_property
359
+ def rf_tensors(self) -> media.RfTensors:
360
+ return media.build_rf_tensors(self._input_data, self._model_spec)
361
+
362
+ @functools.cached_property
363
+ def organic_media_tensors(self) -> media.OrganicMediaTensors:
364
+ return media.build_organic_media_tensors(self._input_data)
365
+
366
+ @functools.cached_property
367
+ def organic_rf_tensors(self) -> media.OrganicRfTensors:
368
+ return media.build_organic_rf_tensors(self._input_data)
369
+
370
+ @functools.cached_property
371
+ def kpi(self) -> backend.Tensor:
372
+ return backend.to_tensor(self._input_data.kpi, dtype=backend.float32)
373
+
374
+ @functools.cached_property
375
+ def revenue_per_kpi(self) -> backend.Tensor | None:
376
+ if self._input_data.revenue_per_kpi is None:
377
+ return None
378
+ return backend.to_tensor(
379
+ self._input_data.revenue_per_kpi, dtype=backend.float32
380
+ )
381
+
382
+ @functools.cached_property
383
+ def controls(self) -> backend.Tensor | None:
384
+ if self._input_data.controls is None:
385
+ return None
386
+ return backend.to_tensor(self._input_data.controls, dtype=backend.float32)
387
+
388
+ @functools.cached_property
389
+ def non_media_treatments(self) -> backend.Tensor | None:
390
+ if self._input_data.non_media_treatments is None:
391
+ return None
392
+ return backend.to_tensor(
393
+ self._input_data.non_media_treatments, dtype=backend.float32
394
+ )
395
+
396
+ @functools.cached_property
397
+ def population(self) -> backend.Tensor:
398
+ return backend.to_tensor(self._input_data.population, dtype=backend.float32)
399
+
400
+ @functools.cached_property
401
+ def total_spend(self) -> backend.Tensor:
402
+ return backend.to_tensor(
403
+ self._input_data.get_total_spend(), dtype=backend.float32
404
+ )
405
+
406
+ @functools.cached_property
407
+ def total_outcome(self) -> backend.Tensor:
408
+ return backend.to_tensor(
409
+ self._input_data.get_total_outcome(), dtype=backend.float32
410
+ )
411
+
412
+ @property
413
+ def n_geos(self) -> int:
414
+ return len(self._input_data.geo)
415
+
416
+ @property
417
+ def n_media_channels(self) -> int:
418
+ if self._input_data.media_channel is None:
419
+ return 0
420
+ return len(self._input_data.media_channel)
421
+
422
+ @property
423
+ def n_rf_channels(self) -> int:
424
+ if self._input_data.rf_channel is None:
425
+ return 0
426
+ return len(self._input_data.rf_channel)
427
+
428
+ @property
429
+ def n_organic_media_channels(self) -> int:
430
+ if self._input_data.organic_media_channel is None:
431
+ return 0
432
+ return len(self._input_data.organic_media_channel)
433
+
434
+ @property
435
+ def n_organic_rf_channels(self) -> int:
436
+ if self._input_data.organic_rf_channel is None:
437
+ return 0
438
+ return len(self._input_data.organic_rf_channel)
439
+
440
+ @property
441
+ def n_controls(self) -> int:
442
+ if self._input_data.control_variable is None:
443
+ return 0
444
+ return len(self._input_data.control_variable)
445
+
446
+ @property
447
+ def n_non_media_channels(self) -> int:
448
+ if self._input_data.non_media_channel is None:
449
+ return 0
450
+ return len(self._input_data.non_media_channel)
451
+
452
+ @property
453
+ def n_times(self) -> int:
454
+ return len(self._input_data.time)
455
+
456
+ @property
457
+ def n_media_times(self) -> int:
458
+ return len(self._input_data.media_time)
459
+
460
+ @property
461
+ def is_national(self) -> bool:
462
+ return self.n_geos == 1
463
+
464
+ @functools.cached_property
465
+ def knot_info(self) -> knots.KnotInfo:
466
+ return knots.get_knot_info(
467
+ n_times=self.n_times,
468
+ knots=self._model_spec.knots,
469
+ enable_aks=self._model_spec.enable_aks,
470
+ data=self._input_data,
471
+ is_national=self.is_national,
472
+ )
473
+
474
+ @functools.cached_property
475
+ def controls_transformer(
476
+ self,
477
+ ) -> transformers.CenteringAndScalingTransformer | None:
478
+ """Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
479
+ if self.controls is None:
480
+ return None
481
+
482
+ if self._model_spec.control_population_scaling_id is not None:
483
+ controls_population_scaling_id = backend.to_tensor(
484
+ self._model_spec.control_population_scaling_id, dtype=backend.bool_
485
+ )
486
+ else:
487
+ controls_population_scaling_id = None
488
+
489
+ return transformers.CenteringAndScalingTransformer(
490
+ tensor=self.controls,
491
+ population=self.population,
492
+ population_scaling_id=controls_population_scaling_id,
493
+ )
494
+
495
+ @functools.cached_property
496
+ def non_media_transformer(
497
+ self,
498
+ ) -> transformers.CenteringAndScalingTransformer | None:
499
+ """Returns a `CenteringAndScalingTransformer` for non-media treatments."""
500
+ if self.non_media_treatments is None:
501
+ return None
502
+ if self._model_spec.non_media_population_scaling_id is not None:
503
+ non_media_population_scaling_id = backend.to_tensor(
504
+ self._model_spec.non_media_population_scaling_id, dtype=backend.bool_
505
+ )
506
+ else:
507
+ non_media_population_scaling_id = None
508
+
509
+ return transformers.CenteringAndScalingTransformer(
510
+ tensor=self.non_media_treatments,
511
+ population=self.population,
512
+ population_scaling_id=non_media_population_scaling_id,
513
+ )
514
+
515
+ @functools.cached_property
516
+ def kpi_transformer(self) -> transformers.KpiTransformer:
517
+ return transformers.KpiTransformer(self.kpi, self.population)
518
+
519
+ @functools.cached_property
520
+ def controls_scaled(self) -> backend.Tensor | None:
521
+ if self.controls is not None:
522
+ # If `controls` is defined, then `controls_transformer` is also defined.
523
+ return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
524
+ else:
525
+ return None
526
+
527
+ @functools.cached_property
528
+ def non_media_treatments_normalized(self) -> backend.Tensor | None:
529
+ """Normalized non-media treatments.
530
+
531
+ The non-media treatments values are scaled by population (for channels where
532
+ `non_media_population_scaling_id` is `True`) and normalized by centering and
533
+ scaling with means and standard deviations.
534
+ """
535
+ if self.non_media_transformer is not None:
536
+ return self.non_media_transformer.forward(
537
+ self.non_media_treatments
538
+ ) # pytype: disable=attribute-error
539
+ else:
540
+ return None
541
+
542
+ @functools.cached_property
543
+ def kpi_scaled(self) -> backend.Tensor:
544
+ return self.kpi_transformer.forward(self.kpi)
545
+
546
+ @functools.cached_property
547
+ def media_effects_dist(self) -> str:
548
+ if self.is_national:
549
+ return constants.NATIONAL_MODEL_SPEC_ARGS[constants.MEDIA_EFFECTS_DIST]
550
+ else:
551
+ return self._model_spec.media_effects_dist
552
+
553
+ @functools.cached_property
554
+ def unique_sigma_for_each_geo(self) -> bool:
555
+ if self.is_national:
556
+ # Should evaluate to False.
557
+ return constants.NATIONAL_MODEL_SPEC_ARGS[
558
+ constants.UNIQUE_SIGMA_FOR_EACH_GEO
559
+ ]
560
+ else:
561
+ return self._model_spec.unique_sigma_for_each_geo
562
+
563
+ @functools.cached_property
564
+ def baseline_geo_idx(self) -> int:
565
+ """Returns the index of the baseline geo."""
566
+ if isinstance(self._model_spec.baseline_geo, int):
567
+ if (
568
+ self._model_spec.baseline_geo < 0
569
+ or self._model_spec.baseline_geo >= self.n_geos
570
+ ):
571
+ raise ValueError(
572
+ f"Baseline geo index {self._model_spec.baseline_geo} out of range"
573
+ f" [0, {self.n_geos - 1}]."
574
+ )
575
+ return self._model_spec.baseline_geo
576
+ elif isinstance(self._model_spec.baseline_geo, str):
577
+ # np.where returns a 1-D tuple, its first element is an array of found
578
+ # elements.
579
+ index = np.where(self._input_data.geo == self._model_spec.baseline_geo)[0]
580
+ if index.size == 0:
581
+ raise ValueError(
582
+ f"Baseline geo '{self._model_spec.baseline_geo}' not found."
583
+ )
584
+ # Geos are unique, so index is a 1-element array.
585
+ return index[0]
586
+ else:
587
+ return backend.argmax(self.population)
588
+
589
+ @functools.cached_property
590
+ def holdout_id(self) -> backend.Tensor | None:
591
+ if self._model_spec.holdout_id is None:
592
+ return None
593
+ tensor = backend.to_tensor(self._model_spec.holdout_id, dtype=backend.bool_)
594
+ return tensor[backend.newaxis, ...] if self.is_national else tensor
595
+
596
+ def _warn_setting_ignored_priors(self):
597
+ """Raises a warning if ignored priors are set."""
598
+ default_distribution = prior_distribution.PriorDistribution()
599
+ for ignored_priors_dict, prior_type, prior_type_name in (
600
+ (
601
+ constants.IGNORED_PRIORS_MEDIA,
602
+ self._model_spec.effective_media_prior_type,
603
+ "media_prior_type",
604
+ ),
605
+ (
606
+ constants.IGNORED_PRIORS_RF,
607
+ self._model_spec.effective_rf_prior_type,
608
+ "rf_prior_type",
609
+ ),
610
+ ):
611
+ ignored_custom_priors = []
612
+ for prior in ignored_priors_dict.get(prior_type, []):
613
+ self_prior = getattr(self._model_spec.prior, prior)
614
+ default_prior = getattr(default_distribution, prior)
615
+ if not prior_distribution.distributions_are_equal(
616
+ self_prior, default_prior
617
+ ):
618
+ ignored_custom_priors.append(prior)
619
+ if ignored_custom_priors:
620
+ ignored_priors_str = ", ".join(ignored_custom_priors)
621
+ warnings.warn(
622
+ f"Custom prior(s) `{ignored_priors_str}` are ignored when"
623
+ f' `{prior_type_name}` is set to "{prior_type}".'
624
+ )
625
+
626
+ def _validate_mroi_priors_non_revenue(self):
627
+ """Validates mroi priors in the non-revenue outcome case."""
628
+ if (
629
+ self._input_data.kpi_type == constants.NON_REVENUE
630
+ and self._input_data.revenue_per_kpi is None
631
+ ):
632
+ default_distribution = prior_distribution.PriorDistribution()
633
+ if (
634
+ self.n_media_channels > 0
635
+ and (
636
+ self._model_spec.effective_media_prior_type
637
+ == constants.TREATMENT_PRIOR_TYPE_MROI
638
+ )
639
+ and prior_distribution.distributions_are_equal(
640
+ self._model_spec.prior.mroi_m, default_distribution.mroi_m
641
+ )
642
+ ):
643
+ raise ValueError(
644
+ f"Custom priors should be set on `{constants.MROI_M}` when"
645
+ ' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
646
+ " kpi data is missing."
647
+ )
648
+ if (
649
+ self.n_rf_channels > 0
650
+ and (
651
+ self._model_spec.effective_rf_prior_type
652
+ == constants.TREATMENT_PRIOR_TYPE_MROI
653
+ )
654
+ and prior_distribution.distributions_are_equal(
655
+ self._model_spec.prior.mroi_rf, default_distribution.mroi_rf
656
+ )
657
+ ):
658
+ raise ValueError(
659
+ f"Custom priors should be set on `{constants.MROI_RF}` when"
660
+ ' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
661
+ " data is missing."
662
+ )
663
+
664
+ def _validate_roi_priors_non_revenue(self):
665
+ """Validates roi priors in the non-revenue outcome case."""
666
+ if (
667
+ self._input_data.kpi_type == constants.NON_REVENUE
668
+ and self._input_data.revenue_per_kpi is None
669
+ ):
670
+ default_distribution = prior_distribution.PriorDistribution()
671
+ default_roi_m_used = (
672
+ self._model_spec.effective_media_prior_type
673
+ == constants.TREATMENT_PRIOR_TYPE_ROI
674
+ and prior_distribution.distributions_are_equal(
675
+ self._model_spec.prior.roi_m, default_distribution.roi_m
676
+ )
677
+ )
678
+ default_roi_rf_used = (
679
+ self._model_spec.effective_rf_prior_type
680
+ == constants.TREATMENT_PRIOR_TYPE_ROI
681
+ and prior_distribution.distributions_are_equal(
682
+ self._model_spec.prior.roi_rf, default_distribution.roi_rf
683
+ )
684
+ )
685
+ # If ROI priors are used with the default prior distribution for all paid
686
+ # channels (media and RF), then use the "total paid media contribution
687
+ # prior" procedure.
688
+ if (
689
+ (default_roi_m_used and default_roi_rf_used)
690
+ or (self.n_media_channels == 0 and default_roi_rf_used)
691
+ or (self.n_rf_channels == 0 and default_roi_m_used)
692
+ ):
693
+ self._set_total_media_contribution_prior = True
694
+ warnings.warn(
695
+ "Consider setting custom ROI priors, as kpi_type was specified as"
696
+ " `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
697
+ " total media contribution prior will be used with"
698
+ f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
699
+ " documentation available at "
700
+ " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
701
+ )
702
+ elif self.n_media_channels > 0 and default_roi_m_used:
703
+ raise ValueError(
704
+ f"Custom priors should be set on `{constants.ROI_M}` when"
705
+ ' `media_prior_type` is "roi", custom priors are assigned on'
706
+ ' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
707
+ " non-revenue and revenue per kpi data is missing."
708
+ )
709
+ elif self.n_rf_channels > 0 and default_roi_rf_used:
710
+ raise ValueError(
711
+ f"Custom priors should be set on `{constants.ROI_RF}` when"
712
+ ' `rf_prior_type` is "roi", custom priors are assigned on'
713
+ ' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
714
+ " non-revenue and revenue per kpi data is missing."
715
+ )
716
+
717
+ def _check_media_prior_support(self) -> None:
718
+ """Checks ROI, mROI, and Contribution prior support when random effects are log-normal.
719
+
720
+ Priors for ROI, mROI, and Contribution can only have negative support if the
721
+ random effects follow a normal distribution. This check enforces that priors
722
+ have non-negative support when random effects follow a log-normal
723
+ distribution. This check only applies to geo-level models with log-normal
724
+ random effects since national models do not have random effects.
725
+ """
726
+ prior = self._model_spec.prior
727
+ if self.n_media_channels > 0:
728
+ self._check_for_negative_support(
729
+ prior.roi_m,
730
+ self.media_effects_dist,
731
+ constants.TREATMENT_PRIOR_TYPE_ROI,
732
+ )
733
+ self._check_for_negative_support(
734
+ prior.mroi_m,
735
+ self.media_effects_dist,
736
+ constants.TREATMENT_PRIOR_TYPE_MROI,
737
+ )
738
+ self._check_for_negative_support(
739
+ prior.contribution_m,
740
+ self.media_effects_dist,
741
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
742
+ )
743
+ if self.n_rf_channels > 0:
744
+ self._check_for_negative_support(
745
+ prior.roi_rf,
746
+ self.media_effects_dist,
747
+ constants.TREATMENT_PRIOR_TYPE_ROI,
748
+ )
749
+ self._check_for_negative_support(
750
+ prior.mroi_rf,
751
+ self.media_effects_dist,
752
+ constants.TREATMENT_PRIOR_TYPE_MROI,
753
+ )
754
+ self._check_for_negative_support(
755
+ prior.contribution_rf,
756
+ self.media_effects_dist,
757
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
758
+ )
759
+ if self.n_organic_media_channels > 0:
760
+ self._check_for_negative_support(
761
+ prior.contribution_om,
762
+ self.media_effects_dist,
763
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
764
+ )
765
+ if self.n_organic_rf_channels > 0:
766
+ self._check_for_negative_support(
767
+ prior.contribution_orf,
768
+ self.media_effects_dist,
769
+ constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
770
+ )
771
+
772
+ def _check_for_negative_support(
773
+ self,
774
+ dist: backend.tfd.Distribution,
775
+ media_effects_dist: str,
776
+ prior_type: str,
777
+ ) -> None:
778
+ """Checks for negative support in prior distributions.
779
+
780
+ When `media_effects_dist` is `MEDIA_EFFECTS_LOG_NORMAL`, prior distributions
781
+ for media effects must be non-negative. This function raises a ValueError if
782
+ any part of the distribution's CDF is greater than 0 at 0, indicating some
783
+ probability mass below zero.
784
+
785
+ Args:
786
+ dist: The distribution to check.
787
+ media_effects_dist: The type of media effects distribution.
788
+ prior_type: The prior type that corresponds with current prior under test.
789
+
790
+ Raises:
791
+ ValueError: If the prior distribution has negative support when
792
+ `media_effects_dist` is `MEDIA_EFFECTS_LOG_NORMAL`.
793
+ """
794
+ if (
795
+ prior_type == self._model_spec.media_prior_type
796
+ and media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
797
+ and np.any(dist.cdf(0) > 0)
798
+ ):
799
+ raise ValueError(
800
+ "Media priors must have non-negative support when"
801
+ f' `media_effects_dist`="{media_effects_dist}". Found negative prior'
802
+ f" distribution support for {dist.name}."
803
+ )
804
+
805
+ @functools.cached_property
806
+ def prior_broadcast(self) -> prior_distribution.PriorDistribution:
807
+ """Returns broadcasted `PriorDistribution` object."""
808
+ total_spend = self._input_data.get_total_spend()
809
+ # Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
810
+ if len(total_spend.shape) == 1:
811
+ # Already aggregated by channel.
812
+ agg_total_spend = total_spend
813
+ elif len(total_spend.shape) == 2:
814
+ agg_total_spend = np.sum(total_spend, axis=(0,))
815
+ else:
816
+ agg_total_spend = np.sum(total_spend, axis=(0, 1))
817
+
818
+ return self._model_spec.prior.broadcast(
819
+ n_geos=self.n_geos,
820
+ n_media_channels=self.n_media_channels,
821
+ n_rf_channels=self.n_rf_channels,
822
+ n_organic_media_channels=self.n_organic_media_channels,
823
+ n_organic_rf_channels=self.n_organic_rf_channels,
824
+ n_controls=self.n_controls,
825
+ n_non_media_channels=self.n_non_media_channels,
826
+ unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
827
+ n_knots=self.knot_info.n_knots,
828
+ is_national=self.is_national,
829
+ set_total_media_contribution_prior=self._set_total_media_contribution_prior,
830
+ kpi=np.sum(self._input_data.kpi.values),
831
+ total_spend=agg_total_spend,
832
+ )
833
+
834
+ @functools.cached_property
835
+ def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
836
+ """Returns `AdstockDecaySpec` object with correctly mapped channels."""
837
+ if isinstance(self._model_spec.adstock_decay_spec, str):
838
+ return adstock_hill.AdstockDecaySpec.from_consistent_type(
839
+ self._model_spec.adstock_decay_spec
840
+ )
841
+
842
+ try:
843
+ return self._create_adstock_decay_functions_from_channel_map(
844
+ self._model_spec.adstock_decay_spec
845
+ )
846
+ except KeyError as e:
847
+ raise ValueError(
848
+ "Unrecognized channel names found in `adstock_decay_spec` keys"
849
+ f" {tuple(self._model_spec.adstock_decay_spec.keys())}. Keys should"
850
+ " either contain only channel_names"
851
+ f" {tuple(self._input_data.get_all_adstock_hill_channels().tolist())} or"
852
+ " be one or more of {'media', 'rf', 'organic_media',"
853
+ " 'organic_rf'}."
854
+ ) from e
855
+
856
+ def _create_adstock_decay_functions_from_channel_map(
857
+ self, channel_function_map: Mapping[str, str]
858
+ ) -> adstock_hill.AdstockDecaySpec:
859
+ """Create `AdstockDecaySpec` from mapping from channels to decay functions."""
860
+
861
+ for channel in channel_function_map:
862
+ if channel not in self._input_data.get_all_adstock_hill_channels():
863
+ raise KeyError(f"Channel {channel} not found in data.")
864
+
865
+ if self._input_data.media_channel is not None:
866
+ media_channel_builder = self._input_data.get_paid_media_channels_argument_builder().with_default_value(
867
+ constants.GEOMETRIC_DECAY
868
+ )
869
+ media_adstock_function = media_channel_builder(**channel_function_map)
870
+ else:
871
+ media_adstock_function = constants.GEOMETRIC_DECAY
872
+
873
+ if self._input_data.rf_channel is not None:
874
+ rf_channel_builder = self._input_data.get_paid_rf_channels_argument_builder().with_default_value(
875
+ constants.GEOMETRIC_DECAY
876
+ )
877
+ rf_adstock_function = rf_channel_builder(**channel_function_map)
878
+ else:
879
+ rf_adstock_function = constants.GEOMETRIC_DECAY
880
+
881
+ if self._input_data.organic_media_channel is not None:
882
+ organic_media_channel_builder = self._input_data.get_organic_media_channels_argument_builder().with_default_value(
883
+ constants.GEOMETRIC_DECAY
884
+ )
885
+ organic_media_adstock_function = organic_media_channel_builder(
886
+ **channel_function_map
887
+ )
888
+ else:
889
+ organic_media_adstock_function = constants.GEOMETRIC_DECAY
890
+
891
+ if self._input_data.organic_rf_channel is not None:
892
+ organic_rf_channel_builder = self._input_data.get_organic_rf_channels_argument_builder().with_default_value(
893
+ constants.GEOMETRIC_DECAY
894
+ )
895
+ organic_rf_adstock_function = organic_rf_channel_builder(
896
+ **channel_function_map
897
+ )
898
+ else:
899
+ organic_rf_adstock_function = constants.GEOMETRIC_DECAY
900
+
901
+ return adstock_hill.AdstockDecaySpec(
902
+ media=media_adstock_function,
903
+ rf=rf_adstock_function,
904
+ organic_media=organic_media_adstock_function,
905
+ organic_rf=organic_rf_adstock_function,
906
+ )
907
+
908
+ def populate_cached_properties(self):
909
+ """Eagerly activates all cached properties.
910
+
911
+ This is useful for creating a `tf.function` computation graph with this
912
+ Meridian object as part of a captured closure. Within the computation graph,
913
+ internal state mutations are problematic, and so this method freezes the
914
+ object's states before the computation graph is created.
915
+ """
916
+ cls = self.__class__
917
+ # "Freeze" all @cached_property attributes by simply accessing them (with
918
+ # `getattr()`).
919
+ cached_properties = [
920
+ attr
921
+ for attr in dir(self)
922
+ if isinstance(getattr(cls, attr, cls), functools.cached_property)
923
+ ]
924
+ for attr in cached_properties:
925
+ _ = getattr(self, attr)