google-meridian 1.0.5__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,6 +27,7 @@ import numpy as np
27
27
  import pandas as pd
28
28
  import tensorflow as tf
29
29
  import tensorflow_probability as tfp
30
+ from typing_extensions import Self
30
31
  import xarray as xr
31
32
 
32
33
  __all__ = [
@@ -36,8 +37,9 @@ __all__ = [
36
37
  ]
37
38
 
38
39
 
40
+ # TODO: Refactor the related unit tests to be under DataTensors.
39
41
  class DataTensors(tf.experimental.ExtensionType):
40
- """Container for data variables arguments of Analyzer methods.
42
+ """Container for data variable arguments of Analyzer methods.
41
43
 
42
44
  Attributes:
43
45
  media: Optional tensor with dimensions `(n_geos, T, n_media_channels)` for
@@ -59,20 +61,78 @@ class DataTensors(tf.experimental.ExtensionType):
59
61
  non_media_treatments: Optional tensor with dimensions `(n_geos, T,
60
62
  n_non_media_channels)` for any time dimension `T`.
61
63
  controls: Optional tensor with dimensions `(n_geos, n_times, n_controls)`.
62
- revenue_per_kpi: Optional tensor with dimensions `(n_geos, n_times)`.
64
+ revenue_per_kpi: Optional tensor with dimensions `(n_geos, T)` for any time
65
+ dimension `T`.
63
66
  """
64
67
 
65
- media: Optional[tf.Tensor] = None
66
- media_spend: Optional[tf.Tensor] = None
67
- reach: Optional[tf.Tensor] = None
68
- frequency: Optional[tf.Tensor] = None
69
- rf_spend: Optional[tf.Tensor] = None
70
- organic_media: Optional[tf.Tensor] = None
71
- organic_reach: Optional[tf.Tensor] = None
72
- organic_frequency: Optional[tf.Tensor] = None
73
- non_media_treatments: Optional[tf.Tensor] = None
74
- controls: Optional[tf.Tensor] = None
75
- revenue_per_kpi: Optional[tf.Tensor] = None
68
+ media: Optional[tf.Tensor]
69
+ media_spend: Optional[tf.Tensor]
70
+ reach: Optional[tf.Tensor]
71
+ frequency: Optional[tf.Tensor]
72
+ rf_spend: Optional[tf.Tensor]
73
+ organic_media: Optional[tf.Tensor]
74
+ organic_reach: Optional[tf.Tensor]
75
+ organic_frequency: Optional[tf.Tensor]
76
+ non_media_treatments: Optional[tf.Tensor]
77
+ controls: Optional[tf.Tensor]
78
+ revenue_per_kpi: Optional[tf.Tensor]
79
+
80
+ def __init__(
81
+ self,
82
+ media: Optional[tf.Tensor] = None,
83
+ media_spend: Optional[tf.Tensor] = None,
84
+ reach: Optional[tf.Tensor] = None,
85
+ frequency: Optional[tf.Tensor] = None,
86
+ rf_spend: Optional[tf.Tensor] = None,
87
+ organic_media: Optional[tf.Tensor] = None,
88
+ organic_reach: Optional[tf.Tensor] = None,
89
+ organic_frequency: Optional[tf.Tensor] = None,
90
+ non_media_treatments: Optional[tf.Tensor] = None,
91
+ controls: Optional[tf.Tensor] = None,
92
+ revenue_per_kpi: Optional[tf.Tensor] = None,
93
+ ):
94
+ self.media = tf.cast(media, tf.float32) if media is not None else None
95
+ self.media_spend = (
96
+ tf.cast(media_spend, tf.float32) if media_spend is not None else None
97
+ )
98
+ self.reach = tf.cast(reach, tf.float32) if reach is not None else None
99
+ self.frequency = (
100
+ tf.cast(frequency, tf.float32) if frequency is not None else None
101
+ )
102
+ self.rf_spend = (
103
+ tf.cast(rf_spend, tf.float32) if rf_spend is not None else None
104
+ )
105
+ self.organic_media = (
106
+ tf.cast(organic_media, tf.float32)
107
+ if organic_media is not None
108
+ else None
109
+ )
110
+ self.organic_reach = (
111
+ tf.cast(organic_reach, tf.float32)
112
+ if organic_reach is not None
113
+ else None
114
+ )
115
+ self.organic_frequency = (
116
+ tf.cast(organic_frequency, tf.float32)
117
+ if organic_frequency is not None
118
+ else None
119
+ )
120
+ self.non_media_treatments = (
121
+ tf.cast(non_media_treatments, tf.float32)
122
+ if non_media_treatments is not None
123
+ else None
124
+ )
125
+ self.controls = (
126
+ tf.cast(controls, tf.float32) if controls is not None else None
127
+ )
128
+ self.revenue_per_kpi = (
129
+ tf.cast(revenue_per_kpi, tf.float32)
130
+ if revenue_per_kpi is not None
131
+ else None
132
+ )
133
+
134
+ def __validate__(self):
135
+ self._validate_n_dims()
76
136
 
77
137
  def total_spend(self) -> tf.Tensor | None:
78
138
  """Returns the total spend tensor.
@@ -89,6 +149,248 @@ class DataTensors(tf.experimental.ExtensionType):
89
149
  spend_tensors.append(self.rf_spend)
90
150
  return tf.concat(spend_tensors, axis=-1) if spend_tensors else None
91
151
 
152
+ def get_modified_times(self, meridian: model.Meridian) -> int | None:
153
+ """Returns `n_times` of any tensor where `n_times` has been modified.
154
+
155
+ This method compares the time dimensions of the attributes in the
156
+ `DataTensors` object with the corresponding tensors in the `meridian`
157
+ object. If any of the time dimensions are different, then this method
158
+ returns the modified number of time periods of the tensor in the
159
+ `DataTensors` object. If all time dimensions are the same, returns `None`.
160
+
161
+ Args:
162
+ meridian: A Meridian object to validate against and get the original data
163
+ tensors from.
164
+
165
+ Returns:
166
+ The `n_times` of any tensor where `n_times` is different from the times
167
+ of the corresponding tensor in the `meridian` object. If all time
168
+ dimensions are the same, returns `None`.
169
+ """
170
+ for field in self._tf_extension_type_fields():
171
+ new_tensor = getattr(self, field.name)
172
+ old_tensor = getattr(meridian.input_data, field.name)
173
+ # The time dimension is always the second dimension, except for when spend
174
+ # data is provided with only one dimension of (n_channels).
175
+ if (
176
+ new_tensor is not None
177
+ and old_tensor is not None
178
+ and new_tensor.ndim > 1
179
+ and new_tensor.shape[1] != old_tensor.shape[1]
180
+ ):
181
+ return new_tensor.shape[1]
182
+ return None
183
+
184
+ def validate_and_fill_missing_data(
185
+ self,
186
+ required_tensors_names: Sequence[str],
187
+ meridian: model.Meridian,
188
+ allow_modified_times: bool = True,
189
+ ) -> Self:
190
+ """Fills missing data tensors with their original values from the model.
191
+
192
+ This method uses the collection of data tensors set in the DataTensor class
193
+ and fills in the missing tensors with their original values from the
194
+ Meridian object that is passed in. For example, if `required_tensors_names =
195
+ ["media", "reach", "frequency"]` and only `media` is set in the DataTensors
196
+ class, then this method will output a new DataTensors object with the
197
+ `media` value in this object plus the values of the `reach` and `frequency`
198
+ from the `meridian` object.
199
+
200
+ Args:
201
+ required_tensors_names: A sequence of data tensors names to validate and
202
+ fill in with the original values from the `meridian` object.
203
+ meridian: The Meridian object to validate against and get the original
204
+ data tensors from.
205
+ allow_modified_times: A boolean flag indicating whether to allow modified
206
+ time dimensions in the new data tensors. If False, an error will be
207
+ raised if the time dimensions of any tensor is modified.
208
+
209
+ Returns:
210
+ A `DataTensors` container with the original values from the Meridian
211
+ object filled in for the missing data tensors.
212
+ """
213
+ self._validate_correct_variables_filled(required_tensors_names, meridian)
214
+ self._validate_geo_dims(required_tensors_names, meridian)
215
+ self._validate_channel_dims(required_tensors_names, meridian)
216
+ if allow_modified_times:
217
+ self._validate_time_dims_flexible_times(required_tensors_names, meridian)
218
+ else:
219
+ self._validate_time_dims(required_tensors_names, meridian)
220
+
221
+ return self._fill_default_values(required_tensors_names, meridian)
222
+
223
+ def _validate_n_dims(self):
224
+ """Raises an error if the tensors have the wrong number of dimensions."""
225
+ for field in self._tf_extension_type_fields():
226
+ tensor = getattr(self, field.name)
227
+ if tensor is None:
228
+ continue
229
+ if field.name == constants.REVENUE_PER_KPI:
230
+ _check_n_dims(tensor, field.name, 2)
231
+ elif field.name in [constants.MEDIA_SPEND, constants.RF_SPEND]:
232
+ if tensor.ndim not in [1, 3]:
233
+ raise ValueError(
234
+ f"New `{field.name}` must have 1 or 3 dimensions. Found"
235
+ f" {tensor.ndim} dimensions."
236
+ )
237
+ else:
238
+ _check_n_dims(tensor, field.name, 3)
239
+
240
+ def _validate_correct_variables_filled(
241
+ self, required_variables: Sequence[str], meridian: model.Meridian
242
+ ):
243
+ """Validates that the correct variables are filled.
244
+
245
+ Args:
246
+ required_variables: A sequence of data tensors names that are required to
247
+ be filled in.
248
+ meridian: The Meridian object to validate against.
249
+
250
+ Raises:
251
+ ValueError: If an attribute exists in the `DataTensors` object that is not
252
+ in the `meridian` object, it is not allowed to be used in analysis.
253
+ Warning: If an attribute exists in the `DataTensors` object that is not in
254
+ the `required_variables` list, it will be ignored.
255
+ """
256
+ for field in self._tf_extension_type_fields():
257
+ tensor = getattr(self, field.name)
258
+ if tensor is None:
259
+ continue
260
+ if field.name not in required_variables:
261
+ warnings.warn(
262
+ f"A `{field.name}` value was passed in the `new_data` argument. "
263
+ "This is not supported and will be ignored."
264
+ )
265
+ if field.name in required_variables:
266
+ if getattr(meridian.input_data, field.name) is None:
267
+ raise ValueError(
268
+ f"New `{field.name}` is not allowed because the input data to the"
269
+ f" Meridian model does not contain `{field.name}`."
270
+ )
271
+
272
+ def _validate_geo_dims(
273
+ self, required_fields: Sequence[str], meridian: model.Meridian
274
+ ):
275
+ """Validates the geo dimension of the specified data variables."""
276
+ for var_name in required_fields:
277
+ new_tensor = getattr(self, var_name)
278
+ if new_tensor is not None and new_tensor.shape[0] != meridian.n_geos:
279
+ # Skip spend data with only 1 dimension of (n_channels).
280
+ if new_tensor.ndim == 1:
281
+ continue
282
+ raise ValueError(
283
+ f"New `{var_name}` is expected to have {meridian.n_geos}"
284
+ f" geos. Found {new_tensor.shape[0]} geos."
285
+ )
286
+
287
+ def _validate_channel_dims(
288
+ self, required_fields: Sequence[str], meridian: model.Meridian
289
+ ):
290
+ """Validates the channel dimension of the specified data variables."""
291
+ for var_name in required_fields:
292
+ if var_name == constants.REVENUE_PER_KPI:
293
+ continue
294
+ new_tensor = getattr(self, var_name)
295
+ old_tensor = getattr(meridian.input_data, var_name)
296
+ if new_tensor is not None:
297
+ assert old_tensor is not None
298
+ if new_tensor.shape[-1] != old_tensor.shape[-1]:
299
+ raise ValueError(
300
+ f"New `{var_name}` is expected to have {old_tensor.shape[-1]}"
301
+ f" channels. Found {new_tensor.shape[-1]} channels."
302
+ )
303
+
304
+ def _validate_time_dims(
305
+ self, required_fields: Sequence[str], meridian: model.Meridian
306
+ ):
307
+ """Validates the time dimension of the specified data variables."""
308
+ for var_name in required_fields:
309
+ new_tensor = getattr(self, var_name)
310
+ old_tensor = getattr(meridian.input_data, var_name)
311
+
312
+ # Skip spend data with only 1 dimension of (n_channels).
313
+ if new_tensor is not None and new_tensor.ndim == 1:
314
+ continue
315
+
316
+ if new_tensor is not None:
317
+ assert old_tensor is not None
318
+ if new_tensor.shape[1] != old_tensor.shape[1]:
319
+ raise ValueError(
320
+ f"New `{var_name}` is expected to have {old_tensor.shape[1]}"
321
+ f" time periods. Found {new_tensor.shape[1]} time periods."
322
+ )
323
+
324
+ def _validate_time_dims_flexible_times(
325
+ self, required_fields: Sequence[str], meridian: model.Meridian
326
+ ):
327
+ """Validates the time dimension for the flexible times case."""
328
+ new_n_times = self.get_modified_times(meridian)
329
+ # If no times were modified, then there is nothing more to validate.
330
+ if new_n_times is None:
331
+ return
332
+
333
+ missing_params = []
334
+ for var_name in required_fields:
335
+ new_tensor = getattr(self, var_name)
336
+ old_tensor = getattr(meridian.input_data, var_name)
337
+
338
+ if old_tensor is None:
339
+ continue
340
+ # Skip spend data with only 1 dimension of (n_channels).
341
+ if new_tensor is not None and new_tensor.ndim == 1:
342
+ continue
343
+
344
+ if new_tensor is None:
345
+ missing_params.append(var_name)
346
+ elif new_tensor.shape[1] != new_n_times:
347
+ raise ValueError(
348
+ "If the time dimension of any variable in `new_data` is "
349
+ "modified, then all variables must be provided with the same "
350
+ f"number of time periods. `{var_name}` has {new_tensor.shape[1]} "
351
+ "time periods, which does not match the modified number of time "
352
+ f"periods, {new_n_times}.",
353
+ )
354
+
355
+ if missing_params:
356
+ raise ValueError(
357
+ "If the time dimension of a variable in `new_data` is modified,"
358
+ " then all variables must be provided in `new_data`."
359
+ f" The following variables are missing: `{missing_params}`."
360
+ )
361
+
362
+ def _fill_default_values(
363
+ self, required_fields: Sequence[str], meridian: model.Meridian
364
+ ) -> Self:
365
+ """Fills default values and returns a new DataTensors object."""
366
+ output = {}
367
+ for field in self._tf_extension_type_fields():
368
+ var_name = field.name
369
+ if var_name not in required_fields:
370
+ continue
371
+
372
+ if hasattr(meridian.media_tensors, var_name):
373
+ old_tensor = getattr(meridian.media_tensors, var_name)
374
+ elif hasattr(meridian.rf_tensors, var_name):
375
+ old_tensor = getattr(meridian.rf_tensors, var_name)
376
+ elif hasattr(meridian.organic_media_tensors, var_name):
377
+ old_tensor = getattr(meridian.organic_media_tensors, var_name)
378
+ elif hasattr(meridian.organic_rf_tensors, var_name):
379
+ old_tensor = getattr(meridian.organic_rf_tensors, var_name)
380
+ elif var_name == constants.NON_MEDIA_TREATMENTS:
381
+ old_tensor = meridian.non_media_treatments
382
+ elif var_name == constants.CONTROLS:
383
+ old_tensor = meridian.controls
384
+ elif var_name == constants.REVENUE_PER_KPI:
385
+ old_tensor = meridian.revenue_per_kpi
386
+ else:
387
+ continue
388
+
389
+ new_tensor = getattr(self, var_name)
390
+ output[var_name] = new_tensor if new_tensor is not None else old_tensor
391
+
392
+ return DataTensors(**output)
393
+
92
394
 
93
395
  class DistributionTensors(tf.experimental.ExtensionType):
94
396
  """Container for parameters distributions arguments of Analyzer methods."""
@@ -209,60 +511,6 @@ def _check_n_dims(tensor: tf.Tensor, name: str, n_dims: int):
209
511
  )
