google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -15,6 +15,7 @@
15
15
  """Methods to compute analysis metrics of the model and the data."""
16
16
 
17
17
  from collections.abc import Mapping, Sequence
18
+ import dataclasses
18
19
  import itertools
19
20
  import numbers
20
21
  from typing import Any, Optional
@@ -53,6 +54,7 @@ def _validate_non_media_baseline_values_numbers(
53
54
 
54
55
 
55
56
  # TODO: Refactor the related unit tests to be under DataTensors.
57
+ @dataclasses.dataclass
56
58
  class DataTensors(backend.ExtensionType):
57
59
  """Container for data variable arguments of Analyzer methods.
58
60
 
@@ -175,12 +177,31 @@ class DataTensors(backend.ExtensionType):
175
177
  else None
176
178
  )
177
179
  self.time = (
178
- backend.to_tensor(time, dtype="string") if time is not None else None
180
+ backend.to_tensor(time, dtype=backend.string)
181
+ if time is not None
182
+ else None
179
183
  )
180
-
181
- def __validate__(self):
182
184
  self._validate_n_dims()
183
185
 
186
+ def __eq__(self, other: Any) -> bool:
187
+ """Provides safe equality comparison for mixed tensor/non-tensor fields."""
188
+ if type(self) is not type(other):
189
+ return NotImplemented
190
+ for field in dataclasses.fields(self):
191
+ a = getattr(self, field.name)
192
+ b = getattr(other, field.name)
193
+ if a is None and b is None:
194
+ continue
195
+ if a is None or b is None:
196
+ return False
197
+ try:
198
+ if not bool(np.all(backend.to_tensor(backend.equal(a, b)))):
199
+ return False
200
+ except (ValueError, TypeError):
201
+ if a != b:
202
+ return False
203
+ return True
204
+
184
205
  def total_spend(self) -> backend.Tensor | None:
185
206
  """Returns the total spend tensor.
186
207
 
@@ -216,7 +237,7 @@ class DataTensors(backend.ExtensionType):
216
237
  of the corresponding tensor in the `meridian` object. If all time
217
238
  dimensions are the same, returns `None`.
218
239
  """
219
- for field in self._tf_extension_type_fields():
240
+ for field in dataclasses.fields(self):
220
241
  new_tensor = getattr(self, field.name)
221
242
  if field.name == constants.RF_IMPRESSIONS:
222
243
  old_tensor = getattr(meridian.rf_tensors, field.name)
@@ -282,7 +303,7 @@ class DataTensors(backend.ExtensionType):
282
303
 
283
304
  def _validate_n_dims(self):
284
305
  """Raises an error if the tensors have the wrong number of dimensions."""
285
- for field in self._tf_extension_type_fields():
306
+ for field in dataclasses.fields(self):
286
307
  tensor = getattr(self, field.name)
287
308
  if tensor is None:
288
309
  continue
@@ -315,7 +336,7 @@ class DataTensors(backend.ExtensionType):
315
336
  Warning: If an attribute exists in the `DataTensors` object that is not in
316
337
  the `required_variables` list, it will be ignored.
317
338
  """
318
- for field in self._tf_extension_type_fields():
339
+ for field in dataclasses.fields(self):
319
340
  tensor = getattr(self, field.name)
320
341
  if tensor is None:
321
342
  continue
@@ -468,7 +489,7 @@ class DataTensors(backend.ExtensionType):
468
489
  ) -> Self:
469
490
  """Fills default values and returns a new DataTensors object."""
470
491
  output = {}
471
- for field in self._tf_extension_type_fields():
492
+ for field in dataclasses.fields(self):
472
493
  var_name = field.name
473
494
  if var_name not in required_fields:
474
495
  continue
@@ -489,7 +510,7 @@ class DataTensors(backend.ExtensionType):
489
510
  old_tensor = meridian.revenue_per_kpi
490
511
  elif var_name == constants.TIME:
491
512
  old_tensor = backend.to_tensor(
492
- meridian.input_data.time.values.tolist(), dtype="string"
513
+ meridian.input_data.time.values.tolist(), dtype=backend.string
493
514
  )
494
515
  else:
495
516
  continue
@@ -500,6 +521,7 @@ class DataTensors(backend.ExtensionType):
500
521
  return DataTensors(**output)
501
522
 
502
523
 
524
+ @dataclasses.dataclass
503
525
  class DistributionTensors(backend.ExtensionType):
504
526
  """Container for parameters distributions arguments of Analyzer methods."""
