google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
  2. google_meridian-1.3.0.dist-info/RECORD +62 -0
  3. meridian/analysis/__init__.py +2 -0
  4. meridian/analysis/analyzer.py +280 -142
  5. meridian/analysis/formatter.py +2 -2
  6. meridian/analysis/optimizer.py +353 -169
  7. meridian/analysis/review/__init__.py +20 -0
  8. meridian/analysis/review/checks.py +721 -0
  9. meridian/analysis/review/configs.py +110 -0
  10. meridian/analysis/review/constants.py +40 -0
  11. meridian/analysis/review/results.py +544 -0
  12. meridian/analysis/review/reviewer.py +186 -0
  13. meridian/analysis/summarizer.py +14 -12
  14. meridian/analysis/templates/chips.html.jinja +12 -0
  15. meridian/analysis/test_utils.py +27 -5
  16. meridian/analysis/visualizer.py +45 -50
  17. meridian/backend/__init__.py +698 -55
  18. meridian/backend/config.py +75 -16
  19. meridian/backend/test_utils.py +127 -1
  20. meridian/constants.py +52 -11
  21. meridian/data/input_data.py +7 -2
  22. meridian/data/test_utils.py +5 -3
  23. meridian/mlflow/autolog.py +2 -2
  24. meridian/model/__init__.py +1 -0
  25. meridian/model/adstock_hill.py +10 -9
  26. meridian/model/eda/__init__.py +3 -0
  27. meridian/model/eda/constants.py +21 -0
  28. meridian/model/eda/eda_engine.py +1580 -84
  29. meridian/model/eda/eda_outcome.py +200 -0
  30. meridian/model/eda/eda_spec.py +84 -0
  31. meridian/model/eda/meridian_eda.py +220 -0
  32. meridian/model/knots.py +56 -50
  33. meridian/model/media.py +10 -8
  34. meridian/model/model.py +79 -16
  35. meridian/model/model_test_data.py +53 -9
  36. meridian/model/posterior_sampler.py +398 -391
  37. meridian/model/prior_distribution.py +114 -39
  38. meridian/model/prior_sampler.py +146 -90
  39. meridian/model/spec.py +7 -8
  40. meridian/model/transformers.py +16 -8
  41. meridian/version.py +1 -1
  42. google_meridian-1.2.0.dist-info/RECORD +0 -52
  43. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
  44. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
  45. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  """Methods to compute analysis metrics of the model and the data."""
16
16
 
17
17
  from collections.abc import Mapping, Sequence
18
+ import dataclasses
18
19
  import itertools
19
20
  import numbers
20
21
  from typing import Any, Optional
@@ -53,6 +54,7 @@ def _validate_non_media_baseline_values_numbers(
53
54
 
54
55
 
55
56
  # TODO: Refactor the related unit tests to be under DataTensors.
57
+ @dataclasses.dataclass
56
58
  class DataTensors(backend.ExtensionType):
57
59
  """Container for data variable arguments of Analyzer methods.
58
60
 
@@ -175,12 +177,31 @@ class DataTensors(backend.ExtensionType):
175
177
  else None
176
178
  )
177
179
  self.time = (
178
- backend.to_tensor(time, dtype="string") if time is not None else None
180
+ backend.to_tensor(time, dtype=backend.string)
181
+ if time is not None
182
+ else None
179
183
  )
180
-
181
- def __validate__(self):
182
184
  self._validate_n_dims()
183
185
 
186
+ def __eq__(self, other: Any) -> bool:
187
+ """Provides safe equality comparison for mixed tensor/non-tensor fields."""
188
+ if type(self) is not type(other):
189
+ return NotImplemented
190
+ for field in dataclasses.fields(self):
191
+ a = getattr(self, field.name)
192
+ b = getattr(other, field.name)
193
+ if a is None and b is None:
194
+ continue
195
+ if a is None or b is None:
196
+ return False
197
+ try:
198
+ if not bool(np.all(backend.to_tensor(backend.equal(a, b)))):
199
+ return False
200
+ except (ValueError, TypeError):
201
+ if a != b:
202
+ return False
203
+ return True
204
+
184
205
  def total_spend(self) -> backend.Tensor | None:
185
206
  """Returns the total spend tensor.
186
207
 
@@ -216,7 +237,7 @@ class DataTensors(backend.ExtensionType):
216
237
  of the corresponding tensor in the `meridian` object. If all time
217
238
  dimensions are the same, returns `None`.
218
239
  """
219
- for field in self._tf_extension_type_fields():
240
+ for field in dataclasses.fields(self):
220
241
  new_tensor = getattr(self, field.name)
221
242
  if field.name == constants.RF_IMPRESSIONS:
222
243
  old_tensor = getattr(meridian.rf_tensors, field.name)
@@ -282,7 +303,7 @@ class DataTensors(backend.ExtensionType):
282
303
 
283
304
  def _validate_n_dims(self):
284
305
  """Raises an error if the tensors have the wrong number of dimensions."""
285
- for field in self._tf_extension_type_fields():
306
+ for field in dataclasses.fields(self):
286
307
  tensor = getattr(self, field.name)
287
308
  if tensor is None:
288
309
  continue
@@ -315,7 +336,7 @@ class DataTensors(backend.ExtensionType):
315
336
  Warning: If an attribute exists in the `DataTensors` object that is not in
316
337
  the `required_variables` list, it will be ignored.
317
338
  """
318
- for field in self._tf_extension_type_fields():
339
+ for field in dataclasses.fields(self):
319
340
  tensor = getattr(self, field.name)
320
341
  if tensor is None:
321
342
  continue
@@ -468,7 +489,7 @@ class DataTensors(backend.ExtensionType):
468
489
  ) -> Self:
469
490
  """Fills default values and returns a new DataTensors object."""
470
491
  output = {}
471
- for field in self._tf_extension_type_fields():
492
+ for field in dataclasses.fields(self):
472
493
  var_name = field.name
473
494
  if var_name not in required_fields:
474
495
  continue
@@ -489,7 +510,7 @@ class DataTensors(backend.ExtensionType):
489
510
  old_tensor = meridian.revenue_per_kpi
490
511
  elif var_name == constants.TIME:
491
512
  old_tensor = backend.to_tensor(
492
- meridian.input_data.time.values.tolist(), dtype="string"
513
+ meridian.input_data.time.values.tolist(), dtype=backend.string
493
514
  )
494
515
  else:
495
516
  continue
@@ -500,6 +521,7 @@ class DataTensors(backend.ExtensionType):
500
521
  return DataTensors(**output)
501
522
 
502
523
 
524
+ @dataclasses.dataclass
503
525
  class DistributionTensors(backend.ExtensionType):
504
526
  """Container for parameters distributions arguments of Analyzer methods."""
505
527
 
@@ -583,17 +605,19 @@ def _transformed_new_or_scaled(
583
605
 
584
606
  def _calc_rsquared(expected, actual):
585
607
  """Calculates r-squared between actual and expected outcome."""
586
- return 1 - np.nanmean((expected - actual) ** 2) / np.nanvar(actual)
608
+ return 1 - backend.nanmean((expected - actual) ** 2) / backend.nanvar(actual)
587
609
 
588
610
 
589
611
  def _calc_mape(expected, actual):
590
612
  """Calculates MAPE between actual and expected outcome."""
591
- return np.nanmean(np.abs((actual - expected) / actual))
613
+ return backend.nanmean(backend.absolute((actual - expected) / actual))
592
614
 
593
615
 
594
616
  def _calc_weighted_mape(expected, actual):
595
617
  """Calculates wMAPE between actual and expected outcome (weighted by actual)."""
596
- return np.nansum(np.abs(actual - expected)) / np.nansum(actual)
618
+ return backend.nansum(backend.absolute(actual - expected)) / backend.nansum(
619
+ actual
620
+ )
597
621
 
598
622
 
599
623
  def _warn_if_geo_arg_in_kwargs(**kwargs):
@@ -675,43 +699,66 @@ def _validate_flexible_selected_times(
675
699
  selected_times: Sequence[str] | Sequence[bool] | None,
676
700
  media_selected_times: Sequence[str] | Sequence[bool] | None,
677
701
  new_n_media_times: int,
702
+ new_time: Sequence[str] | None = None,
678
703
  ):
679
704
  """Raises an error if selected times or media selected times is invalid.
680
705
 
681
- This checks that the `selected_times` and `media_selected_times` arguments
682
- are lists of booleans with the same number of elements as `new_n_media_times`.
683
- This is only relevant if the time dimension of any of the variables in
684
- `new_data` used in the analysis is modified.
706
+ This checks that (1) the `selected_times` and `media_selected_times` arguments
707
+ are lists of booleans with the same number of elements as `new_n_media_times`,
708
+ or (2) the `selected_times` and `media_selected_times` arguments are lists of
709
+ strings and the `new_time` list is provided and `selected_times` and
710
+ `media_selected_times` are subsets of `new_time`. This is only relevant if the
711
+ time dimension of any of the variables in `new_data` used in the analysis is
712
+ modified.
685
713
 
686
714
  Args:
687
715
  selected_times: Optional list of times to validate.
688
716
  media_selected_times: Optional list of media times to validate.
689
717
  new_n_media_times: The number of time periods in the new data.
718
+ new_time: The optional time dimension of the new data.
690
719
  """
691
720
  if selected_times and (
692
- not _is_bool_list(selected_times)
693
- or len(selected_times) != new_n_media_times
721
+ not (
722
+ _is_bool_list(selected_times)
723
+ and len(selected_times) == new_n_media_times
724
+ )
725
+ and not (
726
+ _is_str_list(selected_times)
727
+ and new_time is not None
728
+ and set(selected_times) <= set(new_time)
729
+ )
694
730
  ):
695
731
  raise ValueError(
696
732
  "If `media`, `reach`, `frequency`, `organic_media`,"
697
733
  " `organic_reach`, `organic_frequency`, `non_media_treatments`, or"
698
734
  " `revenue_per_kpi` is provided with a different number of time"
699
- " periods than in `InputData`, then `selected_times` must be a list"
735
+ " periods than in `InputData`, then (1) `selected_times` must be a list"
700
736
  " of booleans with length equal to the number of time periods in"
701
- " the new data."
737
+ " the new data, or (2) `selected_times` must be a list of strings and"
738
+ " `new_time` must be provided and `selected_times` must be a subset of"
739
+ " `new_time`."
702
740
  )
703
741
 
704
742
  if media_selected_times and (
705
- not _is_bool_list(media_selected_times)
706
- or len(media_selected_times) != new_n_media_times
743
+ not (
744
+ _is_bool_list(media_selected_times)
745
+ and len(media_selected_times) == new_n_media_times
746
+ )
747
+ and not (
748
+ _is_str_list(media_selected_times)
749
+ and new_time is not None
750
+ and set(media_selected_times) <= set(new_time)
751
+ )
707
752
  ):
708
753
  raise ValueError(
709
754
  "If `media`, `reach`, `frequency`, `organic_media`,"
710
755
  " `organic_reach`, `organic_frequency`, `non_media_treatments`, or"
711
756
  " `revenue_per_kpi` is provided with a different number of time"
712
- " periods than in `InputData`, then `media_selected_times` must be"
757
+ " periods than in `InputData`, then (1) `media_selected_times` must be"
713
758
  " a list of booleans with length equal to the number of time"
714
- " periods in the new data."
759
+ " periods in the new data, or (2) `media_selected_times` must be a list"
760
+ " of strings and `new_time` must be provided and"
761
+ " `media_selected_times` must be a subset of `new_time`."
715
762
  )
716
763
 
717
764
 
@@ -870,42 +917,37 @@ class Analyzer:
870
917
  )
871
918
  return result
872
919
 
873
- def _check_revenue_data_exists(self, use_kpi: bool = False):
874
- """Checks if the revenue data is available for the analysis.
920
+ def _use_kpi(self, use_kpi: bool = False) -> bool:
921
+ """Checks if KPI analysis should be used.
875
922
 