210
512
 
211
513
 
212
- def _check_shape_matches(
213
- t1: tf.Tensor | None = None,
214
- t1_name: str = "",
215
- t2: tf.Tensor | None = None,
216
- t2_name: str = "",
217
- t2_shape: tf.TensorShape | None = None,
218
- ):
219
- """Raises an error if dimensions of a tensor don't match the correct shape.
220
-
221
- When `t2_shape` is provided, the dimensions are assumed to be `(n_geos,
222
- n_times, n_channels)` or `(n_geos, n_times)`.
223
-
224
- Args:
225
- t1: The first tensor to check.
226
- t1_name: The name of the first tensor to check.
227
- t2: Optional second tensor to check. If None, `t2_shape` must be provided.
228
- t2_name: The name of the second tensor to check.
229
- t2_shape: Optional shape of the second tensor to check. If None, `t2` must
230
- be provided.
231
- """
232
- if t1 is not None and t2 is not None and t1.shape != t2.shape:
233
- raise ValueError(f"{t1_name}.shape must match {t2_name}.shape.")
234
- if t1 is not None and t2_shape is not None and t1.shape != t2_shape:
235
- _check_n_dims(t1, t1_name, t2_shape.rank)
236
- if t1.shape[0] != t2_shape[0]:
237
- raise ValueError(
238
- f"New `{t1_name}` is expected to have {t2_shape[0]} geos. "
239
- f"Found {t1.shape[0]} geos."
240
- )
241
- if t1.shape[1] != t2_shape[1]:
242
- raise ValueError(
243
- f"New `{t1_name}` is expected to have {t2_shape[1]} time periods. "
244
- f"Found {t1.shape[1]} time periods."
245
- )
246
- if t1.ndim == 3 and t1.shape[2] != t2_shape[2]:
247
- raise ValueError(
248
- f"New `{t1_name}` is expected to have third dimension of size "
249
- f"{t2_shape[2]}. Actual size is {t1.shape[2]}."
250
- )
251
-
252
-
253
- def _check_spend_shape_matches(
254
- spend: tf.Tensor,
255
- spend_name: str,
256
- shapes: Sequence[tf.TensorShape],
257
- ):
258
- """Raises an error if dimensions of spend don't match expected shape."""
259
- if spend is not None and spend.shape not in shapes:
260
- raise ValueError(
261
- f"{spend_name}.shape: {spend.shape} must match either {shapes[0]} or"
262
- + f" {shapes[1]}."
263
- )
264
-
265
-
266
514
  def _is_bool_list(l: Sequence[Any]) -> bool:
267
515
  """Returns True if the list contains only booleans."""
268
516
  return all(isinstance(item, bool) for item in l)
@@ -280,7 +528,21 @@ def _validate_selected_times(
280
528
  arg_name: str,
281
529
  comparison_arg_name: str,
282
530
  ):
283
- """Raises an error if selected_times is invalid."""
531
+ """Raises an error if selected_times is invalid.
532
+
533
+ This checks that the `selected_times` argument is a list of strings or a list
534
+ of booleans. If it is a list of strings, then each string must match the name
535
+ of a time period coordinate in `input_times`. If it is a list of booleans,
536
+ then it must have the same number of elements as `n_times`.
537
+
538
+ Args:
539
+ selected_times: Optional list of times to validate.
540
+ input_times: Time dimension coordinates from `InputData.time` or
541
+ `InputData.media_time`.
542
+ n_times: The number of time periods in the tensor.
543
+ arg_name: The name of the argument being validated.
544
+ comparison_arg_name: The name of the argument being compared to.
545
+ """
284
546
  if not selected_times:
285
547
  return
286
548
  if _is_bool_list(selected_times):
@@ -301,59 +563,23 @@ def _validate_selected_times(
301
563
  )
302
564
 
303
565
 
