google-meridian 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,7 @@
16
16
 
17
17
  from collections.abc import Mapping, Sequence
18
18
  import itertools
19
+ import numbers
19
20
  from typing import Any, Optional
20
21
  import warnings
21
22
 
@@ -37,6 +38,20 @@ __all__ = [
37
38
  ]
38
39
 
39
40
 
41
+ def _validate_non_media_baseline_values_numbers(
42
+ non_media_baseline_values: Sequence[str | float] | None,
43
+ ):
44
+ if non_media_baseline_values is None:
45
+ return
46
+
47
+ for value in non_media_baseline_values:
48
+ if not isinstance(value, numbers.Number):
49
+ raise ValueError(
50
+ f"Invalid `non_media_baseline_values` value: '{value}'. Only float"
51
+ " numbers are supported."
52
+ )
53
+
54
+
40
55
  # TODO: Refactor the related unit tests to be under DataTensors.
41
56
  class DataTensors(tf.experimental.ExtensionType):
42
57
  """Container for data variable arguments of Analyzer methods.
@@ -63,6 +78,8 @@ class DataTensors(tf.experimental.ExtensionType):
63
78
  controls: Optional tensor with dimensions `(n_geos, n_times, n_controls)`.
64
79
  revenue_per_kpi: Optional tensor with dimensions `(n_geos, T)` for any time
65
80
  dimension `T`.
81
+ time: Optional tensor of time coordinates in the "YYYY-mm-dd" string format
82
+ for time dimension `T`.
66
83
  """
67
84
 
68
85
  media: Optional[tf.Tensor]
@@ -76,6 +93,7 @@ class DataTensors(tf.experimental.ExtensionType):
76
93
  non_media_treatments: Optional[tf.Tensor]
77
94
  controls: Optional[tf.Tensor]
78
95
  revenue_per_kpi: Optional[tf.Tensor]
96
+ time: Optional[tf.Tensor]
79
97
 
80
98
  def __init__(
81
99
  self,
@@ -90,6 +108,7 @@ class DataTensors(tf.experimental.ExtensionType):
90
108
  non_media_treatments: Optional[tf.Tensor] = None,
91
109
  controls: Optional[tf.Tensor] = None,
92
110
  revenue_per_kpi: Optional[tf.Tensor] = None,
111
+ time: Optional[Sequence[str] | tf.Tensor] = None,
93
112
  ):
94
113
  self.media = tf.cast(media, tf.float32) if media is not None else None
95
114
  self.media_spend = (
@@ -130,6 +149,7 @@ class DataTensors(tf.experimental.ExtensionType):
130
149
  if revenue_per_kpi is not None
131
150
  else None
132
151
  )
152
+ self.time = tf.cast(time, tf.string) if time is not None else None
133
153
 
134
154
  def __validate__(self):
135
155
  self._validate_n_dims()
@@ -176,6 +196,7 @@ class DataTensors(tf.experimental.ExtensionType):
176
196
  new_tensor is not None
177
197
  and old_tensor is not None
178
198
  and new_tensor.ndim > 1
199
+ and old_tensor.ndim > 1
179
200
  and new_tensor.shape[1] != old_tensor.shape[1]
180
201
  ):
181
202
  return new_tensor.shape[1]
@@ -241,6 +262,8 @@ class DataTensors(tf.experimental.ExtensionType):
241
262
  f"New `{field.name}` must have 1 or 3 dimensions. Found"
242
263
  f" {tensor.ndim} dimensions."
243
264
  )
265
+ elif field.name == constants.TIME:
266
+ _check_n_dims(tensor, field.name, 1)
244
267
  else:
245
268
  _check_n_dims(tensor, field.name, 3)
246
269
 
@@ -283,7 +306,7 @@ class DataTensors(tf.experimental.ExtensionType):
283
306
  for var_name in required_fields:
284
307
  new_tensor = getattr(self, var_name)
285
308
  if new_tensor is not None and new_tensor.shape[0] != meridian.n_geos:
286
- # Skip spend data with only 1 dimension of (n_channels).
309
+ # Skip spend and time data with only 1 dimension.
287
310
  if new_tensor.ndim == 1:
288
311
  continue
289
312
  raise ValueError(
@@ -296,7 +319,7 @@ class DataTensors(tf.experimental.ExtensionType):
296
319
  ):
297
320
  """Validates the channel dimension of the specified data variables."""
298
321
  for var_name in required_fields:
299
- if var_name == constants.REVENUE_PER_KPI:
322
+ if var_name in [constants.REVENUE_PER_KPI, constants.TIME]:
300
323
  continue
301
324
  new_tensor = getattr(self, var_name)
302
325
  old_tensor = getattr(meridian.input_data, var_name)
@@ -317,12 +340,24 @@ class DataTensors(tf.experimental.ExtensionType):
317
340
  old_tensor = getattr(meridian.input_data, var_name)
318
341
 
319
342
  # Skip spend data with only 1 dimension of (n_channels).
320
- if new_tensor is not None and new_tensor.ndim == 1:
343
+ if (
344
+ var_name in [constants.MEDIA_SPEND, constants.RF_SPEND]
345
+ and new_tensor is not None
346
+ and new_tensor.ndim == 1
347
+ ):
321
348
  continue
322
349
 
323
350
  if new_tensor is not None:
324
351
  assert old_tensor is not None
