google-meridian 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1233,10 +1233,10 @@ class MediaEffects:
1233
1233
  bands for the Hill curves.
1234
1234
  """
1235
1235
  if c.MEDIA in list(df_channel_type[c.CHANNEL_TYPE]):
1236
- x_axis_title = 'Media Units per Capita'
1236
+ x_axis_title = summary_text.HILL_X_AXIS_MEDIA_LABEL
1237
1237
  shaded_area_title = summary_text.HILL_SHADED_REGION_MEDIA_LABEL
1238
1238
  else:
1239
- x_axis_title = 'Average Frequency'
1239
+ x_axis_title = summary_text.HILL_X_AXIS_RF_LABEL
1240
1240
  shaded_area_title = summary_text.HILL_SHADED_REGION_RF_LABEL
1241
1241
  domain_list = [
1242
1242
  c.POSTERIOR,
@@ -1380,6 +1380,7 @@ class MediaSummary:
1380
1380
  confidence_level: float = c.DEFAULT_CONFIDENCE_LEVEL,
1381
1381
  selected_times: Sequence[str] | None = None,
1382
1382
  marginal_roi_by_reach: bool = True,
1383
+ non_media_baseline_values: Sequence[str | float] | None = None,
1383
1384
  ):
1384
1385
  """Initializes the media summary metrics based on the model data and params.
1385
1386
 
@@ -1394,12 +1395,20 @@ class MediaSummary:
1394
1395
  next dollar spent only impacts reach, holding frequency constant. If
1395
1396
  this argument is False, we assume the next dollar spent only impacts
1396
1397
  frequency, holding reach constant.
1398
+ non_media_baseline_values: Optional list of shape (n_non_media_channels,).
1399
+ Each element is either a float (which means that the fixed value will be
1400
+ used as baseline for the given channel) or one of the strings "min" or
1401
+ "max" (which mean that the global minimum or maximum value will be used
1402
+ as baseline for the values of the given non_media treatment channel). If
1403
+ None, the minimum value is used as baseline for each non_media treatment
1404
+ channel.
1397
1405
  """
1398
1406
  self._meridian = meridian
1399
1407
  self._analyzer = analyzer.Analyzer(meridian)
1400
1408
  self._confidence_level = confidence_level
1401
1409
  self._selected_times = selected_times
1402
1410
  self._marginal_roi_by_reach = marginal_roi_by_reach
1411
+ self._non_media_baseline_values = non_media_baseline_values
1403
1412
 
1404
1413
  @functools.cached_property
1405
1414
  def paid_summary_metrics(self) -> xr.Dataset:
@@ -1438,6 +1447,7 @@ class MediaSummary:
1438
1447
  use_kpi=self._meridian.input_data.revenue_per_kpi is None,
1439
1448
  confidence_level=self._confidence_level,
1440
1449
  include_non_paid_channels=True,
1450
+ non_media_baseline_values=self._non_media_baseline_values,
1441
1451
  )
1442
1452
 
1443
1453
  def summary_table(
@@ -1560,6 +1570,7 @@ class MediaSummary:
1560
1570
  confidence_level: float | None = None,
1561
1571
  selected_times: Sequence[str] | None = None,
1562
1572
  marginal_roi_by_reach: bool = True,
1573
+ non_media_baseline_values: Sequence[str | float] | None = None,
1563
1574
  ):
1564
1575
  """Runs the computation for the media summary metrics with new parameters.
1565
1576
 
@@ -1574,10 +1585,18 @@ class MediaSummary:
1574
1585
  dollar spent only impacts reach, holding frequency constant. If `False`,
1575
1586
  the assumption is the next dollar spent only impacts frequency, holding
1576
1587
  reach constant.
1588
+ non_media_baseline_values: Optional list of shape (n_non_media_channels,).
1589
+ Each element is either a float (which means that the fixed value will be
1590
+ used as baseline for the given channel) or one of the strings "min" or
1591
+ "max" (which mean that the global minimum or maximum value will be used
1592
+ as baseline for the values of the given non_media treatment channel). If
1593
+ None, the minimum value is used as baseline for each non_media treatment
1594
+ channel.
1577
1595
  """
1578
1596
  self._confidence_level = confidence_level or self._confidence_level
1579
1597
  self._selected_times = selected_times
1580
1598
  self._marginal_roi_by_reach = marginal_roi_by_reach
1599
+ self._non_media_baseline_values = non_media_baseline_values
1581
1600
 
1582
1601
  def plot_contribution_waterfall_chart(self) -> alt.Chart:
1583
1602
  """Plots a waterfall chart of the contribution share per channel.