304
- def _validate_new_data_dims(
305
- new_data: DataTensors, used_treatment_var_names: Sequence[str]
306
- ):
307
- for var_name in used_treatment_var_names:
308
- tensor = getattr(new_data, var_name)
309
- if tensor is not None:
310
- _check_n_dims(tensor, var_name, 3)
311
- if new_data.revenue_per_kpi is not None:
312
- _check_n_dims(new_data.revenue_per_kpi, constants.REVENUE_PER_KPI, 2)
313
-
314
-
315
- def _is_flexible_time_scenario(
316
- new_data: DataTensors,
317
- used_var_names: Sequence[str],
318
- used_vars_names_and_times: Sequence[tuple[tf.Tensor | None, str, int]],
319
- ):
320
- """Checks if the time dimension of a variable in new_data is modified."""
321
- if any(
322
- (var is not None and var.shape[1] != time_dim)
323
- for var, _, time_dim in used_vars_names_and_times
324
- ):
325
- missing_vars = []
326
- for var, var_name, _ in used_vars_names_and_times:
327
- if var is None:
328
- missing_vars.append(var_name)
329
- if missing_vars:
330
- raise ValueError(
331
- "If the time dimension of a variable in `new_data` is modified,"
332
- " then all variables must be provided in `new_data`."
333
- f" The following variables are missing: `{missing_vars}`."
334
- )
335
- new_n_media_times = getattr(new_data, used_var_names[0]).shape[1] # pytype: disable=attribute-error
336
- for var, var_name, _ in used_vars_names_and_times:
337
- # pytype: disable=attribute-error
338
- if var.shape[1] != new_n_media_times:
339
- raise ValueError(
340
- "If the time dimension of any variable in `new_data` is"
341
- " modified, then all variables must be provided with the same"
342
- f" number of time periods. `new_data.{var_name}` has"
343
- f" {var.shape[1]} time periods and does not match"
344
- f" `new_data.{used_var_names[0]}` which has"
345
- f" {new_n_media_times} time periods."
346
- )
347
- # pytype: enable=attribute-error
348
- return True
349
- else:
350
- return False
351
-
352
-
353
- def _validate_selected_times_flexible_time(
566
+ def _validate_flexible_selected_times(
354
567
  selected_times: Sequence[str] | Sequence[bool] | None,
568
+ media_selected_times: Sequence[str] | Sequence[bool] | None,
355
569
  new_n_media_times: int,
356
570
  ):
571
+ """Raises an error if selected times or media selected times is invalid.
572
+
573
+ This checks that the `selected_times` and `media_selected_times` arguments
574
+ are lists of booleans with the same number of elements as `new_n_media_times`.
575
+ This is only relevant if the time dimension of any of the variables in
576
+ `new_data` used in the analysis is modified.
577
+
578
+ Args:
579
+ selected_times: Optional list of times to validate.
580
+ media_selected_times: Optional list of media times to validate.
581
+ new_n_media_times: The number of time periods in the new data.
582
+ """
357
583
  if selected_times and (
358
584
  not _is_bool_list(selected_times)
359
585
  or len(selected_times) != new_n_media_times
@@ -367,11 +593,6 @@ def _validate_selected_times_flexible_time(
367
593
  " the new data."
368
594
  )
369
595
 
370
-
371
- def _validate_media_selected_times_flexible_time(
372
- media_selected_times: Sequence[str] | Sequence[bool] | None,
373
- new_n_media_times: int,
374
- ):
375
596
  if media_selected_times and (
376
597
  not _is_bool_list(media_selected_times)
377
598
  or len(media_selected_times) != new_n_media_times
@@ -390,18 +611,26 @@ def _scale_tensors_by_multiplier(
390
611
  data: DataTensors,
391
612
  multiplier: float,
392
613
  by_reach: bool,
614
+ non_media_treatments_baseline: tf.Tensor | None = None,
393
615
  ) -> DataTensors:
394
616
  """Get scaled tensors for incremental outcome calculation.
395
617
 
396
618
  Args:
397
- data: DataTensors object containing the optional tensors to scale: `media`,
398
- `reach`, `frequency`.
619
+ data: DataTensors object containing the optional tensors to scale. Only
620
+ `media`, `reach`, `frequency`, `organic_media`, `organic_reach`,
621
+ `organic_frequency`, `non_media_treatments` are scaled. The other tensors
622
+ remain unchanged.
399
623
  multiplier: Float indicating the factor to scale tensors by.
400
624
  by_reach: Boolean indicating whether to scale reach or frequency when rf
401
625
  data is available.
626
+ non_media_treatments_baseline: Optional tensor to overwrite
627
+ `data.non_media_treatments` in the output. Used to compute the
628
+ conterfactual values for incremental outcome calculation. If not used, the
629
+ unmodified `data.non_media_treatments` tensor is returned in the output.
402
630
 
403
631
  Returns:
404
- A `DataTensors` object containing scaled tensor parameters.
632
+ A `DataTensors` object containing scaled tensor parameters. The original
633
+ tensors that should not be scaled remain unchanged.
405
634
  """
406
635
  incremented_data = {}
407
636
  if data.media is not None:
@@ -413,6 +642,32 @@ def _scale_tensors_by_multiplier(
413
642
  else:
414
643
  incremented_data[constants.REACH] = data.reach
415
644
  incremented_data[constants.FREQUENCY] = data.frequency * multiplier
645
+ if data.organic_media is not None:
646
+ incremented_data[constants.ORGANIC_MEDIA] = data.organic_media * multiplier
647
+ if data.organic_reach is not None and data.organic_frequency is not None:
648
+ if by_reach:
649
+ incremented_data[constants.ORGANIC_REACH] = (
650
+ data.organic_reach * multiplier
651
+ )
652
+ incremented_data[constants.ORGANIC_FREQUENCY] = data.organic_frequency
653
+ else:
654
+ incremented_data[constants.ORGANIC_REACH] = data.organic_reach
655
+ incremented_data[constants.ORGANIC_FREQUENCY] = (
656
+ data.organic_frequency * multiplier
657
+ )
658
+ if non_media_treatments_baseline is not None:
659
+ incremented_data[constants.NON_MEDIA_TREATMENTS] = (
660
+ non_media_treatments_baseline
661
+ )
662
+ else:
663
+ incremented_data[constants.NON_MEDIA_TREATMENTS] = data.non_media_treatments
664
+
665
+ # Include the original data that does not get scaled.
666
+ incremented_data[constants.MEDIA_SPEND] = data.media_spend
667
+ incremented_data[constants.RF_SPEND] = data.rf_spend
668
+ incremented_data[constants.CONTROLS] = data.controls
669
+ incremented_data[constants.REVENUE_PER_KPI] = data.revenue_per_kpi
670
+
416
671
  return DataTensors(**incremented_data)
417
672
 
418
673
 
@@ -540,43 +795,6 @@ class Analyzer:
540
795
  # states mutation before those graphs execute.
541
796
  self._meridian.populate_cached_properties()
542
797
 
543
- def _validate_new_data_geo_dims(
544
- self,
545
- new_data: DataTensors,
546
- var_names: Sequence[str],
547
- ):
548
- """Validates the geo dimension of the chosen variables from `new_data`."""
549
- for var_name in var_names:
550
- var = getattr(new_data, var_name)
551
- if var is not None:
552
- if var.shape[0] != self._meridian.n_geos:
553
- raise ValueError(
554
- f"New `{var_name}` is expected to have"
555
- f" {self._meridian.n_geos} geos. Found {var.shape[0]} geos."
556
- )
557
-
558
- def _validate_new_data_n_channels(
559
- self,
560
- new_data: DataTensors,
561
- var_names: Sequence[str],
562
- ):
563
- """Validates number of channels in the chosen variables from `new_data`."""
564
- for var_name in var_names:
565
- if getattr(new_data, var_name) is not None:
566
- new_data_var = getattr(new_data, var_name)
567
- input_data_var = getattr(self._meridian.input_data, var_name)
568
- if input_data_var is None:
569
- raise ValueError(
570
- f"`new_data.{var_name}` not allowed because the input data does"
571
- f" not have `{var_name}`."
572
- )
573
- elif new_data_var.shape[-1] != input_data_var.shape[-1]:
574
- raise ValueError(
575
- f"New `{var_name}` is expected to have"
576
- f" {input_data_var.shape[-1]} channels. Found"
577
- f" {new_data_var.shape[-1]} channels."
578
- )
579
-
580
798
  @tf.function(jit_compile=True)
581
799
  def _get_kpi_means(
582
800
  self,
@@ -741,102 +959,6 @@ class Analyzer:
741
959
  .reset_index()
742
960
  )
743
961
 
744
- def _fill_missing_data_tensors(
745
- self,
746
- new_data: DataTensors | None,
747
- required_tensors_names: Sequence[str],
748
- ) -> DataTensors:
749
- """Fills missing data tensors with their original values.
750
-
751
- This method takes a collection of new data tensors set by the user and
752
- fills in the missing tensors with their original values from the Meridian
753
- object. For example, if `required_tensors_names = ["media", "reach",
754
- "frequency"]` and the user sets only `new_data.media`, then this method will
755
- output `new_data.media` and the values of the `reach` and `frequency` from
756
- the Meridian object.
757
-
758
- Args:
759
- new_data: A `DataTensors` container with optional tensors set by the user.
760
- required_tensors_names: A sequence of data tensors names to fill in
761
- `new_data` with their original values from the Meridian object.
762
-
763
- Returns:
764
- A `DataTensors` container. For every tensor from the
765
- `required_tensors_names` list, the output contains the tensor from
766
- `new_data` if it is not `None`, otherwise the corresponding tensor from
767
- the Meridian object.
768
- """
769
- if new_data is None:
770
- new_data = DataTensors()
771
- output = {}
772
- if constants.MEDIA in required_tensors_names:
773
- output[constants.MEDIA] = (
774
- new_data.media
775
- if new_data.media is not None
776
- else self._meridian.media_tensors.media
777
- )
778
- if constants.MEDIA_SPEND in required_tensors_names:
779
- output[constants.MEDIA_SPEND] = (
780
- new_data.media_spend
781
- if new_data.media_spend is not None
782
- else self._meridian.media_tensors.media_spend
783
- )
784
- if constants.REACH in required_tensors_names:
785
- output[constants.REACH] = (
786
- new_data.reach
787
- if new_data.reach is not None
788
- else self._meridian.rf_tensors.reach
789
- )
790
- if constants.FREQUENCY in required_tensors_names:
791
- output[constants.FREQUENCY] = (
792
- new_data.frequency
793
- if new_data.frequency is not None
794
- else self._meridian.rf_tensors.frequency
795
- )
796
- if constants.RF_SPEND in required_tensors_names:
797
- output[constants.RF_SPEND] = (
798
- new_data.rf_spend
799
- if new_data.rf_spend is not None
800
- else self._meridian.rf_tensors.rf_spend
801
- )
802
- if constants.ORGANIC_MEDIA in required_tensors_names:
803
- output[constants.ORGANIC_MEDIA] = (
804
- new_data.organic_media
805
- if new_data.organic_media is not None
806
- else self._meridian.organic_media_tensors.organic_media
807
- )
808
- if constants.ORGANIC_REACH in required_tensors_names:
809
- output[constants.ORGANIC_REACH] = (
810
- new_data.organic_reach
811
- if new_data.organic_reach is not None
812
- else self._meridian.organic_rf_tensors.organic_reach
813
- )
814
- if constants.ORGANIC_FREQUENCY in required_tensors_names:
815
- output[constants.ORGANIC_FREQUENCY] = (
816
- new_data.organic_frequency
817
- if new_data.organic_frequency is not None
818
- else self._meridian.organic_rf_tensors.organic_frequency
819
- )
820
- if constants.NON_MEDIA_TREATMENTS in required_tensors_names:
821
- output[constants.NON_MEDIA_TREATMENTS] = (
822
- new_data.non_media_treatments
823
- if new_data.non_media_treatments is not None
824
- else self._meridian.non_media_treatments
825
- )
826
- if constants.CONTROLS in required_tensors_names:
827
- output[constants.CONTROLS] = (
828
- new_data.controls
829
- if new_data.controls is not None
830
- else self._meridian.controls
831
- )
832
- if constants.REVENUE_PER_KPI in required_tensors_names:
833
- output[constants.REVENUE_PER_KPI] = (
834
- new_data.revenue_per_kpi
835
- if new_data.revenue_per_kpi is not None
836
- else self._meridian.revenue_per_kpi
837
- )
838
- return DataTensors(**output)
839
-
840
962
  def _get_scaled_data_tensors(
841
963
  self,
842
964
  new_data: DataTensors | None = None,
@@ -1330,55 +1452,24 @@ class Analyzer:
1330
1452
  f"sample_{dist_type}() must be called prior to calling"
1331
1453
  " `expected_outcome()`."
1332
1454
  )
1333
- if new_data is not None:
1334
- if new_data.revenue_per_kpi is not None:
1335
- warnings.warn(
1336
- "A `revenue_per_kpi` value was passed in the `new_data` argument to"
1337
- " the `expected_outcome()` method. This is currently not supported"
1338
- " and will be ignored."
1339
- )
1340
- _check_shape_matches(
1341
- new_data.controls, "new_controls", self._meridian.controls, "controls"
1342
- )
1343
- _check_shape_matches(
1344
- new_data.media,
1345
- "new_media",
1346
- self._meridian.media_tensors.media,
1347
- "media",
1348
- )
1349
- _check_shape_matches(
1350
- new_data.reach, "new_reach", self._meridian.rf_tensors.reach, "reach"
1351
- )
1352
- _check_shape_matches(
1353
- new_data.frequency,
1354
- "new_frequency",
1355
- self._meridian.rf_tensors.frequency,
1356
- "frequency",
1357
- )
1358
- _check_shape_matches(
1359
- new_data.organic_media,
1360
- "new_organic_media",
1361
- self._meridian.organic_media_tensors.organic_media,
1362
- "organic_media",
1363
- )
1364
- _check_shape_matches(
1365
- new_data.organic_reach,
1366
- "new_organic_reach",
1367
- self._meridian.organic_rf_tensors.organic_reach,
1368
- "organic_reach",
1369
- )
1370
- _check_shape_matches(
1371
- new_data.organic_frequency,
1372
- "new_organic_frequency",
1373
- self._meridian.organic_rf_tensors.organic_frequency,
1374
- "organic_frequency",
1375
- )
1376
- _check_shape_matches(
1377
- new_data.non_media_treatments,
1378
- "new_non_media_treatments",
1379
- self._meridian.non_media_treatments,
1380
- "non_media_treatments",
1381
- )
1455
+ if new_data is None:
1456
+ new_data = DataTensors()
1457
+
1458
+ required_fields = [
1459
+ constants.CONTROLS,
1460
+ constants.MEDIA,
1461
+ constants.REACH,
1462
+ constants.FREQUENCY,
1463
+ constants.ORGANIC_MEDIA,
1464
+ constants.ORGANIC_REACH,
1465
+ constants.ORGANIC_FREQUENCY,
1466
+ constants.NON_MEDIA_TREATMENTS,
1467
+ ]
1468
+ filled_tensors = new_data.validate_and_fill_missing_data(
1469
+ required_tensors_names=required_fields,
1470
+ meridian=self._meridian,
1471
+ allow_modified_times=False,
1472
+ )
1382
1473
 
1383
1474
  params = (
1384
1475
  self._meridian.inference_data.posterior
@@ -1388,7 +1479,7 @@ class Analyzer:
1388
1479
  # We always compute the expected outcome of all channels, including non-paid
1389
1480
  # channels.
1390
1481
  data_tensors = self._get_scaled_data_tensors(
1391
- new_data=new_data,
1482
+ new_data=filled_tensors,
1392
1483
  include_non_paid_channels=True,
1393
1484
  )
1394
1485
 
@@ -1663,6 +1754,7 @@ class Analyzer:
1663
1754
  aggregate_times: bool = True,
1664
1755
  inverse_transform_outcome: bool = True,
1665
1756
  use_kpi: bool = False,
1757
+ by_reach: bool = True,
1666
1758
  include_non_paid_channels: bool = True,
1667
1759
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
1668
1760
  ) -> tf.Tensor:
@@ -1678,10 +1770,11 @@ class Analyzer:
1678
1770
  by `media_selected_times`. Similarly, `Media_0` means that media execution
1679
1771
  is multiplied by `scaling_factor0` (0.0 by default) for these time periods.
1680
1772
 
1681
- For channels with reach and frequency data, the frequency is held fixed
1682
- while the reach is scaled. "Outcome" refers to either `revenue` if
1683
- `use_kpi=False`, or `kpi` if `use_kpi=True`. When `revenue_per_kpi` is not
1684
- defined, `use_kpi` cannot be False.
1773
+ For channels with reach and frequency data, either reach or frequency is
1774
+ held fixed while the other is scaled, depending on the `by_reach` argument.
1775
+ "Outcome" refers to either `revenue` if `use_kpi=False`, or `kpi` if
1776
+ `use_kpi=True`. When `revenue_per_kpi` is not defined, `use_kpi` cannot be
1777
+ False.
1685
1778
 
1686
1779
  If `new_data=None`, this method computes incremental outcome using `media`,
1687
1780
  `reach`, `frequency`, `organic_media`, `organic_reach`, `organic_frequency`,
@@ -1738,7 +1831,7 @@ class Analyzer:
1738
1831
  default, all geos are included.
1739
1832
  selected_times: Optional list containing either a subset of dates to
1740
1833
  include or booleans with length equal to the number of time periods in
1741
- the `new_XXX` args, if provided. The incremental outcome corresponds to
1834
+ the `new_data` args, if provided. The incremental outcome corresponds to
1742
1835
  incremental KPI generated during the `selected_times` arg by media
1743
1836
  executed during the `media_selected_times` arg. Note that if
1744
1837
  `use_kpi=False`, then `selected_times` can only include the time periods
@@ -1771,6 +1864,10 @@ class Analyzer:
1771
1864
  otherwise the expected revenue `(kpi * revenue_per_kpi)` is calculated.
1772
1865
  It is required that `use_kpi = True` if `revenue_per_kpi` data is not
1773
1866
  available or if `inverse_transform_outcome = False`.
1867
+ by_reach: Boolean. If `True`, then the incremental outcome is calculated
1868
+ by scaling the reach and holding the frequency constant. If `False`,
1869
+ then the incremental outcome is calculated by scaling the frequency and
1870
+ holding the reach constant. Only used for channels with RF data.
1774
1871
  include_non_paid_channels: Boolean. If `True`, then non-media treatments
1775
1872
  and organic effects are included in the calculation. If `False`, then
1776
1873
  only the paid media and RF effects are included.
@@ -1825,176 +1922,61 @@ class Analyzer:
1825
1922
  if new_data is None:
1826
1923
  new_data = DataTensors()
1827
1924
 
1828
- if new_data.controls is not None:
1829
- warnings.warn(
1830
- "A `controls` value was passed in the `new_data` argument to the"
1831
- " `incremental_outcome()` method. This has no effect on the output"
1832
- " and will be ignored."
1833
- )
1834
-
1835
- all_media_time_var_names = [
1925
+ required_params = [
1836
1926
  constants.MEDIA,
1837
1927
  constants.REACH,
1838
1928
  constants.FREQUENCY,
1839
- ]
1840
- if include_non_paid_channels:
1841
- all_media_time_var_names += [
1842
- constants.ORGANIC_MEDIA,
1843
- constants.ORGANIC_REACH,
1844
- constants.ORGANIC_FREQUENCY,
1845
- ]
1846
-
1847
- all_var_names = all_media_time_var_names + [
1929
+ constants.ORGANIC_MEDIA,
1930
+ constants.ORGANIC_REACH,
1931
+ constants.ORGANIC_FREQUENCY,
1932
+ constants.NON_MEDIA_TREATMENTS,
1848
1933
  constants.REVENUE_PER_KPI,
1849
1934
  ]
1850
- if include_non_paid_channels:
1851
- all_var_names += [
1852
- constants.NON_MEDIA_TREATMENTS,
1853
- ]
1854
-
1855
- used_media_time_var_names = [
1856
- var_name
1857
- for var_name in all_media_time_var_names
1858
- if getattr(mmm.input_data, var_name) is not None
1859
- ]
1860
-
1861
- used_treatment_var_names = used_media_time_var_names.copy()
1862
- if (
1863
- include_non_paid_channels
1864
- and mmm.input_data.non_media_treatments is not None
1865
- ):
1866
- used_treatment_var_names += [
1867
- constants.NON_MEDIA_TREATMENTS,
1868
- ]
1869
-
1870
- used_var_names = (
1871
- used_treatment_var_names + [constants.REVENUE_PER_KPI]
1872
- if mmm.input_data.revenue_per_kpi is not None
1873
- else used_treatment_var_names
1935
+ data_tensors = new_data.validate_and_fill_missing_data(
1936
+ required_tensors_names=required_params, meridian=self._meridian
1874
1937
  )
1938
+ new_n_media_times = data_tensors.get_modified_times(self._meridian)
1875
1939
 
1876
- for var_name in all_var_names:
1877
- if (
1878
- getattr(new_data, var_name) is not None
1879
- and var_name not in used_var_names
1880
- ):
1881
- raise ValueError(
1882
- f"The `new_data` argument contains a variable `{var_name}` that is"
1883
- " not used by the model. Please remove this variable from"
1884
- " `new_data`."
1885
- )
1886
-
1887
- _validate_new_data_dims(
1888
- new_data=new_data, used_treatment_var_names=used_treatment_var_names
1889
- )
1890
- self._validate_new_data_geo_dims(
1891
- new_data=new_data, var_names=used_var_names
1892
- )
1893
- self._validate_new_data_n_channels(
1894
- new_data=new_data,
1895
- var_names=used_treatment_var_names,
1896
- )
1897
-
1898
- used_vars_names_and_times = [
1899
- (getattr(new_data, var_name), var_name, mmm.n_media_times)
1900
- for var_name in used_media_time_var_names
1901
- ]
1902
- if mmm.input_data.non_media_treatments is not None:
1903
- used_vars_names_and_times.append(
1904
- (
1905
- new_data.non_media_treatments,
1906
- constants.NON_MEDIA_TREATMENTS,
1907
- mmm.n_times,
1908
- ),
1940
+ if new_n_media_times is None:
1941
+ new_n_media_times = mmm.n_media_times
1942
+ _validate_selected_times(
1943
+ selected_times=selected_times,
1944
+ input_times=mmm.input_data.time,
1945
+ n_times=mmm.n_times,
1946
+ arg_name="selected_times",
1947
+ comparison_arg_name="the input data",
1909
1948
  )
1910
- if mmm.input_data.revenue_per_kpi is not None:
1911
- used_vars_names_and_times.append(
1912
- (
1913
- new_data.revenue_per_kpi,
1914
- constants.REVENUE_PER_KPI,
1915
- mmm.n_times,
1916
- ),
1949
+ _validate_selected_times(
1950
+ selected_times=media_selected_times,
1951
+ input_times=mmm.input_data.media_time,
1952
+ n_times=mmm.n_media_times,
1953
+ arg_name="media_selected_times",
1954
+ comparison_arg_name="the media tensors",
1917
1955
  )
1918
-
1919
- use_flexible_time = _is_flexible_time_scenario(
1920
- new_data=new_data,
1921
- used_var_names=used_var_names,
1922
- used_vars_names_and_times=used_vars_names_and_times,
1923
- )
1924
- new_n_media_times = (
1925
- getattr(new_data, used_var_names[0]).shape[1] # pytype: disable=attribute-error
1926
- if use_flexible_time
1927
- else mmm.n_media_times
1928
- )
1929
-
1930
- if use_flexible_time:
1931
- _validate_selected_times_flexible_time(
1956
+ else:
1957
+ _validate_flexible_selected_times(
1932
1958
  selected_times=selected_times,
1933
- new_n_media_times=new_n_media_times,
1934
- )
1935
- _validate_media_selected_times_flexible_time(
1936
1959
  media_selected_times=media_selected_times,
1937
1960
  new_n_media_times=new_n_media_times,
1938
1961
  )
1939
-
1940
- # Set default values for optional media arguments.
1941
- data_tensors = self._fill_missing_data_tensors(
1942
- new_data,
1943
- [
1944
- constants.MEDIA,
1945
- constants.REACH,
1946
- constants.FREQUENCY,
1947
- constants.ORGANIC_MEDIA,
1948
- constants.ORGANIC_REACH,
1949
- constants.ORGANIC_FREQUENCY,
1950
- constants.NON_MEDIA_TREATMENTS,
1951
- constants.REVENUE_PER_KPI,
1952
- ],
1953
- )
1954
1962
  if media_selected_times is None:
1955
1963
  media_selected_times = [True] * new_n_media_times
1956
1964
  else:
1957
- _validate_selected_times(
1958
- selected_times=media_selected_times,
1959
- input_times=mmm.input_data.media_time,
1960
- n_times=new_n_media_times,
1961
- arg_name="media_selected_times",
1962
- comparison_arg_name="the media tensors",
1963
- )
1964
1965
  if all(isinstance(time, str) for time in media_selected_times):
1965
1966
  media_selected_times = [
1966
1967
  x in media_selected_times for x in mmm.input_data.media_time
1967
1968
  ]
1968
1969
  non_media_selected_times = media_selected_times[-mmm.n_times :]
1969
1970
 
1970
- # Set counterfactual media and reach tensors based on the scaling factors
1971
- # and the media selected times.
1971
+ # Set counterfactual tensors based on the scaling factors and the media
1972
+ # selected times.
1972
1973
  counterfactual0 = (
1973
1974
  1 + (scaling_factor0 - 1) * np.array(media_selected_times)
1974
1975
  )[:, None]
1975
1976
  counterfactual1 = (
1976
1977
  1 + (scaling_factor1 - 1) * np.array(media_selected_times)
1977
1978
  )[:, None]
1978
- new_media0 = (
1979
- None
1980
- if data_tensors.media is None
1981
- else data_tensors.media * counterfactual0
1982
- )
1983
- new_reach0 = (
1984
- None
1985
- if data_tensors.reach is None
1986
- else data_tensors.reach * counterfactual0
1987
- )
1988
- new_organic_media0 = (
1989
- None
1990
- if data_tensors.organic_media is None
1991
- else data_tensors.organic_media * counterfactual0
1992
- )
1993
- new_organic_reach0 = (
1994
- None
1995
- if data_tensors.organic_reach is None
1996
- else data_tensors.organic_reach * counterfactual0
1997
- )
1979
+
1998
1980
  if data_tensors.non_media_treatments is not None:
1999
1981
  new_non_media_treatments0 = _compute_non_media_baseline(
2000
1982
  non_media_treatments=data_tensors.non_media_treatments,
@@ -2003,55 +1985,23 @@ class Analyzer:
2003
1985
  )
2004
1986
  else:
2005
1987
  new_non_media_treatments0 = None
2006
- new_media1 = (
2007
- None
2008
- if data_tensors.media is None
2009
- else data_tensors.media * counterfactual1
2010
- )
2011
- new_reach1 = (
2012
- None
2013
- if data_tensors.reach is None
2014
- else data_tensors.reach * counterfactual1
2015
- )
2016
- new_organic_media1 = (
2017
- None
2018
- if data_tensors.organic_media is None
2019
- else data_tensors.organic_media * counterfactual1
2020
- )
2021
- new_organic_reach1 = (
2022
- None
2023
- if data_tensors.organic_reach is None
2024
- else data_tensors.organic_reach * counterfactual1
1988
+
1989
+ incremented_data0 = _scale_tensors_by_multiplier(
1990
+ data=data_tensors,
1991
+ multiplier=counterfactual0,
1992
+ by_reach=by_reach,
1993
+ non_media_treatments_baseline=new_non_media_treatments0,
2025
1994
  )
2026
- new_non_media_treatments1 = (
2027
- None
2028
- if data_tensors.non_media_treatments is None
2029
- else data_tensors.non_media_treatments
1995
+ incremented_data1 = _scale_tensors_by_multiplier(
1996
+ data=data_tensors, multiplier=counterfactual1, by_reach=by_reach
2030
1997
  )
1998
+
2031
1999
  data_tensors0 = self._get_scaled_data_tensors(
2032
- new_data=DataTensors(
2033
- media=new_media0,
2034
- reach=new_reach0,
2035
- frequency=data_tensors.frequency,
2036
- organic_media=new_organic_media0,
2037
- organic_reach=new_organic_reach0,
2038
- organic_frequency=data_tensors.organic_frequency,
2039
- non_media_treatments=new_non_media_treatments0,
2040
- revenue_per_kpi=data_tensors.revenue_per_kpi,
2041
- ),
2000
+ new_data=incremented_data0,
2042
2001
  include_non_paid_channels=include_non_paid_channels,
2043
2002
  )
2044
2003
  data_tensors1 = self._get_scaled_data_tensors(
2045
- new_data=DataTensors(
2046
- media=new_media1,
2047
- reach=new_reach1,
2048
- frequency=data_tensors.frequency,
2049
- organic_media=new_organic_media1,
2050
- organic_reach=new_organic_reach1,
2051
- organic_frequency=data_tensors.organic_frequency,
2052
- non_media_treatments=new_non_media_treatments1,
2053
- revenue_per_kpi=data_tensors.revenue_per_kpi,
2054
- ),
2004
+ new_data=incremented_data1,
2055
2005
  include_non_paid_channels=include_non_paid_channels,
2056
2006
  )
2057
2007
 
@@ -2101,41 +2051,26 @@ class Analyzer:
2101
2051
  )
2102
2052
  return tf.concat(incremental_outcome_temps, axis=1)
2103
2053
 
2104
- def _validate_and_fill_roi_analysis_arguments(
2054
+ def _validate_geo_and_time_granularity(
2105
2055
  self,
2106
- new_data: DataTensors,
2107
2056
  selected_geos: Sequence[str] | None = None,
2108
2057
  selected_times: Sequence[str] | None = None,
2109
2058
  aggregate_geos: bool = True,
2110
- aggregate_times: bool = True,
2111
- ) -> DataTensors:
2112
- """Validates dimensions of arguments for ROI analysis methods.
2113
-
2114
- Validates dimensionality requirements for `new_data` and other arguments for
2115
- ROI analysis methods.
2059
+ ):
2060
+ """Validates the geo and time granularity arguments for ROI analysis.
2116
2061
 
2117
2062
  Args:
2118
- new_data: DataTensors containing optional `media`, `media_spend`, `reach`,
2119
- `frequency`, and `rf_spend` data.
2120
2063
  selected_geos: Optional. Contains a subset of geos to include. By default,
2121
2064
  all geos are included.
2122
2065
  selected_times: Optional. Contains a subset of times to include. By
2123
2066
  default, all time periods are included.
2124
2067
  aggregate_geos: If `True`, then expected revenue is summed over all
2125
2068
  regions.
2126
- aggregate_times: If `True`, then expected revenue is summed over all time
2127
- periods.
2128
-
2129
- Returns:
2130
- `DataTensors` containing the new data tensors, filled using the original
2131
- data tensors if the new data tensors are `None`.
2132
2069
 
2133
2070
  Raises:
2134
- ValueError: If the dimensions of the arguments do not match the
2135
- dimensions of the corresponding tensors in the Meridian object or if
2136
- the other arguments are invalid.
2071
+ ValueError: If the geo or time granularity arguments are not valid for the
2072
+ ROI analysis.
2137
2073
  """
2138
-
2139
2074
  if self._meridian.is_national:
2140
2075
  _warn_if_geo_arg_in_kwargs(
2141
2076
  aggregate_geos=aggregate_geos,
@@ -2147,126 +2082,48 @@ class Analyzer:
2147
2082
  and not self._meridian.input_data.media_spend_has_geo_dimension
2148
2083
  ):
2149
2084
  raise ValueError(
2150
- "aggregate_geos=False not allowed because Meridian media_spend data"
2151
- " does not have geo dimension."
2085
+ "`selected_geos` and `aggregate_geos=False` are not allowed because"
2086
+ " Meridian `media_spend` data does not have a geo dimension."
2152
2087
  )
2153
2088
  if (
2154
2089
  self._meridian.rf_tensors.rf_spend is not None
2155
2090
  and not self._meridian.input_data.rf_spend_has_geo_dimension
2156
2091
  ):
2157
2092
  raise ValueError(
2158
- "aggregate_geos=False not allowed because Meridian rf_spend data"
2159
- " does not have geo dimension."
2093
+ "`selected_geos` and `aggregate_geos=False` are not allowed because"
2094
+ " Meridian `rf_spend` data does not have a geo dimension."
2160
2095
  )
2161
2096
 
2162
- if selected_times is not None or not aggregate_times:
2097
+ if selected_times is not None:
2163
2098
  if (
2164
2099
  self._meridian.media_tensors.media_spend is not None
2165
2100
  and not self._meridian.input_data.media_spend_has_time_dimension
2166
2101
  ):
2167
2102
  raise ValueError(
2168
- "aggregate_times=False not allowed because Meridian media_spend"
2169
- " data does not have time dimension."
2103
+ "`selected_times` is not allowed because Meridian `media_spend`"
2104
+ " data does not have a time dimension."
2170
2105
  )
2171
2106
  if (
2172
2107
  self._meridian.rf_tensors.rf_spend is not None
2173
2108
  and not self._meridian.input_data.rf_spend_has_time_dimension
2174
2109
  ):
2175
2110
  raise ValueError(
2176
- "aggregate_times=False not allowed because Meridian rf_spend data"
2177
- " does not have time dimension."
2111
+ "`selected_times` is not allowed because Meridian `rf_spend` data"
2112
+ " does not have a time dimension."
2178
2113
  )
2179
2114
 
2180
- _check_shape_matches(
2181
- new_data.media,
2182
- f"{constants.NEW_DATA}.{constants.MEDIA}",
2183
- self._meridian.media_tensors.media,
2184
- constants.MEDIA,
2185
- )
2186
- _check_spend_shape_matches(
2187
- new_data.media_spend,
2188
- f"{constants.NEW_DATA}.{constants.MEDIA_SPEND}",
2189
- (
2190
- tf.TensorShape((self._meridian.n_media_channels)),
2191
- tf.TensorShape((
2192
- self._meridian.n_geos,
2193
- self._meridian.n_times,
2194
- self._meridian.n_media_channels,
2195
- )),
2196
- ),
2197
- )
2198
- _check_shape_matches(
2199
- new_data.reach,
2200
- f"{constants.NEW_DATA}.{constants.REACH}",
2201
- self._meridian.rf_tensors.reach,
2202
- constants.REACH,
2203
- )
2204
- _check_shape_matches(
2205
- new_data.frequency,
2206
- f"{constants.NEW_DATA}.{constants.FREQUENCY}",
2207
- self._meridian.rf_tensors.frequency,
2208
- constants.FREQUENCY,
2209
- )
2210
- _check_spend_shape_matches(
2211
- new_data.rf_spend,
2212
- f"{constants.NEW_DATA}.{constants.RF_SPEND}",
2213
- (
2214
- tf.TensorShape((self._meridian.n_rf_channels)),
2215
- tf.TensorShape((
2216
- self._meridian.n_geos,
2217
- self._meridian.n_times,
2218
- self._meridian.n_rf_channels,
2219
- )),
2220
- ),
2221
- )
2222
-
2223
- new_media = (
2224
- self._meridian.media_tensors.media
2225
- if new_data.media is None
2226
- else new_data.media
2227
- )
2228
- new_reach = (
2229
- self._meridian.rf_tensors.reach
2230
- if new_data.reach is None
2231
- else new_data.reach
2232
- )
2233
- new_frequency = (
2234
- self._meridian.rf_tensors.frequency
2235
- if new_data.frequency is None
2236
- else new_data.frequency
2237
- )
2238
-
2239
- new_media_spend = (
2240
- self._meridian.media_tensors.media_spend
2241
- if new_data.media_spend is None
2242
- else new_data.media_spend
2243
- )
2244
- new_rf_spend = (
2245
- self._meridian.rf_tensors.rf_spend
2246
- if new_data.rf_spend is None
2247
- else new_data.rf_spend
2248
- )
2249
-
2250
- return DataTensors(
2251
- media=new_media,
2252
- reach=new_reach,
2253
- frequency=new_frequency,
2254
- media_spend=new_media_spend,
2255
- rf_spend=new_rf_spend,
2256
- )
2257
-
2258
2115
  def marginal_roi(
2259
2116
  self,
2260
2117
  incremental_increase: float = 0.01,
2261
2118
  use_posterior: bool = True,
2262
2119
  new_data: DataTensors | None = None,
2263
2120
  selected_geos: Sequence[str] | None = None,
2264
- selected_times: Sequence[str] | None = None,
2121
+ selected_times: Sequence[str] | Sequence[bool] | None = None,
2265
2122
  aggregate_geos: bool = True,
2266
2123
  by_reach: bool = True,
2267
2124
  use_kpi: bool = False,
2268
2125
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
2269
- ) -> tf.Tensor | None:
2126
+ ) -> tf.Tensor:
2270
2127
  """Calculates the marginal ROI prior or posterior distribution.
2271
2128
 
2272
2129
  The marginal ROI (mROI) numerator is the change in expected outcome (`kpi`
@@ -2277,7 +2134,7 @@ class Analyzer:
2277
2134
  If `new_data=None`, this method calculates marginal ROI conditional on the
2278
2135
  values of the paid media variables that the Meridian object was initialized
2279
2136
  with. The user can also override this historical data through the `new_data`
2280
- argument, as long as the new tensors` dimensions match. For example,
2137
+ argument. For example,
2281
2138
 
2282
2139
  ```python
2283
2140
  new_data = DataTensors(media=new_media, frequency=new_frequency)
@@ -2302,14 +2159,16 @@ class Analyzer:
2302
2159
  `reach`, `frequency`, `rf_spend` and `revenue_per_kpi` data. If
2303
2160
  provided, the marginal ROI is calculated using the values of the tensors
2304
2161
  passed in `new_data` and the original values of all the remaining
2305
- tensors. The new tensors' dimensions must match the dimensions of the
2306
- corresponding original tensors from `meridian.input_data`. If `None`,
2307
- the marginal ROI is calculated using the original values of all the
2308
- tensors.
2162
+ tensors. If `None`, the marginal ROI is calculated using the original
2163
+ values of all the tensors. If any of the tensors in `new_data` is
2164
+ provided with a different number of time periods than in `InputData`,
2165
+ then all tensors must be provided with the same number of time periods.
2309
2166
  selected_geos: Optional. Contains a subset of geos to include. By default,
2310
2167
  all geos are included.
2311
- selected_times: Optional. Contains a subset of times to include. By
2312
- default, all time periods are included.
2168
+ selected_times: Optional list containing either a subset of dates to
2169
+ include or booleans with length equal to the number of time periods in
2170
+ the `new_data` args, if provided. By default, all time periods are
2171
+ included.
2313
2172
  aggregate_geos: If `True`, the expected revenue is summed over all of the
2314
2173
  regions.
2315
2174
  by_reach: Used for a channel with reach and frequency. If `True`, returns
@@ -2327,49 +2186,55 @@ class Analyzer:
2327
2186
  (n_media_channels + n_rf_channels))`. The `n_geos` dimension is dropped if
2328
2187
  `aggregate_geos=True`.
2329
2188
  """
2330
- self._check_revenue_data_exists(use_kpi)
2331
2189
  dim_kwargs = {
2332
2190
  "selected_geos": selected_geos,
2333
2191
  "selected_times": selected_times,
2334
2192
  "aggregate_geos": aggregate_geos,
2335
- "aggregate_times": True,
2336
- }
2337
- incremental_outcome_kwargs = {
2338
- "inverse_transform_outcome": True,
2339
- "use_posterior": use_posterior,
2340
- "use_kpi": use_kpi,
2341
- "batch_size": batch_size,
2342
- "include_non_paid_channels": False,
2343
2193
  }
2344
- filled_data = self._validate_and_fill_roi_analysis_arguments(
2345
- new_data=new_data or DataTensors(),
2346
- **dim_kwargs,
2194
+ self._check_revenue_data_exists(use_kpi)
2195
+ self._validate_geo_and_time_granularity(**dim_kwargs)
2196
+ required_values = [
2197
+ constants.MEDIA,
2198
+ constants.MEDIA_SPEND,
2199
+ constants.REACH,
2200
+ constants.FREQUENCY,
2201
+ constants.RF_SPEND,
2202
+ constants.REVENUE_PER_KPI,
2203
+ ]
2204
+ if not new_data:
2205
+ new_data = DataTensors()
2206
+ filled_data = new_data.validate_and_fill_missing_data(
2207
+ required_tensors_names=required_values,
2208
+ meridian=self._meridian,
2347
2209
  )
2348
- incremental_outcome = self.incremental_outcome(
2210
+ numerator = self.incremental_outcome(
2349
2211
  new_data=filled_data,
2350
- **incremental_outcome_kwargs,
2351
- **dim_kwargs,
2352
- )
2353
- incremented_data = _scale_tensors_by_multiplier(
2354
- data=filled_data,
2355
- multiplier=incremental_increase + 1,
2212
+ scaling_factor0=1,
2213
+ scaling_factor1=1 + incremental_increase,
2214
+ inverse_transform_outcome=True,
2215
+ use_posterior=use_posterior,
2216
+ use_kpi=use_kpi,
2356
2217
  by_reach=by_reach,
2218
+ batch_size=batch_size,
2219
+ include_non_paid_channels=False,
2220
+ aggregate_times=True,
2221
+ **dim_kwargs,
2357
2222
  )
2358
- incremental_outcome_with_multiplier = self.incremental_outcome(
2359
- new_data=incremented_data, **dim_kwargs, **incremental_outcome_kwargs
2360
- )
2361
- numerator = incremental_outcome_with_multiplier - incremental_outcome
2362
2223
  spend_inc = filled_data.total_spend() * incremental_increase
2363
2224
  if spend_inc is not None and spend_inc.ndim == 3:
2364
2225
  denominator = self.filter_and_aggregate_geos_and_times(
2365
- spend_inc, **dim_kwargs
2226
+ spend_inc,
2227
+ aggregate_times=True,
2228
+ flexible_time_dim=True,
2229
+ has_media_dim=True,
2230
+ **dim_kwargs,
2366
2231
  )
2367
2232
  else:
2368
2233
  if not aggregate_geos:
2369
2234
  # This check should not be reachable. It is here to protect against
2370
- # future changes to self._validate_and_fill_roi_analysis_arguments. If
2235
+ # future changes to self._validate_geo_and_time_granularity. If
2371
2236
  # spend_inc.ndim is not 3 and `aggregate_geos` is `False`, then
2372
- # self._validate_and_fill_roi_analysis_arguments should raise an error.
2237
+ # self._validate_geo_and_time_granularity should raise an error.
2373
2238
  raise ValueError(
2374
2239
  "aggregate_geos must be True if spend does not have a geo "
2375
2240
  "dimension."
@@ -2382,7 +2247,7 @@ class Analyzer:
2382
2247
  use_posterior: bool = True,
2383
2248
  new_data: DataTensors | None = None,
2384
2249
  selected_geos: Sequence[str] | None = None,
2385
- selected_times: Sequence[str] | None = None,
2250
+ selected_times: Sequence[str] | Sequence[bool] | None = None,
2386
2251
  aggregate_geos: bool = True,
2387
2252
  use_kpi: bool = False,
2388
2253
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
@@ -2396,8 +2261,8 @@ class Analyzer:
2396
2261
 
2397
2262
  If `new_data=None`, this method calculates ROI conditional on the values of
2398
2263
  the paid media variables that the Meridian object was initialized with. The
2399
- user can also override this historical data through the `new_data` argument,
2400
- as long as the new tensors' dimensions match. For example,
2264
+ user can also override this historical data through the `new_data` argument.
2265
+ For example,
2401
2266
 
2402
2267
  ```python
2403
2268
  new_data = DataTensors(media=new_media, frequency=new_frequency)
@@ -2417,14 +2282,17 @@ class Analyzer:
2417
2282
  new_data: Optional. DataTensors containing `media`, `media_spend`,
2418
2283
  `reach`, `frequency`, and `rf_spend`, and `revenue_per_kpi` data. If
2419
2284
  provided, the ROI is calculated using the values of the tensors passed
2420
- in `new_data` and the original values of all the remaining tensors. The
2421
- new tensors' dimensions must match the dimensions of the corresponding
2422
- original tensors from `meridian.input_data`. If `None`, the ROI is
2423
- calculated using the original values of all the tensors.
2424
- selected_geos: Optional list containing a subset of geos to include. By
2425
- default, all geos are included.
2426
- selected_times: Optional list containing a subset of times to include. By
2427
- default, all time periods are included.
2285
+ in `new_data` and the original values of all the remaining tensors. If
2286
+ `None`, the ROI is calculated using the original values of all the
2287
+ tensors. If any of the tensors in `new_data` is provided with a
2288
+ different number of time periods than in `InputData`, then all tensors
2289
+ must be provided with the same number of time periods.
2290
+ selected_geos: Optional. Contains a subset of geos to include. By default,
2291
+ all geos are included.
2292
+ selected_times: Optional list containing either a subset of dates to
2293
+ include or booleans with length equal to the number of time periods in
2294
+ the `new_data` args, if provided. By default, all time periods are
2295
+ included.
2428
2296
  aggregate_geos: Boolean. If `True`, the expected revenue is summed over
2429
2297
  all of the regions.
2430
2298
  use_kpi: If `False`, then revenue is used to calculate the ROI numerator.
@@ -2439,12 +2307,10 @@ class Analyzer:
2439
2307
  (n_media_channels + n_rf_channels))`. The `n_geos` dimension is dropped if
2440
2308
  `aggregate_geos=True`.
2441
2309
  """
2442
- self._check_revenue_data_exists(use_kpi)
2443
2310
  dim_kwargs = {
2444
2311
  "selected_geos": selected_geos,
2445
2312
  "selected_times": selected_times,
2446
2313
  "aggregate_geos": aggregate_geos,
2447
- "aggregate_times": True,
2448
2314
  }
2449
2315
  incremental_outcome_kwargs = {
2450
2316
  "inverse_transform_outcome": True,
@@ -2452,10 +2318,23 @@ class Analyzer:
2452
2318
  "use_kpi": use_kpi,
2453
2319
  "batch_size": batch_size,
2454
2320
  "include_non_paid_channels": False,
2321
+ "aggregate_times": True,
2455
2322
  }
2456
- filled_data = self._validate_and_fill_roi_analysis_arguments(
2457
- new_data=new_data or DataTensors(),
2458
- **dim_kwargs,
2323
+ self._check_revenue_data_exists(use_kpi)
2324
+ self._validate_geo_and_time_granularity(**dim_kwargs)
2325
+ required_values = [
2326
+ constants.MEDIA,
2327
+ constants.MEDIA_SPEND,
2328
+ constants.REACH,
2329
+ constants.FREQUENCY,
2330
+ constants.RF_SPEND,
2331
+ constants.REVENUE_PER_KPI,
2332
+ ]
2333
+ if not new_data:
2334
+ new_data = DataTensors()
2335
+ filled_data = new_data.validate_and_fill_missing_data(
2336
+ required_tensors_names=required_values,
2337
+ meridian=self._meridian,
2459
2338
  )
2460
2339
  incremental_outcome = self.incremental_outcome(
2461
2340
  new_data=filled_data,
@@ -2466,15 +2345,19 @@ class Analyzer:
2466
2345
  spend = filled_data.total_spend()
2467
2346
  if spend is not None and spend.ndim == 3:
2468
2347
  denominator = self.filter_and_aggregate_geos_and_times(
2469
- spend, **dim_kwargs
2348
+ spend,
2349
+ aggregate_times=True,
2350
+ flexible_time_dim=True,
2351
+ has_media_dim=True,
2352
+ **dim_kwargs,
2470
2353
  )
2471
2354
  else:
2472
2355
  if not aggregate_geos:
2473
2356
  # This check should not be reachable. It is here to protect against
2474
- # future changes to self._validate_and_fill_roi_analysis_arguments. If
2357
+ # future changes to self._validate_geo_and_time_granularity. If
2475
2358
  # spend_inc.ndim is not 3 and either of `aggregate_geos` or
2476
2359
  # `aggregate_times` is `False`, then
2477
- # self._validate_and_fill_roi_analysis_arguments should raise an error.
2360
+ # self._validate_geo_and_time_granularity should raise an error.
2478
2361
  raise ValueError(
2479
2362
  "aggregate_geos must be True if spend does not have a geo "
2480
2363
  "dimension."
@@ -2487,7 +2370,7 @@ class Analyzer:
2487
2370
  use_posterior: bool = True,
2488
2371
  new_data: DataTensors | None = None,
2489
2372
  selected_geos: Sequence[str] | None = None,
2490
- selected_times: Sequence[str] | None = None,
2373
+ selected_times: Sequence[str] | Sequence[bool] | None = None,
2491
2374
  aggregate_geos: bool = True,
2492
2375
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
2493
2376
  ) -> tf.Tensor:
@@ -2499,8 +2382,8 @@ class Analyzer:
2499
2382
 
2500
2383
  If `new_data=None`, this method calculates CPIK conditional on the values of
2501
2384
  the paid media variables that the Meridian object was initialized with. The
2502
- user can also override this historical data through the `new_data` argument,
2503
- as long as the new tensors' dimensions match. For example,
2385
+ user can also override this historical data through the `new_data` argument.
2386
+ For example,
2504
2387
 
2505
2388
  ```python
2506
2389
  new_data = DataTensors(media=new_media, frequency=new_frequency)
@@ -2510,9 +2393,8 @@ class Analyzer:
2510
2393
  numerator is the total spend during the selected geos and time periods. An
2511
2394
  exception will be thrown if the spend of the InputData used to train the
2512
2395
  model does not have geo and time dimensions. (If the `new_data.media_spend`
2513
- and
2514
- `new_data.rf_spend` arguments are used with different dimensions than the
2515
- InputData spend, then an exception will be thrown since this is a likely
2396
+ and `new_data.rf_spend` arguments are used with different dimensions than
2397
+ the InputData spend, then an exception will be thrown since this is a likely
2516
2398
  user error.)
2517
2399
 
2518
2400
  Note that CPIK is simply 1/ROI, where ROI is obtained from a call to the
@@ -2524,14 +2406,17 @@ class Analyzer:
2524
2406
  new_data: Optional. DataTensors containing `media`, `media_spend`,
2525
2407
  `reach`, `frequency`, `rf_spend` and `revenue_per_kpi` data. If
2526
2408
  provided, the cpik is calculated using the values of the tensors passed
2527
- in `new_data` and the original values of all the remaining tensors. The
2528
- new tensors' dimensions must match the dimensions of the corresponding
2529
- original tensors from `meridian.input_data`. If `None`, the cpik is
2530
- calculated using the original values of all the tensors.
2531
- selected_geos: Optional list containing a subset of geos to include. By
2532
- default, all geos are included.
2533
- selected_times: Optional list containing a subset of times to include. By
2534
- default, all time periods are included.
2409
+ in `new_data` and the original values of all the remaining tensors. If
2410
+ `None`, the ROI is calculated using the original values of all the
2411
+ tensors. If any of the tensors in `new_data` is provided with a
2412
+ different number of time periods than in `InputData`, then all tensors
2413
+ must be provided with the same number of time periods.
2414
+ selected_geos: Optional. Contains a subset of geos to include. By default,
2415
+ all geos are included.
2416
+ selected_times: Optional list containing either a subset of dates to
2417
+ include or booleans with length equal to the number of time periods in
2418
+ the `new_data` args, if provided. By default, all time periods are
2419
+ included.
2535
2420
  aggregate_geos: Boolean. If `True`, the expected KPI is summed over all of
2536
2421
  the regions.
2537
2422
  batch_size: Integer representing the maximum draws per chain in each
@@ -2881,7 +2766,7 @@ class Analyzer:
2881
2766
  marginal_roi_by_reach: bool = True,
2882
2767
  marginal_roi_incremental_increase: float = 0.01,
2883
2768
  selected_geos: Sequence[str] | None = None,
2884
- selected_times: Sequence[str] | None = None,
2769
+ selected_times: Sequence[str] | Sequence[bool] | None = None,
2885
2770
  aggregate_geos: bool = True,
2886
2771
  aggregate_times: bool = True,
2887
2772
  optimal_frequency: Sequence[float] | None = None,
@@ -2896,7 +2781,9 @@ class Analyzer:
2896
2781
  If `new_data=None`, this method calculates all the metrics conditional on
2897
2782
  the values of the data variables that the Meridian object was initialized
2898
2783
  with. The user can also override this historical data through the `new_data`
2899
- argument, as long as the new tensors` dimensions match. For example,
2784
+ argument. For example, to override the media, frequency, and non-media
2785
+ treatments data variables, the user can pass the following `new_data`
2786
+ argument:
2900
2787
 
2901
2788
  ```python
2902
2789
  new_data = DataTensors(
@@ -2905,19 +2792,24 @@ class Analyzer:
2905
2792
  non_media_treatments=new_non_media_treatments)
2906
2793
  ```
2907
2794
 
2795
+ Note that if `new_data` is provided with a different number of time periods
2796
+ than in `InputData`, `pct_of_contribution` is not defined because
2797
+ `expected_outcome()` is not defined for new time periods.
2798
+
2908
2799
  Note that `mroi` and `effectiveness` metrics are not defined (`math.nan`)
2909
2800
  for the aggregate `"All Paid Channels"` channel dimension.
2910
2801
 
2911
2802
  Args:
2912
2803
  new_data: Optional `DataTensors` object with optional new tensors:
2913
- `media`, `reach`, `frequency`, `organic_media`, `organic_reach`,
2914
- `organic_frequency`, `non_media_treatments`, `controls`,
2915
- `revenue_per_kpi`. If provided, the summary metrics are calculated using
2916
- the values of the tensors passed in `new_data` and the original values
2917
- of all the remaining tensors. The new tensors' dimensions must match the
2918
- dimensions of the corresponding original tensors from
2919
- `meridian.input_data`. If `None`, the summary metrics are calculated
2920
- using the original values of all the tensors.
2804
+ `media`, `media_spend`, `reach`, `frequency`, `rf_spend`,
2805
+ `organic_media`, `organic_reach`, `organic_frequency`,
2806
+ `non_media_treatments`, `controls`, `revenue_per_kpi`. If provided, the
2807
+ summary metrics are calculated using the values of the tensors passed in
2808
+ `new_data` and the original values of all the remaining tensors. If
2809
+ `None`, the summary metrics are calculated using the original values of
2810
+ all the tensors. If `new_data` is provided with a different number of
2811
+ time periods than in `InputData`, then all tensors, except `controls`,
2812
+ must have the same number of time periods.
2921
2813
  marginal_roi_by_reach: Boolean. Marginal ROI (mROI) is defined as the
2922
2814
  return on the next dollar spent. If this argument is `True`, the
2923
2815
  assumption is that the next dollar spent only impacts reach, holding
@@ -2930,8 +2822,10 @@ class Analyzer:
2930
2822
  when `include_non_paid_channels` is `False`.
2931
2823
  selected_geos: Optional list containing a subset of geos to include. By
2932
2824
  default, all geos are included.
2933
- selected_times: Optional list containing a subset of times to include. By
2934
- default, all time periods are included.
2825
+ selected_times: Optional list containing either a subset of dates to
2826
+ include or booleans with length equal to the number of time periods in
2827
+ the tensors in the `new_data` argument, if provided. By default, all
2828
+ time periods are included.
2935
2829
  aggregate_geos: Boolean. If `True`, the expected outcome is summed over
2936
2830
  all of the regions.
2937
2831
  aggregate_times: Boolean. If `True`, the expected outcome is summed over
@@ -2982,12 +2876,8 @@ class Analyzer:
2982
2876
  "aggregate_geos": aggregate_geos,
2983
2877
  "aggregate_times": aggregate_times,
2984
2878
  }
2985
- dim_kwargs_wo_agg_times = {
2986
- "selected_geos": selected_geos,
2987
- "selected_times": selected_times,
2988
- "aggregate_geos": aggregate_geos,
2989
- }
2990
2879
  batched_kwargs = {"batch_size": batch_size}
2880
+ new_data = new_data or DataTensors()
2991
2881
  aggregated_impressions = self.get_aggregated_impressions(
2992
2882
  new_data=new_data,
2993
2883
  optimal_frequency=optimal_frequency,
@@ -3020,19 +2910,31 @@ class Analyzer:
3020
2910
  **dim_kwargs,
3021
2911
  **batched_kwargs,
3022
2912
  )
3023
- expected_outcome_prior = self.expected_outcome(
2913
+ incremental_outcome_mroi_prior = self.compute_incremental_outcome_aggregate(
3024
2914
  use_posterior=False,
3025
2915
  new_data=new_data,
3026
2916
  use_kpi=use_kpi,
2917
+ by_reach=marginal_roi_by_reach,
2918
+ scaling_factor0=1,
2919
+ scaling_factor1=1 + marginal_roi_incremental_increase,
2920
+ include_non_paid_channels=include_non_paid_channels,
2921
+ non_media_baseline_values=non_media_baseline_values,
3027
2922
  **dim_kwargs,
3028
2923
  **batched_kwargs,
3029
2924
  )
3030
- expected_outcome_posterior = self.expected_outcome(
3031
- use_posterior=True,
3032
- new_data=new_data,
3033
- use_kpi=use_kpi,
3034
- **dim_kwargs,
3035
- **batched_kwargs,
2925
+ incremental_outcome_mroi_posterior = (
2926
+ self.compute_incremental_outcome_aggregate(
2927
+ use_posterior=True,
2928
+ new_data=new_data,
2929
+ use_kpi=use_kpi,
2930
+ by_reach=marginal_roi_by_reach,
2931
+ scaling_factor0=1,
2932
+ scaling_factor1=1 + marginal_roi_incremental_increase,
2933
+ include_non_paid_channels=include_non_paid_channels,
2934
+ non_media_baseline_values=non_media_baseline_values,
2935
+ **dim_kwargs,
2936
+ **batched_kwargs,
2937
+ )
3036
2938
  )
3037
2939
 
3038
2940
  xr_dims = (
@@ -3059,11 +2961,20 @@ class Analyzer:
3059
2961
  )
3060
2962
  xr_coords[constants.GEO] = ([constants.GEO], geo_dims)
3061
2963
  if not aggregate_times:
3062
- time_dims = (
3063
- self._meridian.input_data.time.data
3064
- if selected_times is None
3065
- else selected_times
3066
- )
2964
+ # Get the time coordinates for flexible time dimensions.
2965
+ modified_times = new_data.get_modified_times(self._meridian)
2966
+ if modified_times is None:
2967
+ times = self._meridian.input_data.time.data
2968
+ else:
2969
+ times = np.arange(modified_times)
2970
+
2971
+ if selected_times is None:
2972
+ time_dims = times
2973
+ elif _is_bool_list(selected_times):
2974
+ indices = np.where(selected_times)
2975
+ time_dims = times[indices]
2976
+ else:
2977
+ time_dims = selected_times
3067
2978
  xr_coords[constants.TIME] = ([constants.TIME], time_dims)
3068
2979
  xr_dims_with_ci_and_distribution = xr_dims + (
3069
2980
  constants.METRIC,
@@ -3094,15 +3005,6 @@ class Analyzer:
3094
3005
  confidence_level=confidence_level,
3095
3006
  include_median=True,
3096
3007
  )
3097
- pct_of_contribution = self._compute_pct_of_contribution(
3098
- incremental_outcome_prior=incremental_outcome_prior,
3099
- incremental_outcome_posterior=incremental_outcome_posterior,
3100
- expected_outcome_prior=expected_outcome_prior,
3101
- expected_outcome_posterior=expected_outcome_posterior,
3102
- xr_dims=xr_dims_with_ci_and_distribution,
3103
- xr_coords=xr_coords_with_ci_and_distribution,
3104
- confidence_level=confidence_level,
3105
- )
3106
3008
  effectiveness = self._compute_effectiveness_aggregate(
3107
3009
  incremental_outcome_prior=incremental_outcome_prior,
3108
3010
  incremental_outcome_posterior=incremental_outcome_posterior,
@@ -3117,6 +3019,33 @@ class Analyzer:
3117
3019
  # channels.
3118
3020
  ).where(lambda ds: ds.channel != constants.ALL_CHANNELS)
3119
3021
 
3022
+ if new_data.get_modified_times(self._meridian) is None:
3023
+ expected_outcome_prior = self.expected_outcome(
3024
+ use_posterior=False,
3025
+ new_data=new_data,
3026
+ use_kpi=use_kpi,
3027
+ **dim_kwargs,
3028
+ **batched_kwargs,
3029
+ )
3030
+ expected_outcome_posterior = self.expected_outcome(
3031
+ use_posterior=True,
3032
+ new_data=new_data,
3033
+ use_kpi=use_kpi,
3034
+ **dim_kwargs,
3035
+ **batched_kwargs,
3036
+ )
3037
+ pct_of_contribution = self._compute_pct_of_contribution(
3038
+ incremental_outcome_prior=incremental_outcome_prior,
3039
+ incremental_outcome_posterior=incremental_outcome_posterior,
3040
+ expected_outcome_prior=expected_outcome_prior,
3041
+ expected_outcome_posterior=expected_outcome_posterior,
3042
+ xr_dims=xr_dims_with_ci_and_distribution,
3043
+ xr_coords=xr_coords_with_ci_and_distribution,
3044
+ confidence_level=confidence_level,
3045
+ )
3046
+ else:
3047
+ pct_of_contribution = xr.Dataset()
3048
+
3120
3049
  if include_non_paid_channels:
3121
3050
  # If non-paid channels are included, return only the non-paid metrics.
3122
3051
  if not aggregate_times:
@@ -3141,8 +3070,10 @@ class Analyzer:
3141
3070
  # If non-paid channels are not included, return all metrics, paid and
3142
3071
  # non-paid.
3143
3072
  spend_list = []
3144
- new_spend_tensors = self._fill_missing_data_tensors(
3145
- new_data, [constants.MEDIA_SPEND, constants.RF_SPEND]
3073
+ if new_data is None:
3074
+ new_data = DataTensors()
3075
+ new_spend_tensors = new_data.validate_and_fill_missing_data(
3076
+ [constants.MEDIA_SPEND, constants.RF_SPEND], self._meridian
3146
3077
  )
3147
3078
  if self._meridian.n_media_channels > 0:
3148
3079
  spend_list.append(new_spend_tensors.media_spend)
@@ -3150,7 +3081,9 @@ class Analyzer:
3150
3081
  spend_list.append(new_spend_tensors.rf_spend)
3151
3082
  # TODO Add support for 1-dimensional spend.
3152
3083
  aggregated_spend = self.filter_and_aggregate_geos_and_times(
3153
- tensor=tf.concat(spend_list, axis=-1), **dim_kwargs
3084
+ tensor=tf.concat(spend_list, axis=-1),
3085
+ flexible_time_dim=True,
3086
+ **dim_kwargs,
3154
3087
  )
3155
3088
  spend_with_total = tf.concat(
3156
3089
  [aggregated_spend, tf.reduce_sum(aggregated_spend, -1, keepdims=True)],
@@ -3185,19 +3118,14 @@ class Analyzer:
3185
3118
  confidence_level=confidence_level,
3186
3119
  spend_with_total=spend_with_total,
3187
3120
  )
3188
- mroi = self._compute_marginal_roi_aggregate(
3189
- marginal_roi_by_reach=marginal_roi_by_reach,
3190
- marginal_roi_incremental_increase=marginal_roi_incremental_increase,
3191
- expected_revenue_prior=expected_outcome_prior,
3192
- expected_revenue_posterior=expected_outcome_posterior,
3121
+ mroi = self._compute_roi_aggregate(
3122
+ incremental_outcome_prior=incremental_outcome_mroi_prior,
3123
+ incremental_outcome_posterior=incremental_outcome_mroi_posterior,
3193
3124
  xr_dims=xr_dims_with_ci_and_distribution,
3194
3125
  xr_coords=xr_coords_with_ci_and_distribution,
3195
3126
  confidence_level=confidence_level,
3196
- spend_with_total=spend_with_total,
3197
- new_data=new_data,
3198
- use_kpi=use_kpi,
3199
- **dim_kwargs_wo_agg_times,
3200
- **batched_kwargs,
3127
+ spend_with_total=spend_with_total * marginal_roi_incremental_increase,
3128
+ metric_name=constants.MROI,
3201
3129
  # Drop mROI metric values in the Dataset's data_vars for the
3202
3130
  # aggregated "All Paid Channels" channel dimension value.
3203
3131
  # "Marginal ROI" calculation must arbitrarily assume how the
@@ -3241,7 +3169,7 @@ class Analyzer:
3241
3169
  self,
3242
3170
  new_data: DataTensors | None = None,
3243
3171
  selected_geos: Sequence[str] | None = None,
3244
- selected_times: Sequence[str] | None = None,
3172
+ selected_times: Sequence[str] | Sequence[bool] | None = None,
3245
3173
  aggregate_geos: bool = True,
3246
3174
  aggregate_times: bool = True,
3247
3175
  optimal_frequency: Sequence[float] | None = None,
@@ -3255,14 +3183,14 @@ class Analyzer:
3255
3183
  `organic_frequency`, and `non_media_treatments` tensors. If `new_data`
3256
3184
  argument is used, then the aggregated impressions are computed using the
3257
3185
  values of the tensors passed in the `new_data` argument and the original
3258
- values of all the remaining tensors. The new tensors' dimensions must
3259
- match the dimensions of the corresponding original tensors from
3260
- `meridian.input_data`. If `None`, the existing tensors from the Meridian
3261
- object are used.
3186
+ values of all the remaining tensors. If `None`, the existing tensors
3187
+ from the Meridian object are used.
3262
3188
  selected_geos: Optional list containing a subset of geos to include. By
3263
3189
  default, all geos are included.
3264
- selected_times: Optional list containing a subset of times to include. By
3265
- default, all time periods are included.
3190
+ selected_times: Optional list containing either a subset of dates to
3191
+ include or booleans with length equal to the number of time periods in
3192
+ the tensors in the `new_data` argument, if provided. By default, all
3193
+ time periods are included.
3266
3194
  aggregate_geos: Boolean. If `True`, the expected outcome is summed over
3267
3195
  all of the regions.
3268
3196
  aggregate_times: Boolean. If `True`, the expected outcome is summed over
@@ -3291,12 +3219,18 @@ class Analyzer:
3291
3219
  constants.ORGANIC_FREQUENCY,
3292
3220
  constants.NON_MEDIA_TREATMENTS,
3293
3221
  ])
3294
- data_tensors = self._fill_missing_data_tensors(new_data, tensor_names_list)
3222
+ if new_data is None:
3223
+ new_data = DataTensors()
3224
+ data_tensors = new_data.validate_and_fill_missing_data(
3225
+ tensor_names_list, self._meridian
3226
+ )
3227
+ n_times = (
3228
+ data_tensors.get_modified_times(self._meridian)
3229
+ or self._meridian.n_times
3230
+ )
3295
3231
  impressions_list = []
3296
3232
  if self._meridian.n_media_channels > 0:
3297
- impressions_list.append(
3298
- data_tensors.media[:, -self._meridian.n_times :, :]
3299
- )
3233
+ impressions_list.append(data_tensors.media[:, -n_times:, :])
3300
3234
 
3301
3235
  if self._meridian.n_rf_channels > 0:
3302
3236
  if optimal_frequency is None:
@@ -3304,15 +3238,12 @@ class Analyzer:
3304
3238
  else:
3305
3239
  new_frequency = tf.ones_like(data_tensors.frequency) * optimal_frequency
3306
3240
  impressions_list.append(
3307
- data_tensors.reach[:, -self._meridian.n_times :, :]
3308
- * new_frequency[:, -self._meridian.n_times :, :]
3241
+ data_tensors.reach[:, -n_times:, :] * new_frequency[:, -n_times:, :]
3309
3242
  )
3310
3243
 
3311
3244
  if include_non_paid_channels:
3312
3245
  if self._meridian.n_organic_media_channels > 0:
3313
- impressions_list.append(
3314
- data_tensors.organic_media[:, -self._meridian.n_times :, :]
3315
- )
3246
+ impressions_list.append(data_tensors.organic_media[:, -n_times:, :])
3316
3247
  if self._meridian.n_organic_rf_channels > 0:
3317
3248
  if optimal_frequency is None:
3318
3249
  new_organic_frequency = data_tensors.organic_frequency
@@ -3321,8 +3252,8 @@ class Analyzer:
3321
3252
  tf.ones_like(data_tensors.organic_frequency) * optimal_frequency
3322
3253
  )
3323
3254
  impressions_list.append(
3324
- data_tensors.organic_reach[:, -self._meridian.n_times :, :]
3325
- * new_organic_frequency[:, -self._meridian.n_times :, :]
3255
+ data_tensors.organic_reach[:, -n_times:, :]
3256
+ * new_organic_frequency[:, -n_times:, :]
3326
3257
  )
3327
3258
  if self._meridian.n_non_media_channels > 0:
3328
3259
  impressions_list.append(data_tensors.non_media_treatments)
@@ -3333,6 +3264,7 @@ class Analyzer:
3333
3264
  selected_times=selected_times,
3334
3265
  aggregate_geos=aggregate_geos,
3335
3266
  aggregate_times=aggregate_times,
3267
+ flexible_time_dim=True,
3336
3268
  )
3337
3269
 
3338
3270
  def baseline_summary_metrics(
@@ -4544,85 +4476,13 @@ class Analyzer:
4544
4476
  xr_coords: Mapping[str, tuple[Sequence[str], Sequence[str]]],
4545
4477
  spend_with_total: tf.Tensor,
4546
4478
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
4479
+ metric_name: str = constants.ROI,
4547
4480
  ) -> xr.Dataset:
4548
4481
  # TODO: Support calibration_period_bool.
4549
4482
  return _central_tendency_and_ci_by_prior_and_posterior(
4550
4483
  prior=incremental_outcome_prior / spend_with_total,
4551
4484
  posterior=incremental_outcome_posterior / spend_with_total,
4552
- metric_name=constants.ROI,
4553
- xr_dims=xr_dims,
4554
- xr_coords=xr_coords,
4555
- confidence_level=confidence_level,
4556
- include_median=True,
4557
- )
4558
-
4559
- def _compute_marginal_roi_aggregate(
4560
- self,
4561
- marginal_roi_by_reach: bool,
4562
- marginal_roi_incremental_increase: float,
4563
- expected_revenue_prior: tf.Tensor,
4564
- expected_revenue_posterior: tf.Tensor,
4565
- xr_dims: Sequence[str],
4566
- xr_coords: Mapping[str, tuple[Sequence[str], Sequence[str]]],
4567
- spend_with_total: tf.Tensor,
4568
- new_data: DataTensors | None = None,
4569
- use_kpi: bool = False,
4570
- confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
4571
- **roi_kwargs,
4572
- ) -> xr.Dataset:
4573
- data_tensors = self._fill_missing_data_tensors(
4574
- new_data, [constants.MEDIA, constants.REACH, constants.FREQUENCY]
4575
- )
4576
- mroi_prior = self.marginal_roi(
4577
- use_posterior=False,
4578
- new_data=data_tensors,
4579
- by_reach=marginal_roi_by_reach,
4580
- incremental_increase=marginal_roi_incremental_increase,
4581
- use_kpi=use_kpi,
4582
- **roi_kwargs,
4583
- )
4584
- mroi_posterior = self.marginal_roi(
4585
- use_posterior=True,
4586
- new_data=data_tensors,
4587
- by_reach=marginal_roi_by_reach,
4588
- incremental_increase=marginal_roi_incremental_increase,
4589
- use_kpi=use_kpi,
4590
- **roi_kwargs,
4591
- )
4592
- incremented_data = _scale_tensors_by_multiplier(
4593
- data=data_tensors,
4594
- multiplier=(1 + marginal_roi_incremental_increase),
4595
- by_reach=marginal_roi_by_reach,
4596
- )
4597
-
4598
- mroi_prior_total = (
4599
- self.expected_outcome(
4600
- use_posterior=False,
4601
- new_data=incremented_data,
4602
- use_kpi=use_kpi,
4603
- **roi_kwargs,
4604
- )
4605
- - expected_revenue_prior
4606
- ) / (marginal_roi_incremental_increase * spend_with_total[..., -1])
4607
- mroi_posterior_total = (
4608
- self.expected_outcome(
4609
- use_posterior=True,
4610
- new_data=incremented_data,
4611
- use_kpi=use_kpi,
4612
- **roi_kwargs,
4613
- )
4614
- - expected_revenue_posterior
4615
- ) / (marginal_roi_incremental_increase * spend_with_total[..., -1])
4616
- mroi_prior_concat = tf.concat(
4617
- [mroi_prior, mroi_prior_total[..., None]], axis=-1
4618
- )
4619
- mroi_posterior_concat = tf.concat(
4620
- [mroi_posterior, mroi_posterior_total[..., None]], axis=-1
4621
- )
4622
- return _central_tendency_and_ci_by_prior_and_posterior(
4623
- prior=mroi_prior_concat,
4624
- posterior=mroi_posterior_concat,
4625
- metric_name=constants.MROI,
4485
+ metric_name=metric_name,
4626
4486
  xr_dims=xr_dims,
4627
4487
  xr_coords=xr_coords,
4628
4488
  confidence_level=confidence_level,