876
- In the `kpi_type=NON_REVENUE` case, `revenue_per_kpi` is required to perform
877
- the revenue analysis. If `revenue_per_kpi` is not defined, then the revenue
878
- data is not available and the revenue analysis (`use_kpi=False`) is not
879
- possible. Only the KPI analysis (`use_kpi=True`) is possible in this case.
923
+ If `use_kpi` is `True` but `kpi_type=REVENUE`, then `use_kpi` is ignored.
880
924
 
881
- In the `kpi_type=REVENUE` case, KPI is equal to revenue and setting
882
- `use_kpi=True` has no effect. Therefore, a warning is issued if the default
883
- `False` value of `use_kpi` is overridden by the user.
925
+ If `use_kpi` is `False`, then `revenue_per_kpi` is required to perform
926
+ the revenue analysis. Setting `use_kpi` to `False` in this case is ignored.
884
927
 
885
928
  Args:
886
- use_kpi: A boolean flag indicating whether to use KPI instead of revenue.
929
+ use_kpi: A boolean flag indicating whether KPI analysis should be used.
887
930
 
931
+ Returns:
932
+ A boolean flag indicating whether KPI analysis should be used.
888
933
  Raises:
889
- ValueError: If `use_kpi` is `False` and `revenue_per_kpi` is not defined.
890
- UserWarning: If `use_kpi` is `True` in the `kpi_type=REVENUE` case.
934
+ UserWarning: If the KPI type is revenue and use_kpi is True or if
935
+ `use_kpi=False` but `revenue_per_kpi` is not available.
891
936
  """
892
- if self._meridian.input_data.kpi_type == constants.NON_REVENUE:
893
- if not use_kpi and self._meridian.revenue_per_kpi is None:
894
- raise ValueError(
895
- "Revenue analysis is not available when `revenue_per_kpi` is"
896
- " unknown. Set `use_kpi=True` to perform KPI analysis instead."
897
- )
937
+ if use_kpi and self._meridian.input_data.kpi_type == constants.REVENUE:
938
+ warnings.warn(
939
+ "Setting `use_kpi=True` has no effect when `kpi_type=REVENUE`"
940
+ " since in this case, KPI is equal to revenue."
941
+ )
942
+ return False
898
943
 
899
- if self._meridian.input_data.kpi_type == constants.REVENUE:
900
- # In the `kpi_type=REVENUE` case, KPI is equal to revenue and
901
- # `revenue_per_kpi` is set to a tensor of 1s in the initialization of the
902
- # `InputData` object.
903
- assert self._meridian.revenue_per_kpi is not None
904
- if use_kpi:
905
- warnings.warn(
906
- "Setting `use_kpi=True` has no effect when `kpi_type=REVENUE`"
907
- " since in this case, KPI is equal to revenue."
908
- )
944
+ if not use_kpi and self._meridian.input_data.revenue_per_kpi is None:
945
+ warnings.warn(
946
+ "Revenue analysis is not available when `revenue_per_kpi` is"
947
+ " unknown. Defaulting to KPI analysis."
948
+ )
949
+
950
+ return use_kpi or self._meridian.input_data.revenue_per_kpi is None
909
951
 
910
952
  def _get_adstock_dataframe(
911
953
  self,
@@ -1381,8 +1423,14 @@ class Analyzer:
1381
1423
  "`selected_geos` must match the geo dimension names from "
1382
1424
  "meridian.InputData."
1383
1425
  )
1384
- geo_mask = [x in selected_geos for x in mmm.input_data.geo]
1385
- tensor = backend.boolean_mask(tensor, geo_mask, axis=geo_dim)
1426
+ geo_indices = [
1427
+ i for i, x in enumerate(mmm.input_data.geo) if x in selected_geos
1428
+ ]
1429
+ tensor = backend.gather(
1430
+ tensor,
1431
+ backend.to_tensor(geo_indices, dtype=backend.int32),
1432
+ axis=geo_dim,
1433
+ )
1386
1434
 
1387
1435
  if selected_times is not None:
1388
1436
  _validate_selected_times(
@@ -1393,10 +1441,21 @@ class Analyzer:
1393
1441
  comparison_arg_name="`tensor`",
1394
1442
  )
1395
1443
  if _is_str_list(selected_times):
1396
- time_mask = [x in selected_times for x in mmm.input_data.time]
1397
- tensor = backend.boolean_mask(tensor, time_mask, axis=time_dim)
1444
+ time_indices = [
1445
+ i for i, x in enumerate(mmm.input_data.time) if x in selected_times
1446
+ ]
1447
+ tensor = backend.gather(
1448
+ tensor,
1449
+ backend.to_tensor(time_indices, dtype=backend.int32),
1450
+ axis=time_dim,
1451
+ )
1398
1452
  elif _is_bool_list(selected_times):
1399
- tensor = backend.boolean_mask(tensor, selected_times, axis=time_dim)
1453
+ time_indices = [i for i, x in enumerate(selected_times) if x]
1454
+ tensor = backend.gather(
1455
+ tensor,
1456
+ backend.to_tensor(time_indices, dtype=backend.int32),
1457
+ axis=time_dim,
1458
+ )
1400
1459
 
1401
1460
  tensor_dims = "...gt" + "m" * has_media_dim
1402
1461
  output_dims = (
@@ -1452,19 +1511,19 @@ class Analyzer:
1452
1511
  calculated.
1453
1512
  new_data: An optional `DataTensors` container with optional new tensors:
1454
1513
  `media`, `reach`, `frequency`, `organic_media`, `organic_reach`,
1455
- `organic_frequency`, `non_media_treatments`, `controls`. If `None`,
1456
- expected outcome is calculated conditional on the original values of the
1457
- data tensors that the Meridian object was initialized with. If
1458
- `new_data` argument is used, expected outcome is calculated conditional
1459
- on the values of the tensors passed in `new_data` and on the original
1460
- values of the remaining unset tensors. For example,
1514
+ `organic_frequency`, `non_media_treatments`, `revenue_per_kpi`,
1515
+ `controls`. If `None`, expected outcome is calculated conditional on the
1516
+ original values of the data tensors that the Meridian object was
1517
+ initialized with. If `new_data` argument is used, expected outcome is
1518
+ calculated conditional on the values of the tensors passed in `new_data`
1519
+ and on the original values of the remaining unset tensors. For example,
1461
1520
  `expected_outcome(new_data=DataTensors(reach=new_reach,
1462
1521
  frequency=new_frequency))` calculates expected outcome conditional on
1463
1522
  the original `media`, `organic_media`, `organic_reach`,
1464
- `organic_frequency`, `non_media_treatments` and `controls` tensors and
1465
- on the new given values for `reach` and `frequency` tensors. The new
1466
- tensors' dimensions must match the dimensions of the corresponding
1467
- original tensors from `input_data`.
1523
+ `organic_frequency`, `non_media_treatments`, `revenue_per_kpi`, and
1524
+ `controls` tensors and on the new given values for `reach` and
1525
+ `frequency` tensors. The new tensors' dimensions must match the
1526
+ dimensions of the corresponding original tensors from `input_data`.
1468
1527
  selected_geos: Optional list of containing a subset of geos to include. By
1469
1528
  default, all geos are included.
1470
1529
  selected_times: Optional list of containing a subset of dates to include.
@@ -1498,8 +1557,7 @@ class Analyzer:
1498
1557
  or `sample_prior()` (for `use_posterior=False`) has not been called
1499
1558
  prior to calling this method.
1500
1559
  """