505
527
 
@@ -583,17 +605,19 @@ def _transformed_new_or_scaled(
583
605
 
584
606
  def _calc_rsquared(expected, actual):
585
607
  """Calculates r-squared between actual and expected outcome."""
586
- return 1 - np.nanmean((expected - actual) ** 2) / np.nanvar(actual)
608
+ return 1 - backend.nanmean((expected - actual) ** 2) / backend.nanvar(actual)
587
609
 
588
610
 
589
611
  def _calc_mape(expected, actual):
590
612
  """Calculates MAPE between actual and expected outcome."""
591
- return np.nanmean(np.abs((actual - expected) / actual))
613
+ return backend.nanmean(backend.absolute((actual - expected) / actual))
592
614
 
593
615
 
594
616
  def _calc_weighted_mape(expected, actual):
595
617
  """Calculates wMAPE between actual and expected outcome (weighted by actual)."""
596
- return np.nansum(np.abs(actual - expected)) / np.nansum(actual)
618
+ return backend.nansum(backend.absolute(actual - expected)) / backend.nansum(
619
+ actual
620
+ )
597
621
 
598
622
 
599
623
  def _warn_if_geo_arg_in_kwargs(**kwargs):
@@ -893,42 +917,37 @@ class Analyzer:
893
917
  )
894
918
  return result
895
919
 
896
- def _check_revenue_data_exists(self, use_kpi: bool = False):
897
- """Checks if the revenue data is available for the analysis.
920
+ def _use_kpi(self, use_kpi: bool = False) -> bool:
921
+ """Checks if KPI analysis should be used.
898
922
 
899
- In the `kpi_type=NON_REVENUE` case, `revenue_per_kpi` is required to perform
900
- the revenue analysis. If `revenue_per_kpi` is not defined, then the revenue
901
- data is not available and the revenue analysis (`use_kpi=False`) is not
902
- possible. Only the KPI analysis (`use_kpi=True`) is possible in this case.
923
+ If `use_kpi` is `True` but `kpi_type=REVENUE`, then `use_kpi` is ignored.
903
924
 
904
- In the `kpi_type=REVENUE` case, KPI is equal to revenue and setting
905
- `use_kpi=True` has no effect. Therefore, a warning is issued if the default
906
- `False` value of `use_kpi` is overridden by the user.
925
+ If `use_kpi` is `False`, then `revenue_per_kpi` is required to perform
926
+ the revenue analysis. Setting `use_kpi` to `False` in this case is ignored.
907
927
 
908
928
  Args:
909
- use_kpi: A boolean flag indicating whether to use KPI instead of revenue.
929
+ use_kpi: A boolean flag indicating whether KPI analysis should be used.
910
930
 
931
+ Returns:
932
+ A boolean flag indicating whether KPI analysis should be used.
911
933
  Raises:
912
- ValueError: If `use_kpi` is `False` and `revenue_per_kpi` is not defined.
913
- UserWarning: If `use_kpi` is `True` in the `kpi_type=REVENUE` case.
934
+ UserWarning: If the KPI type is revenue and use_kpi is True or if
935
+ `use_kpi=False` but `revenue_per_kpi` is not available.
914
936
  """
915
- if self._meridian.input_data.kpi_type == constants.NON_REVENUE:
916
- if not use_kpi and self._meridian.revenue_per_kpi is None:
917
- raise ValueError(
918
- "Revenue analysis is not available when `revenue_per_kpi` is"
919
- " unknown. Set `use_kpi=True` to perform KPI analysis instead."
920
- )
937
+ if use_kpi and self._meridian.input_data.kpi_type == constants.REVENUE:
938
+ warnings.warn(
939
+ "Setting `use_kpi=True` has no effect when `kpi_type=REVENUE`"
940
+ " since in this case, KPI is equal to revenue."
941
+ )
942
+ return False
921
943
 
922
- if self._meridian.input_data.kpi_type == constants.REVENUE:
923
- # In the `kpi_type=REVENUE` case, KPI is equal to revenue and
924
- # `revenue_per_kpi` is set to a tensor of 1s in the initialization of the
925
- # `InputData` object.
926
- assert self._meridian.revenue_per_kpi is not None
927
- if use_kpi:
928
- warnings.warn(
929
- "Setting `use_kpi=True` has no effect when `kpi_type=REVENUE`"
930
- " since in this case, KPI is equal to revenue."
931
- )
944
+ if not use_kpi and self._meridian.input_data.revenue_per_kpi is None:
945
+ warnings.warn(
946
+ "Revenue analysis is not available when `revenue_per_kpi` is"
947
+ " unknown. Defaulting to KPI analysis."
948
+ )
949
+
950
+ return use_kpi or self._meridian.input_data.revenue_per_kpi is None
932
951
 
933
952
  def _get_adstock_dataframe(
934
953
  self,
@@ -1404,8 +1423,14 @@ class Analyzer:
1404
1423
  "`selected_geos` must match the geo dimension names from "
1405
1424
  "meridian.InputData."
1406
1425
  )
1407
- geo_mask = [x in selected_geos for x in mmm.input_data.geo]
1408
- tensor = backend.boolean_mask(tensor, geo_mask, axis=geo_dim)
1426
+ geo_indices = [
1427
+ i for i, x in enumerate(mmm.input_data.geo) if x in selected_geos
1428
+ ]
1429
+ tensor = backend.gather(
1430
+ tensor,
1431
+ backend.to_tensor(geo_indices, dtype=backend.int32),
1432
+ axis=geo_dim,
1433
+ )
1409
1434
 
1410
1435
  if selected_times is not None:
1411
1436
  _validate_selected_times(
@@ -1416,10 +1441,21 @@ class Analyzer:
1416
1441
  comparison_arg_name="`tensor`",
1417
1442
  )
1418
1443
  if _is_str_list(selected_times):
1419
- time_mask = [x in selected_times for x in mmm.input_data.time]
1420
- tensor = backend.boolean_mask(tensor, time_mask, axis=time_dim)
1444
+ time_indices = [
1445
+ i for i, x in enumerate(mmm.input_data.time) if x in selected_times
1446
+ ]
1447
+ tensor = backend.gather(
1448
+ tensor,
1449
+ backend.to_tensor(time_indices, dtype=backend.int32),
1450
+ axis=time_dim,
1451
+ )
1421
1452
  elif _is_bool_list(selected_times):
1422
- tensor = backend.boolean_mask(tensor, selected_times, axis=time_dim)
1453
+ time_indices = [i for i, x in enumerate(selected_times) if x]
1454
+ tensor = backend.gather(
1455
+ tensor,
1456
+ backend.to_tensor(time_indices, dtype=backend.int32),
1457
+ axis=time_dim,
1458
+ )
1423
1459
 
1424
1460
  tensor_dims = "...gt" + "m" * has_media_dim
1425
1461
  output_dims = (
@@ -1475,19 +1511,19 @@ class Analyzer:
1475
1511
  calculated.
1476
1512
  new_data: An optional `DataTensors` container with optional new tensors:
1477
1513
  `media`, `reach`, `frequency`, `organic_media`, `organic_reach`,
1478
- `organic_frequency`, `non_media_treatments`, `controls`. If `None`,
1479
- expected outcome is calculated conditional on the original values of the
1480
- data tensors that the Meridian object was initialized with. If
1481
- `new_data` argument is used, expected outcome is calculated conditional
1482
- on the values of the tensors passed in `new_data` and on the original
1483
- values of the remaining unset tensors. For example,
1514
+ `organic_frequency`, `non_media_treatments`, `revenue_per_kpi`,
1515
+ `controls`. If `None`, expected outcome is calculated conditional on the
1516
+ original values of the data tensors that the Meridian object was
1517
+ initialized with. If `new_data` argument is used, expected outcome is
1518
+ calculated conditional on the values of the tensors passed in `new_data`
1519
+ and on the original values of the remaining unset tensors. For example,
1484
1520
  `expected_outcome(new_data=DataTensors(reach=new_reach,
1485
1521
  frequency=new_frequency))` calculates expected outcome conditional on
1486
1522
  the original `media`, `organic_media`, `organic_reach`,
1487
- `organic_frequency`, `non_media_treatments` and `controls` tensors and
1488
- on the new given values for `reach` and `frequency` tensors. The new
1489
- tensors' dimensions must match the dimensions of the corresponding
1490
- original tensors from `input_data`.
1523
+ `organic_frequency`, `non_media_treatments`, `revenue_per_kpi`, and
1524
+ `controls` tensors and on the new given values for `reach` and
1525
+ `frequency` tensors. The new tensors' dimensions must match the
1526
+ dimensions of the corresponding original tensors from `input_data`.
1491
1527
  selected_geos: Optional list of containing a subset of geos to include. By
1492
1528
  default, all geos are included.
1493
1529
  selected_times: Optional list of containing a subset of dates to include.
@@ -1521,8 +1557,7 @@ class Analyzer:
1521
1557
  or `sample_prior()` (for `use_posterior=False`) has not been called
1522
1558
  prior to calling this method.
1523
1559
  """