meridian/data/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  """Data handling API for Meridian."""
16
16
 
17
+ from meridian.data import arg_builder
17
18
  from meridian.data import input_data
18
19
  from meridian.data import load
19
20
  from meridian.data import time_coordinates
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Data argument builders for various API surfaces in Meridian."""
16
+
17
+ from collections.abc import Sequence
18
+ from typing import Generic, TypeVar
19
+
20
+
21
+ __all__ = [
22
+ 'OrderedListArgumentBuilder',
23
+ ]
24
+
25
+
26
+ T = TypeVar('T')
27
+
28
+
29
+ class OrderedListArgumentBuilder(Generic[T]):
30
+ """A simple builder for an argument that expects an ordered list of values.
31
+
32
+ For example, some Meridian function requires a list/array of some values that:
33
+
34
+ - Must have the same length as some channel coordinates in the input data.
35
+ - Must be in the same order as some channel coordinates in the input data.
36
+
37
+ This argument builder can be bound to one such channel coordinates list, and
38
+ provides a user-friendly and human-readable way to seed it with user given
39
+ values.
40
+
41
+ For example, to create an array (list) of values that must be indexed on the
42
+ *paid media* channels in the input data:
43
+
44
+ ```python
45
+ paid_media_channels_arg_builder = (
46
+ OrderedListArgumentBuilder[float](input_data.get_all_paid_channels())
47
+ )
48
+ # Note: rather than creating this builder directly, use methods like
49
+ # `InputData.get_all_paid_channels_argument_builder()` where the container
50
+ # determines which coordinates to bind the builder with.
51
+
52
+ # Use `.with_default_value()` to set a default value for coordinates that are
53
+ # not given in `__call__`.
54
+ paid_media_channels_arg_builder = (
55
+ paid_media_channels_arg_builder.with_default_value(0.3)
56
+ )
57
+
58
+ # Assuming we have paid media channels ['display', 'search', social'].
59
+ some_arg = paid_media_channels_arg_builder(
60
+ display=0.1,
61
+ social=0.25,
62
+ )
63
+ # some_arg == [0.1, 0.3, 0.25]
64
+ ```
65
+
66
+ See: `InputData.get_paid_channels_argument_builder()`.
67
+ """
68
+
69
+ def __init__(self, ordered_coords: Sequence[str]):
70
+ self._ordered_coords = list(ordered_coords)
71
+ self._default_value = None # Applied when a coordinate value is not given.
72
+
73
+ def with_default_value(
74
+ self, default_value: T
75
+ ) -> 'OrderedListArgumentBuilder':
76
+ """Sets the default value for coordinates that are not given in `__call__`.
77
+
78
+ Args:
79
+ default_value: The default value to use for coordinates that are not given
80
+ in `__call__`. If unset (or set to `None`), then `__call__` will raise
81
+ an error if any bound coordinate's value is not given.
82
+
83
+ Returns:
84
+ This builder itself for fluent chaining.
85
+ """
86
+ self._default_value = default_value
87
+ return self
88
+
89
+ def __call__(self, **kwargs) -> list[T]:
90
+ """Builds an ordered argument values list, given the bound coordinates list.
91
+
92
+ Args:
93
+ **kwargs: The keys in `kwargs` are channel names. All channel names must
94
+ be present in the `ordered_coords` bound to this builder.
95
+
96
+ Returns:
97
+ A list of values, in the same order as the `ordered_coords` bound to this
98
+ builder.
99
+ """
100
+ if self._default_value is None and (
101
+ set(kwargs.keys()) != set(self._ordered_coords)
102
+ ):
103
+ raise ValueError(
104
+ 'All coordinates must be present in the given keyword arguments: '
105
+ f'Given: {kwargs.keys()} vs Expected: {self._ordered_coords}'
106
+ )
107
+ return [kwargs.get(c, self._default_value) for c in self._ordered_coords]
@@ -24,6 +24,7 @@ import functools
24
24
  import warnings
25
25
 
26
26
  from meridian import constants
27
+ from meridian.data import arg_builder
27
28
  from meridian.data import time_coordinates as tc
28
29
  import numpy as np
29
30
  import xarray as xr
@@ -678,6 +679,28 @@ class InputData:
678
679
  raise ValueError("Both RF and media channel values are missing.")
679
680
  # pytype: enable=attribute-error
680
681
 
682
+ def get_paid_channels_argument_builder(
683
+ self,
684
+ ) -> arg_builder.OrderedListArgumentBuilder:
685
+ """Returns an argument builder for all *paid* channels."""
686
+ return arg_builder.OrderedListArgumentBuilder(self.get_all_paid_channels())
687
+
688
+ def get_paid_media_channels_argument_builder(
689
+ self,
690
+ ) -> arg_builder.OrderedListArgumentBuilder:
691
+ """Returns an argument builder for *paid* media channels *only*."""
692
+ if self.media_channel is None:
693
+ raise ValueError("There are no media channels in the input data.")
694
+ return arg_builder.OrderedListArgumentBuilder(self.media_channel.values)
695
+
696
+ def get_paid_rf_channels_argument_builder(
697
+ self,
698
+ ) -> arg_builder.OrderedListArgumentBuilder:
699
+ """Returns an argument builder for *paid* RF channels *only*."""
700
+ if self.rf_channel is None:
701
+ raise ValueError("There are no RF channels in the input data.")
702
+ return arg_builder.OrderedListArgumentBuilder(self.rf_channel.values)
703
+
681
704
  def get_all_channels(self) -> np.ndarray:
682
705
  """Returns all the channel dimensions.