1501
-
1502
- self._check_revenue_data_exists(use_kpi)
1560
+ use_kpi = self._use_kpi(use_kpi)
1503
1561
  self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
1504
1562
  if self._meridian.is_national:
1505
1563
  _warn_if_geo_arg_in_kwargs(
@@ -1515,7 +1573,9 @@ class Analyzer:
1515
1573
  if new_data is None:
1516
1574
  new_data = DataTensors()
1517
1575
 
1518
- required_fields = constants.NON_REVENUE_DATA
1576
+ required_fields = (
1577
+ constants.PAID_DATA + constants.NON_PAID_DATA + (constants.CONTROLS,)
1578
+ )
1519
1579
  filled_tensors = new_data.validate_and_fill_missing_data(
1520
1580
  required_tensors_names=required_fields,
1521
1581
  meridian=self._meridian,
@@ -1569,7 +1629,7 @@ class Analyzer:
1569
1629
  if inverse_transform_outcome:
1570
1630
  outcome_means = self._meridian.kpi_transformer.inverse(outcome_means)
1571
1631
  if not use_kpi:
1572
- outcome_means *= self._meridian.revenue_per_kpi
1632
+ outcome_means *= filled_tensors.revenue_per_kpi
1573
1633
 
1574
1634
  return self.filter_and_aggregate_geos_and_times(
1575
1635
  outcome_means,
@@ -1698,7 +1758,7 @@ class Analyzer:
1698
1758
  Returns:
1699
1759
  Tensor of incremental outcome returned in terms of revenue or KPI.
1700
1760
  """
1701
- self._check_revenue_data_exists(use_kpi)
1761
+ use_kpi = self._use_kpi(use_kpi)
1702
1762
  if revenue_per_kpi is None:
1703
1763
  revenue_per_kpi = self._meridian.revenue_per_kpi
1704
1764
  t1 = self._meridian.kpi_transformer.inverse(
@@ -1711,7 +1771,17 @@ class Analyzer:
1711
1771
  return kpi
1712
1772
  return backend.einsum("gt,...gtm->...gtm", revenue_per_kpi, kpi)
1713
1773
 
1714
- @backend.function(jit_compile=True)
1774
+ @backend.function(
1775
+ jit_compile=True,
1776
+ static_argnames=[
1777
+ "inverse_transform_outcome",
1778
+ "use_kpi",
1779
+ "selected_geos",
1780
+ "selected_times",
1781
+ "aggregate_geos",
1782
+ "aggregate_times",
1783
+ ],
1784
+ )
1715
1785
  def _incremental_outcome_impl(
1716
1786
  self,
1717
1787
  data_tensors: DataTensors,
@@ -1781,7 +1851,7 @@ class Analyzer:
1781
1851
  Returns:
1782
1852
  Tensor containing the incremental outcome distribution.
1783
1853
  """
1784
- self._check_revenue_data_exists(use_kpi)
1854
+ use_kpi = self._use_kpi(use_kpi)
1785
1855
  if (
1786
1856
  data_tensors.non_media_treatments is not None
1787
1857
  and non_media_treatments_baseline_normalized is None
@@ -1982,7 +2052,7 @@ class Analyzer:
1982
2052
  with matching time dimensions.
1983
2053
  """
1984
2054
  mmm = self._meridian
1985
- self._check_revenue_data_exists(use_kpi)
2055
+ use_kpi = self._use_kpi(use_kpi)
1986
2056
  self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
1987
2057
  if self._meridian.is_national:
1988
2058
  _warn_if_geo_arg_in_kwargs(
@@ -2123,8 +2193,12 @@ class Analyzer:
2123
2193
  )
2124
2194
  incremental_outcome_temps = [None] * len(batch_starting_indices)
2125
2195
  dim_kwargs = {
2126
- "selected_geos": selected_geos,
2127
- "selected_times": selected_times,
2196
+ "selected_geos": (
2197
+ tuple(selected_geos) if selected_geos is not None else None
2198
+ ),
2199
+ "selected_times": (
2200
+ tuple(selected_times) if selected_times is not None else None
2201
+ ),
2128
2202
  "aggregate_geos": aggregate_geos,
2129
2203
  "aggregate_times": aggregate_times,
2130
2204
  }
@@ -2299,7 +2373,7 @@ class Analyzer:
2299
2373
  "selected_times": selected_times,
2300
2374
  "aggregate_geos": aggregate_geos,
2301
2375
  }
2302
- self._check_revenue_data_exists(use_kpi)
2376
+ use_kpi = self._use_kpi(use_kpi)
2303
2377
  self._validate_geo_and_time_granularity(**dim_kwargs)
2304
2378
  required_values = constants.PERFORMANCE_DATA
2305
2379
  if not new_data:
@@ -2408,6 +2482,7 @@ class Analyzer:
2408
2482
  (n_media_channels + n_rf_channels))`. The `n_geos` dimension is dropped if
2409
2483
  `aggregate_geos=True`.
2410
2484
  """
2485
+ use_kpi = self._use_kpi(use_kpi)
2411
2486
  dim_kwargs = {
2412
2487
  "selected_geos": selected_geos,
2413
2488
  "selected_times": selected_times,
@@ -2421,7 +2496,6 @@ class Analyzer:
2421
2496
  "include_non_paid_channels": False,
2422
2497
  "aggregate_times": True,
2423
2498
  }
2424
- self._check_revenue_data_exists(use_kpi)
2425
2499
  self._validate_geo_and_time_granularity(**dim_kwargs)
2426
2500
  required_values = constants.PERFORMANCE_DATA
2427
2501
  if not new_data:
@@ -2609,6 +2683,7 @@ class Analyzer:
2609
2683
  self,
2610
2684
  aggregate_geos: bool = False,
2611
2685
  aggregate_times: bool = False,
2686
+ use_kpi: bool = False,
2612
2687
  split_by_holdout_id: bool = False,
2613
2688
  non_media_baseline_values: Sequence[float] | None = None,
2614
2689
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
@@ -2620,6 +2695,8 @@ class Analyzer:
2620
2695
  summed over all of the regions.
2621
2696
  aggregate_times: Boolean. If `True`, the expected, baseline, and actual
2622
2697
  are summed over all of the time periods.
2698
+ use_kpi: If `True`, calculate the incremental KPI. Otherwise, calculate
2699
+ the incremental revenue using the revenue per KPI (if available).
2623
2700
  split_by_holdout_id: Boolean. If `True` and `holdout_id` exists, the data
2624
2701
  is split into `'Train'`, `'Test'`, and `'All Data'` subsections.
2625
2702
  non_media_baseline_values: Optional list of shape
@@ -2636,8 +2713,8 @@ class Analyzer:
2636
2713
  A dataset with the expected, baseline, and actual outcome metrics.
2637
2714
  """
2638
2715
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2716
+ use_kpi = self._use_kpi(use_kpi)
2639
2717
  mmm = self._meridian
2640
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
2641
2718
  can_split_by_holdout = self._can_split_by_holdout_id(split_by_holdout_id)
2642
2719
  expected_outcome = self.expected_outcome(
2643
2720
  aggregate_geos=False, aggregate_times=False, use_kpi=use_kpi
@@ -2805,7 +2882,7 @@ class Analyzer:
2805
2882
  self,
2806
2883
  use_posterior: bool,
2807
2884
  new_data: DataTensors | None = None,
2808
- use_kpi: bool | None = None,
2885
+ use_kpi: bool = False,
2809
2886
  include_non_paid_channels: bool = True,
2810
2887
  non_media_baseline_values: Sequence[float] | None = None,
2811
2888
  **kwargs,
@@ -2852,7 +2929,7 @@ class Analyzer:
2852
2929
  the end containing the total incremental outcome of all channels.
2853
2930
  """
2854
2931
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2855
- use_kpi = use_kpi or self._meridian.input_data.revenue_per_kpi is None
2932
+ use_kpi = self._use_kpi(use_kpi)
2856
2933
  incremental_outcome_m = self.incremental_outcome(
2857
2934
  use_posterior=use_posterior,
2858
2935
  new_data=new_data,
@@ -2981,6 +3058,7 @@ class Analyzer:
2981
3058
  interpretation by time period.
2982
3059
  """
2983
3060
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
3061
+ use_kpi = self._use_kpi(use_kpi)
2984
3062
  dim_kwargs = {
2985
3063
  "selected_geos": selected_geos,
2986
3064
  "selected_times": selected_times,
@@ -3123,16 +3201,19 @@ class Analyzer:
3123
3201
  ).where(lambda ds: ds.channel != constants.ALL_CHANNELS)
3124
3202
 
3125
3203
  if new_data.get_modified_times(self._meridian) is None:
3204
+ expected_outcome_fields = list(
3205
+ constants.PAID_DATA + constants.NON_PAID_DATA + (constants.CONTROLS,)
3206
+ )
3126
3207
  expected_outcome_prior = self.expected_outcome(
3127
3208
  use_posterior=False,
3128
- new_data=new_data.filter_fields(constants.NON_REVENUE_DATA),
3209
+ new_data=new_data.filter_fields(expected_outcome_fields),
3129
3210
  use_kpi=use_kpi,
3130
3211
  **dim_kwargs,
3131
3212
  **batched_kwargs,
3132
3213
  )
3133
3214
  expected_outcome_posterior = self.expected_outcome(
3134
3215
  use_posterior=True,
3135
- new_data=new_data.filter_fields(constants.NON_REVENUE_DATA),
3216
+ new_data=new_data.filter_fields(expected_outcome_fields),
3136
3217
  use_kpi=use_kpi,
3137
3218
  **dim_kwargs,
3138
3219
  **batched_kwargs,
@@ -3376,6 +3457,7 @@ class Analyzer:
3376
3457
  aggregate_geos: bool = True,
3377
3458
  aggregate_times: bool = True,
3378
3459
  non_media_baseline_values: Sequence[float] | None = None,
3460
+ use_kpi: bool = False,
3379
3461
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
3380
3462
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
3381
3463
  ) -> xr.Dataset:
@@ -3397,6 +3479,8 @@ class Analyzer:
3397
3479
  `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
3398
3480
  `model_spec.non_media_baseline_values` is used, which defaults to the
3399
3481
  minimum value for each non_media treatment channel.
3482
+ use_kpi: Boolean. If `True`, the baseline summary metrics are calculated
3483
+ using KPI. If `False`, the metrics are calculated using revenue.
3400
3484
  confidence_level: Confidence level for media summary metrics credible
3401
3485
  intervals, represented as a value between zero and one.
3402
3486
  batch_size: Integer representing the maximum draws per chain in each
@@ -3412,7 +3496,7 @@ class Analyzer:
3412
3496
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
3413
3497
  # TODO: Change "pct_of_contribution" to a more accurate term.
3414
3498
 
3415
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
3499
+ use_kpi = self._use_kpi(use_kpi)
3416
3500
  dim_kwargs = {
3417
3501
  "selected_geos": selected_geos,
3418
3502
  "selected_times": selected_times,
@@ -3595,6 +3679,7 @@ class Analyzer:
3595
3679
  ValueError: If there are no channels with reach and frequency data.
3596
3680
  """
3597
3681
  dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
3682
+ use_kpi = self._use_kpi(use_kpi)
3598
3683
  new_data = new_data or DataTensors()
3599
3684
  if self._meridian.n_rf_channels == 0:
3600
3685
  raise ValueError(
@@ -3673,9 +3758,11 @@ class Analyzer:
3673
3758
  )
3674
3759
 
3675
3760
  optimal_frequency = [freq_grid[i] for i in optimal_freq_idx]
3676
- optimal_frequency_tensor = backend.to_tensor(
3677
- backend.ones_like(filled_data.rf_impressions) * optimal_frequency,
3678
- backend.float32,
3761
+ optimal_frequency_values = backend.to_tensor(
3762
+ optimal_frequency, dtype=backend.float32
3763
+ )
3764
+ optimal_frequency_tensor = (
3765
+ backend.ones_like(filled_data.rf_impressions) * optimal_frequency_values
3679
3766
  )
3680
3767
  optimal_reach = filled_data.rf_impressions / optimal_frequency_tensor
3681
3768
 
@@ -3760,10 +3847,7 @@ class Analyzer:
3760
3847
  attrs={
3761
3848
  constants.CONFIDENCE_LEVEL: confidence_level,
3762
3849
  constants.USE_POSTERIOR: use_posterior,
3763
- constants.IS_REVENUE_KPI: (
3764
- self._meridian.input_data.kpi_type == constants.REVENUE
3765
- or not use_kpi
3766
- ),
3850
+ constants.IS_REVENUE_KPI: not use_kpi,
3767
3851
  },
3768
3852
  )
3769
3853
 
@@ -3771,6 +3855,7 @@ class Analyzer:
3771
3855
  self,
3772
3856
  selected_geos: Sequence[str] | None = None,
3773
3857
  selected_times: Sequence[str] | None = None,
3858
+ use_kpi: bool = False,
3774
3859
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
3775
3860
  ) -> xr.Dataset:
3776
3861
  """Calculates `R-Squared`, `MAPE`, and `wMAPE` goodness of fit metrics.
@@ -3801,6 +3886,8 @@ class Analyzer:
3801
3886
  default, all geos are included.
3802
3887
  selected_times: Optional list containing a subset of dates to include. By
3803
3888
  default, all time periods are included.
3889
+ use_kpi: Whether to use KPI or revenue scale for the predictive accuracy
3890
+ metrics.
3804
3891
  batch_size: Integer representing the maximum draws per chain in each
3805
3892
  batch. By default, `batch_size` is `100`. The calculation is run in
3806
3893
  batches to avoid memory exhaustion. If a memory error occurs, try
@@ -3814,7 +3901,7 @@ class Analyzer:
3814
3901
  is split into `'Train'`, `'Test'`, and `'All Data'` subsections, and the
3815
3902
  three metrics are computed for each.
3816
3903
  """
3817
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
3904
+ use_kpi = self._use_kpi(use_kpi)
3818
3905
  if self._meridian.is_national:
3819
3906
  _warn_if_geo_arg_in_kwargs(
3820
3907
  selected_geos=selected_geos,
@@ -3835,10 +3922,10 @@ class Analyzer:
3835
3922
  ],
3836
3923
  constants.GEO_GRANULARITY: [constants.GEO, constants.NATIONAL],
3837
3924
  }
3838
- if self._meridian.revenue_per_kpi is not None:
3839
- input_tensor = self._meridian.kpi * self._meridian.revenue_per_kpi
3840
- else:
3925
+ if use_kpi:
3841
3926
  input_tensor = self._meridian.kpi
3927
+ else:
3928
+ input_tensor = self._meridian.kpi * self._meridian.revenue_per_kpi
3842
3929
  actual = np.asarray(
3843
3930
  self.filter_and_aggregate_geos_and_times(
3844
3931
  tensor=input_tensor,
@@ -3967,10 +4054,11 @@ class Analyzer:
3967
4054
  "sample_posterior() must be called prior to calling this method."
3968
4055
  )
3969
4056
 
3970
- def _transpose_first_two_dims(x: backend.Tensor) -> backend.Tensor:
3971
- n_dim = len(x.shape)
4057
+ def _transpose_first_two_dims(x: Any) -> backend.Tensor:
4058
+ x_tensor = backend.to_tensor(x)
4059
+ n_dim = len(x_tensor.shape)
3972
4060
  perm = [1, 0] + list(range(2, n_dim))
3973
- return backend.transpose(x, perm)
4061
+ return backend.transpose(x_tensor, perm)
3974
4062
 
3975
4063
  rhat = backend.mcmc.potential_scale_reduction({
3976
4064
  k: _transpose_first_two_dims(v)
@@ -4003,8 +4091,6 @@ class Analyzer:
4003
4091
  Returns:
4004
4092
  A DataFrame with the following columns:
4005
4093
 
4006
- * `n_params`: The number of respective parameters in the model.
4007
- * `avg_rhat`: The average R-hat value for the respective parameter.
4008
4094
  * `n_params`: The number of respective parameters in the model.
4009
4095
  * `avg_rhat`: The average R-hat value for the respective parameter.
4010
4096
  * `max_rhat`: The maximum R-hat value for the respective parameter.
@@ -4056,6 +4142,7 @@ class Analyzer:
4056
4142
 
4057
4143
  def response_curves(
4058
4144
  self,
4145
+ new_data: DataTensors | None = None,
4059
4146
  spend_multipliers: list[float] | None = None,
4060
4147
  use_posterior: bool = True,
4061
4148
  selected_geos: Sequence[str] | None = None,
@@ -4081,6 +4168,15 @@ class Analyzer:
4081
4168
  `selected_times` are also scaled by the multiplier.)
4082
4169
 
4083
4170
  Args:
4171
+ new_data: Optional `DataTensors` object with optional new tensors:
4172
+ `media`, `reach`, `frequency`, `media_spend`, `rf_spend`,
4173
+ `revenue_per_kpi`, `times`. If provided, the response curves are
4174
+ calculated using the values of the tensors passed in `new_data` and the
4175
+ original values of all the remaining tensors. If `None`, the response
4176
+ curves are calculated using the original values of all the tensors. If
4177
+ any of the tensors in `new_data` is provided with a different number of
4178
+ time periods than in `InputData`, then all tensors must be provided with
4179
+ the same number of time periods and the `time` tensor must be provided.
4084
4180
  spend_multipliers: List of multipliers. Each channel's total spend is
4085
4181
  multiplied by these factors to obtain the values at which the curve is
4086
4182
  calculated for that channel.
@@ -4088,8 +4184,11 @@ class Analyzer:
4088
4184
  generated. If `False`, prior response curves are generated.
4089
4185
  selected_geos: Optional list containing a subset of geos to include. By
4090
4186
  default, all geos are included.
4091
- selected_times: Optional list containing a subset of dates to include. By
4092
- default, all time periods are included.
4187
+ selected_times: Optional list containing a subset of dates to include. If
4188
+ `new_data` is provided with modified time periods, then `selected_times`
4189
+ must be a subset of `new_data.times`. Otherwise, `selected_times` must
4190
+ be a subset of `self._meridian.input_data.time`. By default, all time
4191
+ periods are included.
4093
4192
  by_reach: Boolean. For channels with reach and frequency. If `True`, plots
4094
4193
  the response curve by reach. If `False`, plots the response curve by
4095
4194
  frequency.
@@ -4118,11 +4217,49 @@ class Analyzer:
4118
4217
  "aggregate_geos": True,
4119
4218
  "aggregate_times": True,
4120
4219
  }
4220
+ if new_data is None:
4221
+ new_data = DataTensors()
4222
+ # TODO: b/442920356 - Support flexible time without providing exact dates.
4223
+ required_tensors_names = constants.PERFORMANCE_DATA + (constants.TIME,)
4224
+ filled_data = new_data.validate_and_fill_missing_data(
4225
+ required_tensors_names=required_tensors_names,
4226
+ meridian=self._meridian,
4227
+ allow_modified_times=True,
4228
+ )
4229
+ new_n_media_times = filled_data.get_modified_times(self._meridian)
4230
+
4231
+ if new_n_media_times is None:
4232
+ _validate_selected_times(
4233
+ selected_times=selected_times,
4234
+ input_times=self._meridian.input_data.time,
4235
+ n_times=self._meridian.n_times,
4236
+ arg_name="selected_times",
4237
+ comparison_arg_name="the input data",
4238
+ )
4239
+ else:
4240
+ new_time = np.asarray(filled_data.time).astype(str).tolist()
4241
+ _validate_flexible_selected_times(
4242
+ selected_times=selected_times,
4243
+ media_selected_times=None,
4244
+ new_n_media_times=new_n_media_times,
4245
+ new_time=new_time,
4246
+ )
4247
+ # TODO: b/407847021 - Switch to Sequence[str] once it is supported.
4248
+ if selected_times is not None:
4249
+ selected_times = [x in selected_times for x in new_time]
4250
+ dim_kwargs["selected_times"] = selected_times
4251
+
4121
4252
  if self._meridian.n_rf_channels > 0 and use_optimal_frequency:
4122
- frequency = backend.ones_like(
4123
- self._meridian.rf_tensors.frequency
4124
- ) * backend.to_tensor(
4253
+ opt_freq_data = DataTensors(
4254
+ media=filled_data.media,
4255
+ rf_impressions=filled_data.reach * filled_data.frequency,
4256
+ media_spend=filled_data.media_spend,
4257
+ rf_spend=filled_data.rf_spend,
4258
+ revenue_per_kpi=filled_data.revenue_per_kpi,
4259
+ )
4260
+ frequency = backend.ones_like(filled_data.frequency) * backend.to_tensor(
4125
4261
  self.optimal_freq(
4262
+ new_data=opt_freq_data,
4126
4263
  selected_geos=selected_geos,
4127
4264
  selected_times=selected_times,
4128
4265
  use_kpi=use_kpi,
@@ -4130,12 +4267,12 @@ class Analyzer:
4130
4267
  dtype=backend.float32,
4131
4268
  )
4132
4269
  reach = backend.divide_no_nan(
4133
- self._meridian.rf_tensors.reach * self._meridian.rf_tensors.frequency,
4270
+ filled_data.reach * filled_data.frequency,
4134
4271
  frequency,
4135
4272
  )
4136
4273
  else:
4137
- frequency = self._meridian.rf_tensors.frequency
4138
- reach = self._meridian.rf_tensors.reach
4274
+ frequency = filled_data.frequency
4275
+ reach = filled_data.reach
4139
4276
  if spend_multipliers is None:
4140
4277
  spend_multipliers = list(np.arange(0, 2.2, 0.2))
4141
4278
  incremental_outcome = np.zeros((
@@ -4149,18 +4286,19 @@ class Analyzer:
4149
4286
  (len(self._meridian.input_data.get_all_paid_channels()), 3)
4150
4287
  ) # Last dimension = 3 for the mean, ci_lo and ci_hi.
4151
4288
  continue
4152
- new_data = _scale_tensors_by_multiplier(
4289
+ scaled_data = _scale_tensors_by_multiplier(
4153
4290
  data=DataTensors(
4154
- media=self._meridian.media_tensors.media,
4291
+ media=filled_data.media,
4155
4292
  reach=reach,
4156
4293
  frequency=frequency,
4294
+ revenue_per_kpi=filled_data.revenue_per_kpi,
4157
4295
  ),
4158
4296
  multiplier=multiplier,
4159
4297
  by_reach=by_reach,
4160
4298
  )
4161
4299
  inc_outcome_temp = self.incremental_outcome(
4162
4300
  use_posterior=use_posterior,
4163
- new_data=new_data.filter_fields(constants.PAID_DATA),
4301
+ new_data=scaled_data.filter_fields(constants.PAID_DATA),
4164
4302
  inverse_transform_outcome=True,
4165
4303
  batch_size=batch_size,
4166
4304
  use_kpi=use_kpi,
@@ -4171,22 +4309,11 @@ class Analyzer:
4171
4309
  inc_outcome_temp, confidence_level
4172
4310
  )
4173
4311
 
4174
- if self._meridian.n_media_channels > 0 and self._meridian.n_rf_channels > 0:
4175
- spend = backend.concatenate(
4176
- [
4177
- self._meridian.media_tensors.media_spend,
4178
- self._meridian.rf_tensors.rf_spend,
4179
- ],
4180
- axis=-1,
4181
- )
4182
- elif self._meridian.n_media_channels > 0:
4183
- spend = self._meridian.media_tensors.media_spend
4184
- else:
4185
- spend = self._meridian.rf_tensors.rf_spend
4186
-
4187
- if backend.rank(spend) == 3:
4312
+ spend = filled_data.total_spend()
4313
+ if spend is not None and spend.ndim == 3:
4188
4314
  spend = self.filter_and_aggregate_geos_and_times(
4189
4315
  tensor=spend,
4316
+ flexible_time_dim=True,
4190
4317
  **dim_kwargs,
4191
4318
  )
4192
4319
  spend_einsum = backend.einsum("k,m->km", np.array(spend_multipliers), spend)
@@ -4880,11 +5007,12 @@ class Analyzer:
4880
5007
  def get_aggregated_spend(
4881
5008
  self,
4882
5009
  new_data: DataTensors | None = None,
5010
+ selected_geos: Sequence[str] | None = None,
4883
5011
  selected_times: Sequence[str] | Sequence[bool] | None = None,
4884
5012
  include_media: bool = True,
4885
5013
  include_rf: bool = True,
4886
5014
  ) -> xr.DataArray:
4887
- """Gets the aggregated spend based on the selected time.
5015
+ """Gets the aggregated spend based on the selected geos and time.
4888
5016
 
4889
5017
  Args:
4890
5018
  new_data: An optional `DataTensors` object containing the new `media`,
@@ -4895,6 +5023,9 @@ class Analyzer:
4895
5023
  of all the remaining tensors. If any of the tensors in `new_data` is
4896
5024
  provided with a different number of time periods than in `InputData`,
4897
5025
  then all tensors must be provided with the same number of time periods.
5026
+ selected_geos: Optional list containing a subset of geos to include. By
5027
+ default, all geos are included. The selected geos should match those in
5028
+ `InputData.geo`.
4898
5029
  selected_times: Optional list containing either a subset of dates to
4899
5030
  include or booleans with length equal to the number of time periods in
4900
5031
  KPI data. By default, all time periods are included.
@@ -4939,10 +5070,11 @@ class Analyzer:
4939
5070
  aggregated_media_spend = empty_da
4940
5071
  else:
4941
5072
  aggregated_media_spend = self._impute_and_aggregate_spend(
4942
- selected_times,
4943
- filled_data.media,
4944
- filled_data.media_spend,
4945
- list(self._meridian.input_data.media_channel.values),
5073
+ selected_geos=selected_geos,
5074
+ selected_times=selected_times,
5075
+ media_execution_values=filled_data.media,
5076
+ channel_spend=filled_data.media_spend,
5077
+ channel_names=list(self._meridian.input_data.media_channel.values),
4946
5078
  )
4947
5079
 
4948
5080
  if not include_rf:
@@ -4961,10 +5093,11 @@ class Analyzer:
4961
5093
  else:
4962
5094
  rf_execution_values = filled_data.reach * filled_data.frequency
4963
5095
  aggregated_rf_spend = self._impute_and_aggregate_spend(
4964
- selected_times,
4965
- rf_execution_values,
4966
- filled_data.rf_spend,
4967
- list(self._meridian.input_data.rf_channel.values),
5096
+ selected_geos=selected_geos,
5097
+ selected_times=selected_times,
5098
+ media_execution_values=rf_execution_values,
5099
+ channel_spend=filled_data.rf_spend,
5100
+ channel_names=list(self._meridian.input_data.rf_channel.values),
4968
5101
  )
4969
5102
 
4970
5103
  return xr.concat(
@@ -4973,21 +5106,26 @@ class Analyzer:
4973
5106
 
4974
5107
  def _impute_and_aggregate_spend(
4975
5108
  self,
5109
+ selected_geos: Sequence[str] | None,
4976
5110
  selected_times: Sequence[str] | Sequence[bool] | None,
4977
5111
  media_execution_values: backend.Tensor,
4978
5112
  channel_spend: backend.Tensor,
4979
5113
  channel_names: Sequence[str],
4980
5114
  ) -> xr.DataArray:
4981
- """Imputes and aggregates the spend over the selected time period.
5115
+ """Imputes and aggregates the spend within selected dimensions.
4982
5116
 
4983
- This function is used to aggregate the spend over the selected time period.
4984
- Imputation is required when `channel_spend` has only one dimension and the
4985
- aggregation is applied to only a subset of times, as specified by
4986
- `selected_times`. The `media_execution_values` argument only serves the
4987
- purpose of imputation. Although `media_execution_values` is a required
4988
- argument, its values only affect the output when imputation is required.
5117
+ This function is used to aggregate the spend within selected geos over the
5118
+ selected time period. Imputation is required when `channel_spend` has only
5119
+ one dimension and the aggregation is applied to only a subset of geos or
5120
+ times, as specified by `selected_geos` and `selected_times`. The
5121
+ `media_execution_values` argument only serves the purpose of imputation.
5122
+ Although `media_execution_values` is a required argument, its values only
5123
+ affect the output when imputation is required.
4989
5124
 
4990
5125
  Args:
5126
+ selected_geos: Optional list containing a subset of geos to include. By
5127
+ default, all geos are included. The selected geos should match those in
5128
+ `InputData.geo`.
4991
5129
  selected_times: Optional list containing either a subset of dates to
4992
5130
  include or booleans with length equal to the number of time periods in
4993
5131
  KPI data. By default, all time periods are included.
@@ -5002,7 +5140,7 @@ class Analyzer:
5002
5140
  variable `spend`.
5003
5141
  """
5004
5142
  dim_kwargs = {
5005
- "selected_geos": None,
5143
+ "selected_geos": selected_geos,
5006
5144
  "selected_times": selected_times,
5007
5145
  "aggregate_geos": True,
5008
5146
  "aggregate_times": True,