1524
-
1525
- self._check_revenue_data_exists(use_kpi)
1560
+ use_kpi = self._use_kpi(use_kpi)
1526
1561
  self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
1527
1562
  if self._meridian.is_national:
1528
1563
  _warn_if_geo_arg_in_kwargs(
@@ -1538,7 +1573,9 @@ class Analyzer:
1538
1573
  if new_data is None:
1539
1574
  new_data = DataTensors()
1540
1575
 
1541
- required_fields = constants.NON_REVENUE_DATA
1576
+ required_fields = (
1577
+ constants.PAID_DATA + constants.NON_PAID_DATA + (constants.CONTROLS,)
1578
+ )
1542
1579
  filled_tensors = new_data.validate_and_fill_missing_data(
1543
1580
  required_tensors_names=required_fields,
1544
1581
  meridian=self._meridian,
@@ -1592,7 +1629,7 @@ class Analyzer:
1592
1629
  if inverse_transform_outcome:
1593
1630
  outcome_means = self._meridian.kpi_transformer.inverse(outcome_means)
1594
1631
  if not use_kpi:
1595
- outcome_means *= self._meridian.revenue_per_kpi
1632
+ outcome_means *= filled_tensors.revenue_per_kpi
1596
1633
 
1597
1634
  return self.filter_and_aggregate_geos_and_times(
1598
1635
  outcome_means,
@@ -1721,7 +1758,7 @@ class Analyzer:
1721
1758
  Returns:
1722
1759
  Tensor of incremental outcome returned in terms of revenue or KPI.
1723
1760
  """
1724
- self._check_revenue_data_exists(use_kpi)
1761
+ use_kpi = self._use_kpi(use_kpi)
1725
1762
  if revenue_per_kpi is None:
1726
1763
  revenue_per_kpi = self._meridian.revenue_per_kpi
1727
1764
  t1 = self._meridian.kpi_transformer.inverse(
@@ -1734,7 +1771,17 @@ class Analyzer:
1734
1771
  return kpi
1735
1772
  return backend.einsum("gt,...gtm->...gtm", revenue_per_kpi, kpi)
1736
1773
 
1737
- @backend.function(jit_compile=True)
1774
+ @backend.function(
1775
+ jit_compile=True,
1776
+ static_argnames=[
1777
+ "inverse_transform_outcome",
1778
+ "use_kpi",
1779
+ "selected_geos",
1780
+ "selected_times",
1781
+ "aggregate_geos",
1782
+ "aggregate_times",
1783
+ ],
1784
+ )
1738
1785
  def _incremental_outcome_impl(
1739
1786
  self,
1740
1787
  data_tensors: DataTensors,
@@ -1804,7 +1851,7 @@ class Analyzer:
1804
1851
  Returns:
1805
1852
  Tensor containing the incremental outcome distribution.
1806
1853
  """
1807
- self._check_revenue_data_exists(use_kpi)
1854
+ use_kpi = self._use_kpi(use_kpi)
1808
1855
  if (
1809
1856
  data_tensors.non_media_treatments is not None
1810
1857
  and non_media_treatments_baseline_normalized is None
@@ -2005,7 +2052,7 @@ class Analyzer:
2005
2052
  with matching time dimensions.
2006
2053
  """
2007
2054
  mmm = self._meridian
2008
- self._check_revenue_data_exists(use_kpi)
2055
+ use_kpi = self._use_kpi(use_kpi)
2009
2056
  self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
2010
2057
  if self._meridian.is_national:
2011
2058
  _warn_if_geo_arg_in_kwargs(
@@ -2146,8 +2193,12 @@ class Analyzer:
2146
2193
  )
2147
2194
  incremental_outcome_temps = [None] * len(batch_starting_indices)
2148
2195
  dim_kwargs = {
2149
- "selected_geos": selected_geos,
2150
- "selected_times": selected_times,
2196
+ "selected_geos": (
2197
+ tuple(selected_geos) if selected_geos is not None else None
2198
+ ),
2199
+ "selected_times": (
2200
+ tuple(selected_times) if selected_times is not None else None
2201
+ ),
2151
2202
  "aggregate_geos": aggregate_geos,
2152
2203
  "aggregate_times": aggregate_times,
2153
2204
  }
@@ -2322,7 +2373,7 @@ class Analyzer:
2322
2373
  "selected_times": selected_times,
2323
2374
  "aggregate_geos": aggregate_geos,
2324
2375
  }
2325
- self._check_revenue_data_exists(use_kpi)
2376
+ use_kpi = self._use_kpi(use_kpi)
2326
2377
  self._validate_geo_and_time_granularity(**dim_kwargs)
2327
2378
  required_values = constants.PERFORMANCE_DATA
2328
2379
  if not new_data:
@@ -2431,6 +2482,7 @@ class Analyzer:
2431
2482
  (n_media_channels + n_rf_channels))`. The `n_geos` dimension is dropped if
2432
2483
  `aggregate_geos=True`.
2433
2484
  """
2485
+ use_kpi = self._use_kpi(use_kpi)
2434
2486
  dim_kwargs = {
2435
2487
  "selected_geos": selected_geos,
2436
2488
  "selected_times": selected_times,
@@ -2444,7 +2496,6 @@ class Analyzer:
2444
2496
  "include_non_paid_channels": False,
2445
2497
  "aggregate_times": True,
2446
2498
  }
2447
- self._check_revenue_data_exists(use_kpi)
2448
2499
  self._validate_geo_and_time_granularity(**dim_kwargs)
2449
2500
  required_values = constants.PERFORMANCE_DATA
2450
2501
  if not new_data:
@@ -2632,6 +2683,7 @@ class Analyzer:
2632
2683
  self,
2633
2684
  aggregate_geos: bool = False,
2634
2685
  aggregate_times: bool = False,
2686
+ use_kpi: bool = False,
2635
2687
  split_by_holdout_id: bool = False,
2636
2688
  non_media_baseline_values: Sequence[float] | None = None,
2637
2689
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
@@ -2643,6 +2695,8 @@ class Analyzer:
2643
2695
  summed over all of the regions.
2644
2696
  aggregate_times: Boolean. If `True`, the expected, baseline, and actual
2645
2697
  are summed over all of the time periods.
2698
+ use_kpi: If `True`, calculate the incremental KPI. Otherwise, calculate
2699
+ the incremental revenue using the revenue per KPI (if available).
2646
2700
  split_by_holdout_id: Boolean. If `True` and `holdout_id` exists, the data
2647
2701
  is split into `'Train'`, `'Test'`, and `'All Data'` subsections.
2648
2702
  non_media_baseline_values: Optional list of shape
@@ -2659,8 +2713,8 @@ class Analyzer:
2659
2713
  A dataset with the expected, baseline, and actual outcome metrics.
2660
2714
  """
2661
2715
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2716
+ use_kpi = self._use_kpi(use_kpi)
2662
2717
  mmm = self._meridian
2663
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
2664
2718
  can_split_by_holdout = self._can_split_by_holdout_id(split_by_holdout_id)
2665
2719
  expected_outcome = self.expected_outcome(
2666
2720
  aggregate_geos=False, aggregate_times=False, use_kpi=use_kpi
@@ -2828,7 +2882,7 @@ class Analyzer:
2828
2882
  self,
2829
2883
  use_posterior: bool,
2830
2884
  new_data: DataTensors | None = None,
2831
- use_kpi: bool | None = None,
2885
+ use_kpi: bool = False,
2832
2886
  include_non_paid_channels: bool = True,
2833
2887
  non_media_baseline_values: Sequence[float] | None = None,
2834
2888
  **kwargs,
@@ -2875,7 +2929,7 @@ class Analyzer:
2875
2929
  the end containing the total incremental outcome of all channels.
2876
2930
  """
2877
2931
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2878
- use_kpi = use_kpi or self._meridian.input_data.revenue_per_kpi is None
2932
+ use_kpi = self._use_kpi(use_kpi)
2879
2933
  incremental_outcome_m = self.incremental_outcome(
2880
2934
  use_posterior=use_posterior,
2881
2935
  new_data=new_data,
@@ -3004,6 +3058,7 @@ class Analyzer:
3004
3058
  interpretation by time period.
3005
3059
  """
3006
3060
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
3061
+ use_kpi = self._use_kpi(use_kpi)
3007
3062
  dim_kwargs = {
3008
3063
  "selected_geos": selected_geos,
3009
3064
  "selected_times": selected_times,
@@ -3146,16 +3201,19 @@ class Analyzer:
3146
3201
  ).where(lambda ds: ds.channel != constants.ALL_CHANNELS)
3147
3202
 
3148
3203
  if new_data.get_modified_times(self._meridian) is None:
3204
+ expected_outcome_fields = list(
3205
+ constants.PAID_DATA + constants.NON_PAID_DATA + (constants.CONTROLS,)
3206
+ )
3149
3207
  expected_outcome_prior = self.expected_outcome(
3150
3208
  use_posterior=False,
3151
- new_data=new_data.filter_fields(constants.NON_REVENUE_DATA),
3209
+ new_data=new_data.filter_fields(expected_outcome_fields),
3152
3210
  use_kpi=use_kpi,
3153
3211
  **dim_kwargs,
3154
3212
  **batched_kwargs,
3155
3213
  )
3156
3214
  expected_outcome_posterior = self.expected_outcome(
3157
3215
  use_posterior=True,
3158
- new_data=new_data.filter_fields(constants.NON_REVENUE_DATA),
3216
+ new_data=new_data.filter_fields(expected_outcome_fields),
3159
3217
  use_kpi=use_kpi,
3160
3218
  **dim_kwargs,
3161
3219
  **batched_kwargs,
@@ -3399,6 +3457,7 @@ class Analyzer:
3399
3457
  aggregate_geos: bool = True,
3400
3458
  aggregate_times: bool = True,
3401
3459
  non_media_baseline_values: Sequence[float] | None = None,
3460
+ use_kpi: bool = False,
3402
3461
  confidence_level: float = constants.DEFAULT_CONFIDENCE_LEVEL,
3403
3462
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
3404
3463
  ) -> xr.Dataset:
@@ -3420,6 +3479,8 @@ class Analyzer:
3420
3479
  `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
3421
3480
  `model_spec.non_media_baseline_values` is used, which defaults to the
3422
3481
  minimum value for each non_media treatment channel.
3482
+ use_kpi: Boolean. If `True`, the baseline summary metrics are calculated
3483
+ using KPI. If `False`, the metrics are calculated using revenue.
3423
3484
  confidence_level: Confidence level for media summary metrics credible
3424
3485
  intervals, represented as a value between zero and one.
3425
3486
  batch_size: Integer representing the maximum draws per chain in each
@@ -3435,7 +3496,7 @@ class Analyzer:
3435
3496
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
3436
3497
  # TODO: Change "pct_of_contribution" to a more accurate term.
3437
3498
 
3438
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
3499
+ use_kpi = self._use_kpi(use_kpi)
3439
3500
  dim_kwargs = {
3440
3501
  "selected_geos": selected_geos,
3441
3502
  "selected_times": selected_times,
@@ -3618,6 +3679,7 @@ class Analyzer:
3618
3679
  ValueError: If there are no channels with reach and frequency data.
3619
3680
  """
3620
3681
  dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
3682
+ use_kpi = self._use_kpi(use_kpi)
3621
3683
  new_data = new_data or DataTensors()
3622
3684
  if self._meridian.n_rf_channels == 0:
3623
3685
  raise ValueError(
@@ -3696,9 +3758,11 @@ class Analyzer:
3696
3758
  )
3697
3759
 
3698
3760
  optimal_frequency = [freq_grid[i] for i in optimal_freq_idx]
3699
- optimal_frequency_tensor = backend.to_tensor(
3700
- backend.ones_like(filled_data.rf_impressions) * optimal_frequency,
3701
- backend.float32,
3761
+ optimal_frequency_values = backend.to_tensor(
3762
+ optimal_frequency, dtype=backend.float32
3763
+ )
3764
+ optimal_frequency_tensor = (
3765
+ backend.ones_like(filled_data.rf_impressions) * optimal_frequency_values
3702
3766
  )
3703
3767
  optimal_reach = filled_data.rf_impressions / optimal_frequency_tensor
3704
3768
 
@@ -3783,10 +3847,7 @@ class Analyzer:
3783
3847
  attrs={
3784
3848
  constants.CONFIDENCE_LEVEL: confidence_level,
3785
3849
  constants.USE_POSTERIOR: use_posterior,
3786
- constants.IS_REVENUE_KPI: (
3787
- self._meridian.input_data.kpi_type == constants.REVENUE
3788
- or not use_kpi
3789
- ),
3850
+ constants.IS_REVENUE_KPI: not use_kpi,
3790
3851
  },
3791
3852
  )
3792
3853
 
@@ -3794,6 +3855,7 @@ class Analyzer:
3794
3855
  self,
3795
3856
  selected_geos: Sequence[str] | None = None,
3796
3857
  selected_times: Sequence[str] | None = None,
3858
+ use_kpi: bool = False,
3797
3859
  batch_size: int = constants.DEFAULT_BATCH_SIZE,
3798
3860
  ) -> xr.Dataset:
3799
3861
  """Calculates `R-Squared`, `MAPE`, and `wMAPE` goodness of fit metrics.
@@ -3824,6 +3886,8 @@ class Analyzer:
3824
3886
  default, all geos are included.
3825
3887
  selected_times: Optional list containing a subset of dates to include. By
3826
3888
  default, all time periods are included.
3889
+ use_kpi: Whether to use KPI or revenue scale for the predictive accuracy
3890
+ metrics.
3827
3891
  batch_size: Integer representing the maximum draws per chain in each
3828
3892
  batch. By default, `batch_size` is `100`. The calculation is run in
3829
3893
  batches to avoid memory exhaustion. If a memory error occurs, try
@@ -3837,7 +3901,7 @@ class Analyzer:
3837
3901
  is split into `'Train'`, `'Test'`, and `'All Data'` subsections, and the
3838
3902
  three metrics are computed for each.
3839
3903
  """
3840
- use_kpi = self._meridian.input_data.revenue_per_kpi is None
3904
+ use_kpi = self._use_kpi(use_kpi)
3841
3905
  if self._meridian.is_national:
3842
3906
  _warn_if_geo_arg_in_kwargs(
3843
3907
  selected_geos=selected_geos,
@@ -3858,10 +3922,10 @@ class Analyzer:
3858
3922
  ],
3859
3923
  constants.GEO_GRANULARITY: [constants.GEO, constants.NATIONAL],
3860
3924
  }
3861
- if self._meridian.revenue_per_kpi is not None:
3862
- input_tensor = self._meridian.kpi * self._meridian.revenue_per_kpi
3863
- else:
3925
+ if use_kpi:
3864
3926
  input_tensor = self._meridian.kpi
3927
+ else:
3928
+ input_tensor = self._meridian.kpi * self._meridian.revenue_per_kpi
3865
3929
  actual = np.asarray(
3866
3930
  self.filter_and_aggregate_geos_and_times(
3867
3931
  tensor=input_tensor,
@@ -3990,10 +4054,11 @@ class Analyzer:
3990
4054
  "sample_posterior() must be called prior to calling this method."
3991
4055
  )
3992
4056
 
3993
- def _transpose_first_two_dims(x: backend.Tensor) -> backend.Tensor:
3994
- n_dim = len(x.shape)
4057
+ def _transpose_first_two_dims(x: Any) -> backend.Tensor:
4058
+ x_tensor = backend.to_tensor(x)
4059
+ n_dim = len(x_tensor.shape)
3995
4060
  perm = [1, 0] + list(range(2, n_dim))
3996
- return backend.transpose(x, perm)
4061
+ return backend.transpose(x_tensor, perm)
3997
4062
 
3998
4063
  rhat = backend.mcmc.potential_scale_reduction({
3999
4064
  k: _transpose_first_two_dims(v)
@@ -4026,8 +4091,6 @@ class Analyzer:
4026
4091
  Returns:
4027
4092
  A DataFrame with the following columns:
4028
4093
 
4029
- * `n_params`: The number of respective parameters in the model.
4030
- * `avg_rhat`: The average R-hat value for the respective parameter.
4031
4094
  * `n_params`: The number of respective parameters in the model.
4032
4095
  * `avg_rhat`: The average R-hat value for the respective parameter.
4033
4096
  * `max_rhat`: The maximum R-hat value for the respective parameter.
@@ -4944,11 +5007,12 @@ class Analyzer:
4944
5007
  def get_aggregated_spend(
4945
5008
  self,
4946
5009
  new_data: DataTensors | None = None,
5010
+ selected_geos: Sequence[str] | None = None,
4947
5011
  selected_times: Sequence[str] | Sequence[bool] | None = None,
4948
5012
  include_media: bool = True,
4949
5013
  include_rf: bool = True,
4950
5014
  ) -> xr.DataArray:
4951
- """Gets the aggregated spend based on the selected time.
5015
+ """Gets the aggregated spend based on the selected geos and time.
4952
5016
 
4953
5017
  Args:
4954
5018
  new_data: An optional `DataTensors` object containing the new `media`,
@@ -4959,6 +5023,9 @@ class Analyzer:
4959
5023
  of all the remaining tensors. If any of the tensors in `new_data` is
4960
5024
  provided with a different number of time periods than in `InputData`,
4961
5025
  then all tensors must be provided with the same number of time periods.
5026
+ selected_geos: Optional list containing a subset of geos to include. By
5027
+ default, all geos are included. The selected geos should match those in
5028
+ `InputData.geo`.
4962
5029
  selected_times: Optional list containing either a subset of dates to
4963
5030
  include or booleans with length equal to the number of time periods in
4964
5031
  KPI data. By default, all time periods are included.
@@ -5003,10 +5070,11 @@ class Analyzer:
5003
5070
  aggregated_media_spend = empty_da
5004
5071
  else:
5005
5072
  aggregated_media_spend = self._impute_and_aggregate_spend(
5006
- selected_times,
5007
- filled_data.media,
5008
- filled_data.media_spend,
5009
- list(self._meridian.input_data.media_channel.values),
5073
+ selected_geos=selected_geos,
5074
+ selected_times=selected_times,
5075
+ media_execution_values=filled_data.media,
5076
+ channel_spend=filled_data.media_spend,
5077
+ channel_names=list(self._meridian.input_data.media_channel.values),
5010
5078
  )
5011
5079
 
5012
5080
  if not include_rf:
@@ -5025,10 +5093,11 @@ class Analyzer:
5025
5093
  else:
5026
5094
  rf_execution_values = filled_data.reach * filled_data.frequency
5027
5095
  aggregated_rf_spend = self._impute_and_aggregate_spend(
5028
- selected_times,
5029
- rf_execution_values,
5030
- filled_data.rf_spend,
5031
- list(self._meridian.input_data.rf_channel.values),
5096
+ selected_geos=selected_geos,
5097
+ selected_times=selected_times,
5098
+ media_execution_values=rf_execution_values,
5099
+ channel_spend=filled_data.rf_spend,
5100
+ channel_names=list(self._meridian.input_data.rf_channel.values),
5032
5101
  )
5033
5102
 
5034
5103
  return xr.concat(
@@ -5037,21 +5106,26 @@ class Analyzer:
5037
5106
 
5038
5107
  def _impute_and_aggregate_spend(
5039
5108
  self,
5109
+ selected_geos: Sequence[str] | None,
5040
5110
  selected_times: Sequence[str] | Sequence[bool] | None,
5041
5111
  media_execution_values: backend.Tensor,
5042
5112
  channel_spend: backend.Tensor,
5043
5113
  channel_names: Sequence[str],
5044
5114
  ) -> xr.DataArray:
5045
- """Imputes and aggregates the spend over the selected time period.
5115
+ """Imputes and aggregates the spend within selected dimensions.
5046
5116
 
5047
- This function is used to aggregate the spend over the selected time period.
5048
- Imputation is required when `channel_spend` has only one dimension and the
5049
- aggregation is applied to only a subset of times, as specified by
5050
- `selected_times`. The `media_execution_values` argument only serves the
5051
- purpose of imputation. Although `media_execution_values` is a required
5052
- argument, its values only affect the output when imputation is required.
5117
+ This function is used to aggregate the spend within selected geos over the
5118
+ selected time period. Imputation is required when `channel_spend` has only
5119
+ one dimension and the aggregation is applied to only a subset of geos or
5120
+ times, as specified by `selected_geos` and `selected_times`. The
5121
+ `media_execution_values` argument only serves the purpose of imputation.
5122
+ Although `media_execution_values` is a required argument, its values only
5123
+ affect the output when imputation is required.
5053
5124
 
5054
5125
  Args:
5126
+ selected_geos: Optional list containing a subset of geos to include. By
5127
+ default, all geos are included. The selected geos should match those in
5128
+ `InputData.geo`.
5055
5129
  selected_times: Optional list containing either a subset of dates to
5056
5130
  include or booleans with length equal to the number of time periods in
5057
5131
  KPI data. By default, all time periods are included.
@@ -5066,7 +5140,7 @@ class Analyzer:
5066
5140
  variable `spend`.
5067
5141
  """
5068
5142
  dim_kwargs = {
5069
- "selected_geos": None,
5143
+ "selected_geos": selected_geos,
5070
5144
  "selected_times": selected_times,
5071
5145
  "aggregate_geos": True,
5072
5146
  "aggregate_times": True,