325
- if new_tensor.shape[1] != old_tensor.shape[1]:
352
+ if (
353
+ var_name == constants.TIME
354
+ and new_tensor.shape[0] != old_tensor.shape[0]
355
+ ):
356
+ raise ValueError(
357
+ f"New `{var_name}` is expected to have {old_tensor.shape[0]}"
358
+ f" time periods. Found {new_tensor.shape[0]} time periods."
359
+ )
360
+ elif new_tensor.ndim > 1 and new_tensor.shape[1] != old_tensor.shape[1]:
326
361
  raise ValueError(
327
362
  f"New `{var_name}` is expected to have {old_tensor.shape[1]}"
328
363
  f" time periods. Found {new_tensor.shape[1]} time periods."
@@ -345,12 +380,24 @@ class DataTensors(tf.experimental.ExtensionType):
345
380
  if old_tensor is None:
346
381
  continue
347
382
  # Skip spend data with only 1 dimension of (n_channels).
348
- if new_tensor is not None and new_tensor.ndim == 1:
383
+ if (
384
+ var_name in [constants.MEDIA_SPEND, constants.RF_SPEND]
385
+ and new_tensor is not None
386
+ and new_tensor.ndim == 1
387
+ ):
349
388
  continue
350
389
 
351
390
  if new_tensor is None:
352
391
  missing_params.append(var_name)
353
- elif new_tensor.shape[1] != new_n_times:
392
+ elif var_name == constants.TIME and new_tensor.shape[0] != new_n_times:
393
+ raise ValueError(
394
+ "If the time dimension of any variable in `new_data` is "
395
+ "modified, then all variables must be provided with the same "
396
+ f"number of time periods. `{var_name}` has {new_tensor.shape[1]} "
397
+ "time periods, which does not match the modified number of time "
398
+ f"periods, {new_n_times}.",
399
+ )
400
+ elif new_tensor.ndim > 1 and new_tensor.shape[1] != new_n_times:
354
401
  raise ValueError(
355
402
  "If the time dimension of any variable in `new_data` is "
356
403
  "modified, then all variables must be provided with the same "
@@ -390,6 +437,10 @@ class DataTensors(tf.experimental.ExtensionType):
390
437
  old_tensor = meridian.controls
391
438
  elif var_name == constants.REVENUE_PER_KPI:
392
439
  old_tensor = meridian.revenue_per_kpi
440
+ elif var_name == constants.TIME:
441
+ old_tensor = tf.convert_to_tensor(
442
+ meridian.input_data.time.values.tolist(), dtype=tf.string
443
+ )
393
444
  else:
394
445
  continue
395
446
 
@@ -618,22 +669,16 @@ def _scale_tensors_by_multiplier(
618
669
  data: DataTensors,
619
670
  multiplier: float,
620
671
  by_reach: bool,
621
- non_media_treatments_baseline: tf.Tensor | None = None,
622
672
  ) -> DataTensors:
623
673
  """Get scaled tensors for incremental outcome calculation.
624
674
 
625
675
  Args:
626
676
  data: DataTensors object containing the optional tensors to scale. Only
627
- `media`, `reach`, `frequency`, `organic_media`, `organic_reach`,
628
- `organic_frequency`, `non_media_treatments` are scaled. The other tensors
629
- remain unchanged.
677
+ `media`, `reach`, `frequency`, `organic_media`, `organic_reach`, and
678
+ `organic_frequency` are scaled. The other tensors remain unchanged.
630
679
  multiplier: Float indicating the factor to scale tensors by.
631
680
  by_reach: Boolean indicating whether to scale reach or frequency when rf
632
681
  data is available.
633
- non_media_treatments_baseline: Optional tensor to overwrite
634
- `data.non_media_treatments` in the output. Used to compute the
635
- conterfactual values for incremental outcome calculation. If not used, the
636
- unmodified `data.non_media_treatments` tensor is returned in the output.
637
682
 
638
683
  Returns:
639
684
  A `DataTensors` object containing scaled tensor parameters. The original
@@ -662,14 +707,9 @@ def _scale_tensors_by_multiplier(
662
707
  incremented_data[constants.ORGANIC_FREQUENCY] = (
663
708
  data.organic_frequency * multiplier
664
709
  )
665
- if non_media_treatments_baseline is not None:
666
- incremented_data[constants.NON_MEDIA_TREATMENTS] = (
667
- non_media_treatments_baseline
668
- )
669
- else:
670
- incremented_data[constants.NON_MEDIA_TREATMENTS] = data.non_media_treatments
671
710
 
672
711
  # Include the original data that does not get scaled.
712
+ incremented_data[constants.NON_MEDIA_TREATMENTS] = data.non_media_treatments
673
713
  incremented_data[constants.MEDIA_SPEND] = data.media_spend
674
714
  incremented_data[constants.RF_SPEND] = data.rf_spend
675
715
  incremented_data[constants.CONTROLS] = data.controls
@@ -719,79 +759,6 @@ def _central_tendency_and_ci_by_prior_and_posterior(
719
759
  return xr.Dataset(data_vars=xr_data, coords=xr_coords)
720
760
 
721
761
 
722
- def _compute_non_media_baseline(
723
- non_media_treatments: tf.Tensor,
724
- non_media_baseline_values: Sequence[float | str] | None = None,
725
- non_media_selected_times: Sequence[bool] | None = None,
726
- ) -> tf.Tensor:
727
- """Computes the baseline for each non-media treatment channel.
728
-
729
- Args:
730
- non_media_treatments: The non-media treatment input data.
731
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
732
- Each element is either a float (which means that the fixed value will be
733
- used as baseline for the given channel) or one of the strings "min" or
734
- "max" (which mean that the global minimum or maximum value will be used as
735
- baseline for the values of the given non_media treatment channel). If
736
- None, the minimum value is used as baseline for each non_media treatment
737
- channel.
738
- non_media_selected_times: Optional list of shape (n_times,). Each element is
739
- a boolean indicating whether the corresponding time period should be
740
- included in the baseline computation.
741
-
742
- Returns:
743
- A tensor of shape (n_geos, n_times, n_non_media_channels) containing the
744
- baseline values for each non-media treatment channel.
745
- """
746
-
747
- if non_media_selected_times is None:
748
- non_media_selected_times = [True] * non_media_treatments.shape[-2]
749
-
750
- if non_media_baseline_values is None:
751
- # If non_media_baseline_values is not provided, use the minimum value for
752
- # each non_media treatment channel as the baseline.
753
- non_media_baseline_values_filled = [
754
- constants.NON_MEDIA_BASELINE_MIN
755
- ] * non_media_treatments.shape[-1]
756
- else:
757
- non_media_baseline_values_filled = non_media_baseline_values
758
-
759
- if non_media_treatments.shape[-1] != len(non_media_baseline_values_filled):
760
- raise ValueError(
761
- "The number of non-media channels"
762
- f" ({non_media_treatments.shape[-1]}) does not match the number"
763
- f" of baseline types ({len(non_media_baseline_values_filled)})."
764
- )
765
-
766
- baseline_list = []
767
- for channel in range(non_media_treatments.shape[-1]):
768
- baseline_value = non_media_baseline_values_filled[channel]
769
-
770
- if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
771
- baseline_for_channel = tf.reduce_min(
772
- non_media_treatments[..., channel], axis=[0, 1]
773
- )
774
- elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
775
- baseline_for_channel = tf.reduce_max(
776
- non_media_treatments[..., channel], axis=[0, 1]
777
- )
778
- elif isinstance(baseline_value, float):
779
- baseline_for_channel = tf.cast(baseline_value, tf.float32)
780
- else:
781
- raise ValueError(
782
- f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
783
- " float numbers and strings 'min' and 'max' are supported."
784
- )
785
-
786
- baseline_list.append(
787
- baseline_for_channel
788
- * tf.ones_like(non_media_treatments[..., channel])
789
- * non_media_selected_times
790
- )
791
-
792
- return tf.stack(baseline_list, axis=-1)
793
-
794
-
795
762
  class Analyzer:
796
763
  """Runs calculations to analyze the raw data after fitting the model."""
797
764
 
@@ -818,7 +785,7 @@ class Analyzer:
818
785
  `media`, `reach`, `frequency`, `organic_media`, `organic_reach`,
819
786
  `organic_frequency`, `non_media_treatments`, `controls`. The `media`,
820
787
  `reach`, `organic_media`, `organic_reach` and `non_media_treatments`
821
- tensors are assumed to be scaled by their corresponding transformers.
788
+ tensors are expected to be scaled by their corresponding transformers.
822
789
  dist_tensors: A `DistributionTensors` container with the distribution
823
790
  tensors for media, RF, organic media, organic RF, non-media treatments,
824
791
  and controls.
@@ -1029,7 +996,7 @@ class Analyzer:
1029
996
  organic_media=self._meridian.organic_media_tensors.organic_media_scaled,
1030
997
  organic_reach=self._meridian.organic_rf_tensors.organic_reach_scaled,
1031
998
  organic_frequency=self._meridian.organic_rf_tensors.organic_frequency,
1032
- non_media_treatments=self._meridian.non_media_treatments_scaled,
999
+ non_media_treatments=self._meridian.non_media_treatments_normalized,
1033
1000
  controls=self._meridian.controls_scaled,
1034
1001
  revenue_per_kpi=self._meridian.revenue_per_kpi,
1035
1002
  )
@@ -1078,10 +1045,10 @@ class Analyzer:
1078
1045
  if new_data.organic_frequency is not None
1079
1046
  else self._meridian.organic_rf_tensors.organic_frequency
1080
1047
  )
1081
- non_media_treatments_scaled = _transformed_new_or_scaled(
1048
+ non_media_treatments_normalized = _transformed_new_or_scaled(
1082
1049
  new_variable=new_data.non_media_treatments,
1083
1050
  transformer=self._meridian.non_media_transformer,
1084
- scaled_variable=self._meridian.non_media_treatments_scaled,
1051
+ scaled_variable=self._meridian.non_media_treatments_normalized,
1085
1052
  )
1086
1053
  return DataTensors(
1087
1054
  media=media_scaled,
@@ -1090,7 +1057,7 @@ class Analyzer:
1090
1057
  organic_media=organic_media_scaled,
1091
1058
  organic_reach=organic_reach_scaled,
1092
1059
  organic_frequency=organic_frequency,
1093
- non_media_treatments=non_media_treatments_scaled,
1060
+ non_media_treatments=non_media_treatments_normalized,
1094
1061
  controls=controls_scaled,
1095
1062
  revenue_per_kpi=revenue_per_kpi,
1096
1063
  )
@@ -1559,7 +1526,7 @@ class Analyzer:
1559
1526
  self,
1560
1527
  data_tensors: DataTensors,
1561
1528
  dist_tensors: DistributionTensors,
1562
- non_media_baseline_values: Sequence[float | str] | None = None,
1529
+ non_media_treatments_baseline_normalized: Sequence[float] | None = None,
1563
1530
  ) -> tf.Tensor:
1564
1531
  """Computes incremental KPI distribution.
1565
1532
 
@@ -1573,17 +1540,26 @@ class Analyzer:
1573
1540
  dist_tensors: A `DistributionTensors` container with the distribution
1574
1541
  tensors for media, RF, organic media, organic RF and non-media
1575
1542
  treatments channels.
1576
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
1577
- Each element is either a float (which means that the fixed value will be
1578
- used as baseline for the given channel) or one of the strings "min" or
1579
- "max" (which mean that the global minimum or maximum value will be used
1580
- as baseline for the scaled values of the given non_media treatments
1581
- channel). If None, the minimum value is used as baseline for each
1582
- non_media treatments channel.
1543
+ non_media_treatments_baseline_normalized: Optional list of shape
1544
+ `(n_non_media_channels,)`. Each element is a float that will be used as
1545
+ baseline for the given channel. The values are expected to be scaled by
1546
+ population for channels where
1547
+ `model_spec.non_media_population_scaling_id` is `True` and normalized by
1548
+ centering and scaling using means and standard deviations. This argument
1549
+ is required if the data contains non-media treatments.
1583
1550
 
1584
1551
  Returns:
1585
1552
  Tensor of incremental KPI distribution.
1586
1553
  """
1554
+ if (
1555
+ data_tensors.non_media_treatments is not None
1556
+ and non_media_treatments_baseline_normalized is None
1557
+ ):
1558
+ raise ValueError(
1559
+ "`non_media_treatments_baseline_normalized` must be passed to"
1560
+ " `_get_incremental_kpi` when `non_media_treatments` data is"
1561
+ " present."
1562
+ )
1587
1563
  n_media_times = self._meridian.n_media_times
1588
1564
  if data_tensors.media is not None:
1589
1565
  n_times = data_tensors.media.shape[1] # pytype: disable=attribute-error
@@ -1606,13 +1582,10 @@ class Analyzer:
1606
1582
  combined_beta,
1607
1583
  )
1608
1584
  if data_tensors.non_media_treatments is not None:
1609
- non_media_scaled_baseline = _compute_non_media_baseline(
1610
- non_media_treatments=data_tensors.non_media_treatments,
1611
- non_media_baseline_values=non_media_baseline_values,
1612
- )
1613
1585
  non_media_kpi = tf.einsum(
1614
1586
  "gtn,...gn->...gtn",
1615
- data_tensors.non_media_treatments - non_media_scaled_baseline,
1587
+ data_tensors.non_media_treatments
1588
+ - non_media_treatments_baseline_normalized,
1616
1589
  dist_tensors.gamma_gn,
1617
1590
  )
1618
1591
  return tf.concat([combined_media_kpi, non_media_kpi], axis=-1)
@@ -1662,7 +1635,7 @@ class Analyzer:
1662
1635
  self,
1663
1636
  data_tensors: DataTensors,
1664
1637
  dist_tensors: DistributionTensors,
1665
- non_media_baseline_values: Sequence[float | str] | None = None,
1638
+ non_media_treatments_baseline_normalized: Sequence[float] | None = None,
1666
1639
  inverse_transform_outcome: bool | None = None,
1667
1640
  use_kpi: bool | None = None,
1668
1641
  selected_geos: Sequence[str] | None = None,
@@ -1687,20 +1660,21 @@ class Analyzer:
1687
1660
  poulation. Shape (n_geos x T x n_organic_rf_channels), for any time
1688
1661
  dimension T. `organic_frequency`: `organic frequency data` with shape
1689
1662
  (n_geos x T x n_organic_rf_channels), for any time dimension T.
1690
- `non_media_treatments`: `non_media_treatments` data with shape (n_geos x
1691
- T x n_non_media_channels), for any time dimension T. `revenue_per_kpi`:
1692
- Contains revenue per kpi data with shape `(n_geos x T)`, for any time
1693
- dimension `T`.
1694
- dist_tensors: A `DistributionTensors` container with the distribution
1695
- tensors for media, RF, organic media, organic RF and non-media treatments
1696
- channels.
1697
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
1698
- Each element is either a float (which means that the fixed value will be
1699
- used as baseline for the given channel) or one of the strings "min" or
1700
- "max" (which mean that the global minimum or maximum value will be used
1701
- as baseline for the scaled values of the given non_media treatments
1702
- channel). If None, the minimum value is used as baseline for each
1703
- non_media treatments channel.
1663
+ `non_media_treatments`: `non_media_treatments` data scaled by population
1664
+ for the selected channels and normalized by means and standard
1665
+ deviations with shape (n_geos x T x n_non_media_channels), for any time
1666
+ dimension T. `revenue_per_kpi`: Contains revenue per kpi data with shape
1667
+ `(n_geos x T)`, for any time dimension `T`.
1668
+ dist_tensors: A `DistributionTensors` container with the distribution
1669
+ tensors for media, RF, organic media, organic RF and non-media
1670
+ treatments channels.
1671
+ non_media_treatments_baseline_normalized: Optional list of shape
1672
+ `(n_non_media_channels,)`. Each element is a float that will be used as
1673
+ baseline for the given channel. The values are expected to be scaled by
1674
+ population for channels where
1675
+ `model_spec.non_media_population_scaling_id` is `True` and normalized by
1676
+ centering and scaling using means and standard deviations. This argument
1677
+ is required if the data contains non-media treatments.
1704
1678
  inverse_transform_outcome: Boolean. If `True`, returns the expected
1705
1679
  outcome in the original KPI or revenue (depending on what is passed to
1706
1680
  `use_kpi`), as it was passed to `InputData`. If False, returns the
@@ -1725,10 +1699,20 @@ class Analyzer:
1725
1699
  Tensor containing the incremental outcome distribution.
1726
1700
  """
1727
1701
  self._check_revenue_data_exists(use_kpi)
1702
+ if (
1703
+ data_tensors.non_media_treatments is not None
1704
+ and non_media_treatments_baseline_normalized is None
1705
+ ):
1706
+ raise ValueError(
1707
+ "`non_media_treatments_baseline_normalized` must be passed to"
1708
+ " `_incremental_outcome_impl` when `non_media_treatments` data is"
1709
+ " present."
1710
+ )
1711
+
1728
1712
  transformed_outcome = self._get_incremental_kpi(
1729
1713
  data_tensors=data_tensors,
1730
1714
  dist_tensors=dist_tensors,
1731
- non_media_baseline_values=non_media_baseline_values,
1715
+ non_media_treatments_baseline_normalized=non_media_treatments_baseline_normalized,
1732
1716
  )
1733
1717
  if inverse_transform_outcome:
1734
1718
  incremental_outcome = self._inverse_outcome(
@@ -1752,7 +1736,7 @@ class Analyzer:
1752
1736
  self,
1753
1737
  use_posterior: bool = True,
1754
1738
  new_data: DataTensors | None = None,
1755
- non_media_baseline_values: Sequence[float | str] | None = None,
1739
+ non_media_baseline_values: Sequence[float] | None = None,
1756
1740
  scaling_factor0: float = 0.0,
1757
1741
  scaling_factor1: float = 1.0,
1758
1742
  selected_geos: Sequence[str] | None = None,
@@ -1771,15 +1755,26 @@ class Analyzer:
1771
1755
  This calculates the media outcome of each media channel for each posterior
1772
1756
  or prior parameter draw. Incremental outcome is defined as:
1773
1757
 
1774
- `E(Outcome|Media_1, Controls)` minus `E(Outcome|Media_0, Controls)`
1758
+ `E(Outcome|Treatment_1, Controls)` minus `E(Outcome|Treatment_0, Controls)`
1759
+
1760
+ For paid & organic channels (without reach and frequency data),
1761
+ `Treatment_1` means that media execution for a given channel is multiplied
1762
+ by
1763
+ `scaling_factor1` (1.0 by default) for the set of time periods specified
1764
+ by `media_selected_times`. Similarly, `Treatment_0` means that media
1765
+ execution is multiplied by `scaling_factor0` (0.0 by default) for these time
1766
+ periods.
1767
+
1768
+ For paid & organic channels with reach and frequency data, either reach or
1769
+ frequency is held fixed while the other is scaled, depending on the
1770
+ `by_reach` argument.
1775
1771
 
1776
- Here, `Media_1` means that media execution for a given channel is multiplied
1777
- by `scaling_factor1` (1.0 by default) for the set of time periods specified
1778
- by `media_selected_times`. Similarly, `Media_0` means that media execution
1779
- is multiplied by `scaling_factor0` (0.0 by default) for these time periods.
1772
+ For non-media treatments, `Treatment_1` means that the variable is set to
1773
+ historical values. `Treatment_0` means that the variable is set to its
1774
+ baseline value for all geos and time periods. Note that the scaling factors
1775
+ (`scaling_factor0` and `scaling_factor1`) are not applicable to non-media
1776
+ treatments.
1780
1777
 
1781
- For channels with reach and frequency data, either reach or frequency is
1782
- held fixed while the other is scaled, depending on the `by_reach` argument.
1783
1778
  "Outcome" refers to either `revenue` if `use_kpi=False`, or `kpi` if
1784
1779
  `use_kpi=True`. When `revenue_per_kpi` is not defined, `use_kpi` cannot be
1785
1780
  False.
@@ -1821,13 +1816,13 @@ class Analyzer:
1821
1816
  any of the tensors in `new_data` is provided with a different number of
1822
1817
  time periods than in `InputData`, then all tensors must be provided with
1823
1818
  the same number of time periods.
1824
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
1825
- Each element is either a float (which means that the fixed value will be
1826
- used as baseline for the given channel) or one of the strings "min" or
1827
- "max" (which mean that the global minimum or maximum value will be used
1828
- as baseline for the scaled values of the given non_media treatments
1829
- channel). If not provided, the minimum value is used as the baseline for
1830
- each non_media treatments channel.
1819
+ non_media_baseline_values: Optional list of shape
1820
+ `(n_non_media_channels,)`. Each element is a float which means that the
1821
+ fixed value will be used as baseline for the given channel. It is
1822
+ expected that they are scaled by population for the channels where
1823
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
1824
+ `model_spec.non_media_baseline_values` is used, which defaults to the
1825
+ minimum value for each non_media treatment channel.
1831
1826
  scaling_factor0: Float. The factor by which to scale the counterfactual
1832
1827
  scenario "Media_0" during the time periods specified in
1833
1828
  `media_selected_times`. Must be non-negative and less than
@@ -1909,6 +1904,7 @@ class Analyzer:
1909
1904
  aggregate_geos=aggregate_geos,
1910
1905
  selected_geos=selected_geos,
1911
1906
  )
1907
+ _validate_non_media_baseline_values_numbers(non_media_baseline_values)
1912
1908
  dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
1913
1909
 
1914
1910
  if dist_type not in mmm.inference_data.groups():
@@ -1967,7 +1963,6 @@ class Analyzer:
1967
1963
  media_selected_times = [
1968
1964
  x in media_selected_times for x in mmm.input_data.media_time
1969
1965
  ]
1970
- non_media_selected_times = media_selected_times[-mmm.n_times :]
1971
1966
 
1972
1967
  # Set counterfactual tensors based on the scaling factors and the media
1973
1968
  # selected times.
@@ -1979,28 +1974,52 @@ class Analyzer:
1979
1974
  )[:, None]
1980
1975
 
1981
1976
  if data_tensors.non_media_treatments is not None:
1982
- new_non_media_treatments0 = _compute_non_media_baseline(
1983
- non_media_treatments=data_tensors.non_media_treatments,
1984
- non_media_baseline_values=non_media_baseline_values,
1985
- non_media_selected_times=non_media_selected_times,
1977
+ non_media_treatments_baseline_scaled = (
1978
+ self._meridian.compute_non_media_treatments_baseline(
1979
+ non_media_baseline_values=non_media_baseline_values,
1980
+ )
1981
+ )
1982
+ non_media_treatments_baseline_normalized = self._meridian.non_media_transformer.forward( # pytype: disable=attribute-error
1983
+ non_media_treatments_baseline_scaled,
1984
+ apply_population_scaling=False,
1985
+ )
1986
+ non_media_treatments0 = tf.broadcast_to(
1987
+ tf.constant(
1988
+ non_media_treatments_baseline_normalized, dtype=tf.float32
1989
+ )[tf.newaxis, tf.newaxis, :],
1990
+ self._meridian.non_media_treatments.shape, # pytype: disable=attribute-error
1986
1991
  )
1987
1992
  else:
1988
- new_non_media_treatments0 = None
1993
+ non_media_treatments_baseline_normalized = None
1994
+ non_media_treatments0 = None
1989
1995
 
1990
1996
  incremented_data0 = _scale_tensors_by_multiplier(
1991
1997
  data=data_tensors,
1992
1998
  multiplier=counterfactual0,
1993
1999
  by_reach=by_reach,
1994
- non_media_treatments_baseline=new_non_media_treatments0,
1995
2000
  )
1996
2001
  incremented_data1 = _scale_tensors_by_multiplier(
1997
2002
  data=data_tensors, multiplier=counterfactual1, by_reach=by_reach
1998
2003
  )
1999
2004
 
2000
- data_tensors0 = self._get_scaled_data_tensors(
2005
+ scaled_data0 = self._get_scaled_data_tensors(
2001
2006
  new_data=incremented_data0,
2002
2007
  include_non_paid_channels=include_non_paid_channels,
2003
2008
  )
2009
+ # TODO: b/415198977 - Verify the computation of outcome of non-media
2010
+ # treatments with `media_selected_times` and scale factors.
2011
+
2012
+ data_tensors0 = DataTensors(
2013
+ media=scaled_data0.media,
2014
+ reach=scaled_data0.reach,
2015
+ frequency=scaled_data0.frequency,
2016
+ organic_media=scaled_data0.organic_media,
2017
+ organic_reach=scaled_data0.organic_reach,
2018
+ organic_frequency=scaled_data0.organic_frequency,
2019
+ revenue_per_kpi=scaled_data0.revenue_per_kpi,
2020
+ non_media_treatments=non_media_treatments0,
2021
+ )
2022
+
2004
2023
  data_tensors1 = self._get_scaled_data_tensors(
2005
2024
  new_data=incremented_data1,
2006
2025
  include_non_paid_channels=include_non_paid_channels,
@@ -2027,7 +2046,9 @@ class Analyzer:
2027
2046
  incremental_outcome_kwargs = {
2028
2047
  "inverse_transform_outcome": inverse_transform_outcome,
2029
2048
  "use_kpi": use_kpi,
2030
- "non_media_baseline_values": non_media_baseline_values,
2049
+ "non_media_treatments_baseline_normalized": (
2050
+ non_media_treatments_baseline_normalized
2051
+ ),
2031
2052
  }
2032
2053
  for i, start_index in enumerate(batch_starting_indices):
2033
2054
  stop_index = np.min([n_draws, start_index + batch_size])
@@ -2503,7 +2524,7 @@ class Analyzer:
2503
2524
  aggregate_geos: bool = False,
2504
2525
  aggregate_times: bool = False,
2505
2526
  split_by_holdout_id: bool = False,
2506
- non_media_baseline_values: Sequence[str | float] | None = None,
2527
+ non_media_baseline_values: Sequence[float] | None = None,
2507
2528
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
2508
2529
  ) -> xr.Dataset:
2509
2530
  """Calculates the data for the expected versus actual outcome over time.
@@ -2515,19 +2536,20 @@ class Analyzer:
2515
2536
  are summed over all of the time periods.
2516
2537
  split_by_holdout_id: Boolean. If `True` and `holdout_id` exists, the data
2517
2538
  is split into `'Train'`, `'Test'`, and `'All Data'` subsections.
2518
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
2519
- Each element is either a float (which means that the fixed value will be
2520
- used as baseline for the given channel) or one of the strings "min" or
2521
- "max" (which mean that the global minimum or maximum value will be used
2522
- as baseline for the values of the given non_media treatment channel). If
2523
- None, the minimum value is used as baseline for each non_media treatment
2524
- channel.
2539
+ non_media_baseline_values: Optional list of shape
2540
+ `(n_non_media_channels,)`. Each element is a float which means that the
2541
+ fixed value will be used as baseline for the given channel. It is
2542
+ expected that they are scaled by population for the channels where
2543
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
2544
+ `model_spec.non_media_baseline_values` is used, which defaults to the
2545
+ minimum value for each non_media treatment channel.
2525
2546
  confidence_level: Confidence level for expected outcome credible
2526
2547
  intervals, represented as a value between zero and one. Default: `0.9`.
2527
2548
 
2528
2549
  Returns:
2529
2550
  A dataset with the expected, baseline, and actual outcome metrics.
2530
2551
  """
2552
+ _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2531
2553
  mmm = self._meridian
2532
2554
  use_kpi = self._meridian.input_data.revenue_per_kpi is None
2533
2555
  can_split_by_holdout = self._can_split_by_holdout_id(split_by_holdout_id)
@@ -2597,7 +2619,7 @@ class Analyzer:
2597
2619
 
2598
2620
  def _calculate_baseline_expected_outcome(
2599
2621
  self,
2600
- non_media_baseline_values: Sequence[str | float] | None = None,
2622
+ non_media_baseline_values: Sequence[float] | None = None,
2601
2623
  **expected_outcome_kwargs,
2602
2624
  ) -> tf.Tensor:
2603
2625
  """Calculates either the posterior or prior expected outcome of baseline.
@@ -2609,20 +2631,19 @@ class Analyzer:
2609
2631
  3) `new_organic_media` is set to all zeros
2610
2632
  4) `new_organic_reach` is set to all zeros
2611
2633
  5) `new_non_media_treatments` is set to the counterfactual values
2612
- according to the
2613
- `non_media_baseline_values` argument
2634
+ according to the `non_media_baseline_values` argument
2614
2635
  6) `new_controls` are set to historical values
2615
2636
 
2616
2637
  All other arguments of `expected_outcome` can be passed to this method.
2617
2638
 
2618
2639
  Args:
2619
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
2620
- Each element is either a float (which means that the fixed value will be
2621
- used as baseline for the given channel) or one of the strings "min" or
2622
- "max" (which mean that the global minimum or maximum value will be used
2623
- as baseline for the values of the given non_media treatment channel). If
2624
- None, the minimum value is used as baseline for each non_media treatment
2625
- channel.
2640
+ non_media_baseline_values: Optional list of shape
2641
+ `(n_non_media_channels,)`. Each element is a float which means that the
2642
+ fixed value will be used as baseline for the given channel. It is
2643
+ expected that they are scaled by population for the channels where
2644
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
2645
+ `model_spec.non_media_baseline_values` is used, which defaults to the
2646
+ minimum value for each non_media treatment channel.
2626
2647
  **expected_outcome_kwargs: kwargs to pass to `expected_outcome`, which
2627
2648
  could contain use_posterior, selected_geos, selected_times,
2628
2649
  aggregate_geos, aggregate_times, inverse_transform_outcome, use_kpi,
@@ -2655,10 +2676,27 @@ class Analyzer:
2655
2676
  else None
2656
2677
  )
2657
2678
  if self._meridian.non_media_treatments is not None:
2658
- new_non_media_treatments = _compute_non_media_baseline(
2659
- non_media_treatments=self._meridian.non_media_treatments,
2679
+ if self._meridian.model_spec.non_media_population_scaling_id is not None:
2680
+ scaling_factors = tf.where(
2681
+ self._meridian.model_spec.non_media_population_scaling_id,
2682
+ self._meridian.population[:, tf.newaxis, tf.newaxis],
2683
+ tf.ones_like(self._meridian.population)[:, tf.newaxis, tf.newaxis],
2684
+ )
2685
+ else:
2686
+ scaling_factors = tf.ones_like(self._meridian.population)[
2687
+ :, tf.newaxis, tf.newaxis
2688
+ ]
2689
+
2690
+ baseline = self._meridian.compute_non_media_treatments_baseline(
2660
2691
  non_media_baseline_values=non_media_baseline_values,
2661
2692
  )
2693
+ new_non_media_treatments_population_scaled = tf.broadcast_to(
2694
+ tf.constant(baseline, dtype=tf.float32)[tf.newaxis, tf.newaxis, :],
2695
+ self._meridian.non_media_treatments.shape,
2696
+ )
2697
+ new_non_media_treatments = (
2698
+ new_non_media_treatments_population_scaled * scaling_factors
2699
+ )
2662
2700
  else:
2663
2701
  new_non_media_treatments = None
2664
2702
  new_controls = self._meridian.controls
@@ -2679,7 +2717,7 @@ class Analyzer:
2679
2717
  new_data: DataTensors | None = None,
2680
2718
  use_kpi: bool | None = None,
2681
2719
  include_non_paid_channels: bool = True,
2682
- non_media_baseline_values: Sequence[str | float] | None = None,
2720
+ non_media_baseline_values: Sequence[float] | None = None,
2683
2721
  **kwargs,
2684
2722
  ) -> tf.Tensor:
2685
2723
  """Aggregates the incremental outcome of the media channels.
@@ -2707,13 +2745,13 @@ class Analyzer:
2707
2745
  include_non_paid_channels: Boolean. If `True`, then non-media treatments
2708
2746
  and organic effects are included in the calculation. If `False`, then
2709
2747
  only the paid media and RF effects are included.
2710
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
2711
- Each element is either a float (which means that the fixed value will be
2712
- used as baseline for the given channel) or one of the strings "min" or
2713
- "max" (which mean that the global minimum or maximum value will be used
2714
- as baseline for the scaled values of the given non_media treatments
2715
- channel). If not provided, the minimum value is used as the baseline for
2716
- each non_media treatments channel.
2748
+ non_media_baseline_values: Optional list of shape
2749
+ `(n_non_media_channels,)`. Each element is a float which means that the
2750
+ fixed value will be used as baseline for the given channel. It is
2751
+ expected that they are scaled by population for the channels where
2752
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
2753
+ `model_spec.non_media_baseline_values` is used, which defaults to the
2754
+ minimum value for each non_media treatment channel.
2717
2755
  **kwargs: kwargs to pass to `incremental_outcome`, which could contain
2718
2756
  selected_geos, selected_times, aggregate_geos, aggregate_times,
2719
2757
  batch_size.
@@ -2723,6 +2761,7 @@ class Analyzer:
2723
2761
  of the channel dimension is incremented by one, with the new component at
2724
2762
  the end containing the total incremental outcome of all channels.
2725
2763
  """
2764
+ _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2726
2765
  use_kpi = use_kpi or self._meridian.input_data.revenue_per_kpi is None
2727
2766
  incremental_outcome_m = self.incremental_outcome(
2728
2767
  use_posterior=use_posterior,
@@ -2755,7 +2794,7 @@ class Analyzer:
2755
2794
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
2756
2795
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
2757
2796
  include_non_paid_channels: bool = False,
2758
- non_media_baseline_values: Sequence[str | float] | None = None,
2797
+ non_media_baseline_values: Sequence[float] | None = None,
2759
2798
  ) -> xr.Dataset:
2760
2799
  """Returns summary metrics.
2761
2800
 
@@ -2831,13 +2870,13 @@ class Analyzer:
2831
2870
  reported. If `False`, only the paid channels (media, reach and
2832
2871
  frequency) are included but the summary contains also the metrics
2833
2872
  dependent on spend. Default: `False`.
2834
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
2835
- Each element is either a float (which means that the fixed value will be
2836
- used as baseline for the given channel) or one of the strings "min" or
2837
- "max" (which mean that the global minimum or maximum value will be used
2838
- as baseline for the values of the given non_media treatment channel). If
2839
- None, the minimum value is used as baseline for each non_media treatment
2840
- channel.
2873
+ non_media_baseline_values: Optional list of shape
2874
+ `(n_non_media_channels,)`. Each element is a float which means that the
2875
+ fixed value will be used as baseline for the given channel. It is
2876
+ expected that they are scaled by population for the channels where
2877
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
2878
+ `model_spec.non_media_baseline_values` is used, which defaults to the
2879
+ minimum value for each non_media treatment channel.
2841
2880
 
2842
2881
  Returns:
2843
2882
  An `xr.Dataset` with coordinates: `channel`, `metric` (`mean`, `median`,
@@ -2851,6 +2890,7 @@ class Analyzer:
2851
2890
  when `aggregate_times=False` because they do not have a clear
2852
2891
  interpretation by time period.
2853
2892
  """
2893
+ _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2854
2894
  dim_kwargs = {
2855
2895
  "selected_geos": selected_geos,
2856
2896
  "selected_times": selected_times,
@@ -3239,7 +3279,7 @@ class Analyzer:
3239
3279
  selected_times: Sequence[str] | None = None,
3240
3280
  aggregate_geos: bool = True,
3241
3281
  aggregate_times: bool = True,
3242
- non_media_baseline_values: Sequence[float | str] | None = None,
3282
+ non_media_baseline_values: Sequence[float] | None = None,
3243
3283
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
3244
3284
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
3245
3285
  ) -> xr.Dataset:
@@ -3254,13 +3294,13 @@ class Analyzer:
3254
3294
  all of the regions.
3255
3295
  aggregate_times: Boolean. If `True`, the expected outcome is summed over
3256
3296
  all of the time periods.
3257
- non_media_baseline_values: Optional list of shape (n_non_media_channels,).
3258
- Each element is either a float (which means that the fixed value will be
3259
- used as baseline for the given channel) or one of the strings "min" or
3260
- "max" (which mean that the global minimum or maximum value will be used
3261
- as baseline for the values of the given non_media treatment channel). If
3262
- None, the minimum value is used as baseline for each non_media treatment
3263
- channel.
3297
+ non_media_baseline_values: Optional list of shape
3298
+ `(n_non_media_channels,)`. Each element is a float which means that the
3299
+ fixed value will be used as baseline for the given channel. It is
3300
+ expected that they are scaled by population for the channels where
3301
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
3302
+ `model_spec.non_media_baseline_values` is used, which defaults to the
3303
+ minimum value for each non_media treatment channel.
3264
3304
  confidence_level: Confidence level for media summary metrics credible
3265
3305
  intervals, represented as a value between zero and one.
3266
3306
  batch_size: Integer representing the maximum draws per chain in each
@@ -3273,6 +3313,7 @@ class Analyzer:
3273
3313
  `ci_low`,`ci_high`),`distribution` (prior, posterior) and contains the
3274
3314
  following data variables: `baseline_outcome`, `pct_of_contribution`.
3275
3315
  """
3316
+ _validate_non_media_baseline_values_numbers(non_media_baseline_values)
3276
3317
  # TODO: Change "pct_of_contribution" to a more accurate term.
3277
3318
 
3278
3319
  use_kpi = self._meridian.input_data.revenue_per_kpi is None
@@ -4663,11 +4704,11 @@ class Analyzer:
4663
4704
 
4664
4705
  def get_historical_spend(
4665
4706
  self,
4666
- selected_times: Sequence[str] | None,
4707
+ selected_times: Sequence[str] | None = None,
4667
4708
  include_media: bool = True,
4668
4709
  include_rf: bool = True,
4669
4710
  ) -> xr.DataArray:
4670
- """Gets the aggregated historical spend based on the time period.
4711
+ """Deprecated. Gets the aggregated historical spend based on the time.
4671
4712
 
4672
4713
  Args:
4673
4714
  selected_times: The time period to get the historical spends. If None, the
@@ -4681,6 +4722,51 @@ class Analyzer:
4681
4722
  An `xr.DataArray` with the coordinate `channel` and contains the data
4682
4723
  variable `spend`.
4683
4724
 
4725
+ Raises:
4726
+ ValueError: A ValueError is raised when `include_media` and `include_rf`
4727
+ are both False.
4728
+ """
4729
+ warnings.warn(
4730
+ "`get_historical_spend` is deprecated. Please use "
4731
+ "`get_aggregated_spend` with `new_data=None` instead.",
4732
+ DeprecationWarning,
4733
+ stacklevel=2,
4734
+ )
4735
+ return self.get_aggregated_spend(
4736
+ selected_times=selected_times,
4737
+ include_media=include_media,
4738
+ include_rf=include_rf,
4739
+ )
4740
+
4741
+ def get_aggregated_spend(
4742
+ self,
4743
+ new_data: DataTensors | None = None,
4744
+ selected_times: Sequence[str] | Sequence[bool] | None = None,
4745
+ include_media: bool = True,
4746
+ include_rf: bool = True,
4747
+ ) -> xr.DataArray:
4748
+ """Gets the aggregated spend based on the selected time.
4749
+
4750
+ Args:
4751
+ new_data: An optional `DataTensors` object containing the new `media`,
4752
+ `media_spend`, `reach`, `frequency`, `rf_spend` tensors. If `None`, the
4753
+ existing tensors from the Meridian object are used. If `new_data`
4754
+ argument is used, then the aggregated spend is computed using the values
4755
+ of the tensors passed in the `new_data` argument and the original values
4756
+ of all the remaining tensors. If any of the tensors in `new_data` is
4757
+ provided with a different number of time periods than in `InputData`,
4758
+ then all tensors must be provided with the same number of time periods.
4759
+ selected_times: The time period to get the aggregated spends. If None, the
4760
+ spend will be aggregated over all time periods.
4761
+ include_media: Whether to include spends for paid media channels that do
4762
+ not have R&F data.
4763
+ include_rf: Whether to include spends for paid media channels with R&F
4764
+ data.
4765
+
4766
+ Returns:
4767
+ An `xr.DataArray` with the coordinate `channel` and contains the data
4768
+ variable `spend`.
4769
+
4684
4770
  Raises:
4685
4771
  ValueError: A ValueError is raised when `include_media` and `include_rf`
4686
4772
  are both False.
@@ -4689,6 +4775,11 @@ class Analyzer:
4689
4775
  raise ValueError(
4690
4776
  "At least one of include_media or include_rf must be True."
4691
4777
  )
4778
+ new_data = new_data or DataTensors()
4779
+ required_tensors_names = constants.PAID_CHANNELS + constants.SPEND_DATA
4780
+ filled_data = new_data.validate_and_fill_missing_data(
4781
+ required_tensors_names, self._meridian
4782
+ )
4692
4783
 
4693
4784
  empty_da = xr.DataArray(
4694
4785
  dims=[constants.CHANNEL], coords={constants.CHANNEL: []}
@@ -4709,8 +4800,8 @@ class Analyzer:
4709
4800
  else:
4710
4801
  aggregated_media_spend = self._impute_and_aggregate_spend(
4711
4802
  selected_times,
4712
- self._meridian.media_tensors.media,
4713
- self._meridian.media_tensors.media_spend,
4803
+ filled_data.media,
4804
+ filled_data.media_spend,
4714
4805
  list(self._meridian.input_data.media_channel.values),
4715
4806
  )
4716
4807
 
@@ -4723,18 +4814,16 @@ class Analyzer:
4723
4814
  or self._meridian.rf_tensors.rf_spend is None
4724
4815
  ):
4725
4816
  warnings.warn(
4726
- "Requested spends for paid media channels with R&F data, but but the"
4817
+ "Requested spends for paid media channels with R&F data, but the"
4727
4818
  " channels are not available.",
4728
4819
  )
4729
4820
  aggregated_rf_spend = empty_da
4730
4821
  else:
4731
- rf_execution_values = (
4732
- self._meridian.rf_tensors.reach * self._meridian.rf_tensors.frequency
4733
- )
4822
+ rf_execution_values = filled_data.reach * filled_data.frequency
4734
4823
  aggregated_rf_spend = self._impute_and_aggregate_spend(
4735
4824
  selected_times,
4736
4825
  rf_execution_values,
4737
- self._meridian.rf_tensors.rf_spend,
4826
+ filled_data.rf_spend,
4738
4827
  list(self._meridian.input_data.rf_channel.values),
4739
4828
  )
4740
4829
 
@@ -4744,7 +4833,7 @@ class Analyzer:
4744
4833
 
4745
4834
  def _impute_and_aggregate_spend(
4746
4835
  self,
4747
- selected_times: Sequence[str] | None,
4836
+ selected_times: Sequence[str] | Sequence[bool] | None,
4748
4837
  media_execution_values: tf.Tensor,
4749
4838
  channel_spend: tf.Tensor,
4750
4839
  channel_names: Sequence[str],
@@ -4759,7 +4848,7 @@ class Analyzer:
4759
4848
  argument, its values only affect the output when imputation is required.
4760
4849
 
4761
4850
  Args:
4762
- selected_times: The time period to get the historical spend.
4851
+ selected_times: The time period to get the aggregated spend.
4763
4852
  media_execution_values: The media execution values over all time points.
4764
4853
  channel_spend: The spend over all time points. Its shape can be `(n_geos,
4765
4854
  n_times, n_media_channels)` or `(n_media_channels,)` if the data is
@@ -4775,17 +4864,24 @@ class Analyzer:
4775
4864
  "selected_times": selected_times,
4776
4865
  "aggregate_geos": True,
4777
4866
  "aggregate_times": True,
4867
+ "flexible_time_dim": True,
4778
4868
  }
4779
4869
 
4780
4870
  if channel_spend.ndim == 3:
4781
4871
  aggregated_spend = self.filter_and_aggregate_geos_and_times(
4782
4872
  channel_spend,
4873
+ has_media_dim=True,
4783
4874
  **dim_kwargs,
4784
4875
  ).numpy()
4785
4876
  # channel_spend.ndim can only be 3 or 1.
4786
4877
  else:
4787
4878
  # media spend can have more time points than the model time points
4788
- media_exe_values = media_execution_values[:, -self._meridian.n_times :, :]
4879
+ if media_execution_values.shape[1] == self._meridian.n_media_times:
4880
+ media_exe_values = media_execution_values[
4881
+ :, -self._meridian.n_times :, :
4882
+ ]
4883
+ else:
4884
+ media_exe_values = media_execution_values
4789
4885
  # Calculates CPM over all times and geos if the spend does not have time
4790
4886
  # and geo dimensions.
4791
4887
  target_media_exe_values = self.filter_and_aggregate_geos_and_times(