683
706
 
@@ -1480,7 +1480,9 @@ def sample_coord_to_columns(
1480
1480
  )
1481
1481
 
1482
1482
 
1483
- def sample_input_data_from_dataset(dataset: xr.Dataset, kpi_type: str):
1483
+ def sample_input_data_from_dataset(
1484
+ dataset: xr.Dataset, kpi_type: str
1485
+ ) -> input_data.InputData:
1484
1486
  """Generates a sample `InputData` from a full xarray Dataset."""
1485
1487
  return input_data.InputData(
1486
1488
  kpi=dataset.kpi,
@@ -1507,7 +1509,7 @@ def sample_input_data_revenue(
1507
1509
  n_organic_media_channels: int | None = None,
1508
1510
  n_organic_rf_channels: int | None = None,
1509
1511
  seed: int = 0,
1510
- ):
1512
+ ) -> input_data.InputData:
1511
1513
  """Generates sample InputData for `kpi_type='revenue'`."""
1512
1514
  dataset = random_dataset(
1513
1515
  n_geos=n_geos,
@@ -1555,7 +1557,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1555
1557
  n_organic_media_channels: int | None = None,
1556
1558
  n_organic_rf_channels: int | None = None,
1557
1559
  seed: int = 0,
1558
- ):
1560
+ ) -> input_data.InputData:
1559
1561
  """Generates sample InputData for `non_revenue` KPI w/ revenue_per_kpi."""
1560
1562
  dataset = random_dataset(
1561
1563
  n_geos=n_geos,
@@ -1602,7 +1604,7 @@ def sample_input_data_non_revenue_no_revenue_per_kpi(
1602
1604
  n_organic_media_channels: int | None = None,
1603
1605
  n_organic_rf_channels: int | None = None,
1604
1606
  seed: int = 0,
1605
- ):
1607
+ ) -> input_data.InputData:
1606
1608
  """Generates sample InputData for `non_revenue` KPI w/o revenue_per_kpi."""
1607
1609
  dataset = random_dataset(
1608
1610
  n_geos=n_geos,
@@ -18,6 +18,8 @@ from meridian.model import adstock_hill
18
18
  from meridian.model import knots
19
19
  from meridian.model import media
20
20
  from meridian.model import model
21
+ from meridian.model import posterior_sampler
21
22
  from meridian.model import prior_distribution
23
+ from meridian.model import prior_sampler
22
24
  from meridian.model import spec
23
25
  from meridian.model import transformers