google-meridian 1.0.5__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -239,7 +239,7 @@ class ModelDiagnostics:
239
239
  groupby = posterior_df.columns.tolist()
240
240
  groupby.remove(parameter)
241
241
  plot = (
242
- alt.Chart(prior_posterior_df)
242
+ alt.Chart(prior_posterior_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
243
243
  .transform_density(
244
244
  parameter, groupby=groupby, as_=[parameter, 'density']
245
245
  )
@@ -332,7 +332,7 @@ class ModelDiagnostics:
332
332
  rhat = rhat.dropna(subset=[c.RHAT])
333
333
 
334
334
  boxplot = (
335
- alt.Chart(rhat)
335
+ alt.Chart(rhat, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
336
336
  .mark_boxplot(median={'color': c.BLUE_300}, outliers={'filled': True})
337
337
  .encode(
338
338
  x=alt.X(c.PARAMETER, axis=alt.Axis(labelAngle=-45)),
@@ -461,7 +461,7 @@ class ModelFit:
461
461
  else:
462
462
  y_axis_label = summary_text.KPI_LABEL
463
463
  plot = (
464
- alt.Chart(model_fit_df)
464
+ alt.Chart(model_fit_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
465
465
  .mark_line()
466
466
  .encode(
467
467
  x=alt.X(
@@ -762,7 +762,7 @@ class ReachAndFrequency:
762
762
  range=[c.BLUE_600, c.RED_600],
763
763
  )
764
764
 
765
- base = alt.Chart().transform_calculate(
765
+ base = alt.Chart(width=c.VEGALITE_FACET_DEFAULT_WIDTH).transform_calculate(
766
766
  optimal_freq=f"'{summary_text.OPTIMAL_FREQ_LABEL}'",
767
767
  expected_roi=f"'{summary_text.EXPECTED_ROI_LABEL}'",
768
768
  )
@@ -1012,7 +1012,7 @@ class MediaEffects:
1012
1012
  else:
1013
1013
  y_axis_label = summary_text.INC_KPI_LABEL
1014
1014
  base = (
1015
- alt.Chart(response_curves_df)
1015
+ alt.Chart(response_curves_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
1016
1016
  .transform_calculate(
1017
1017
  spend_level=(
1018
1018
  'datum.spend_multiplier >= 1.0 ? "Above current spend" : "Below'
@@ -1099,7 +1099,7 @@ class MediaEffects:
1099
1099
  An Altair plot showing the Adstock decay prior and posterior per media.
1100
1100
  """
1101
1101
  dataframe = self.adstock_decay_dataframe(confidence_level=confidence_level)
1102
- base = alt.Chart(dataframe)
1102
+ base = alt.Chart(dataframe, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
1103
1103
 
1104
1104
  scaled_confidence_level = int(confidence_level * 100)
1105
1105
 
@@ -1254,7 +1254,7 @@ class MediaEffects:
1254
1254
  ]
1255
1255
  range_list = [c.BLUE_700, c.GREY_600]
1256
1256
 
1257
- base = alt.Chart(df_channel_type)
1257
+ base = alt.Chart(df_channel_type, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
1258
1258
  color_scale = alt.Scale(
1259
1259
  domain=domain_list,
1260
1260
  range=range_list,
@@ -1274,7 +1274,7 @@ class MediaEffects:
1274
1274
  y2=f'{c.CI_HI}:Q',
1275
1275
  color=alt.Color(f'{c.DISTRIBUTION}:N', scale=color_scale),
1276
1276
  )
1277
- histogram = base.mark_bar(color=c.GREY_600, opacity=0.4).encode(
1277
+ histogram = base.mark_rect(color=c.GREY_600, opacity=0.4).encode(
1278
1278
  x=f'{c.START_INTERVAL_HISTOGRAM}:Q',
1279
1279
  x2=f'{c.END_INTERVAL_HISTOGRAM}:Q',
1280
1280
  y=alt.Y(f'{c.SCALED_COUNT_HISTOGRAM}:Q'),
@@ -1700,7 +1700,7 @@ class MediaSummary:
1700
1700
 
1701
1701
  domain = [c.BASELINE, c.ALL_CHANNELS]
1702
1702
  colors = [c.YELLOW_600, c.BLUE_700]
1703
- base = alt.Chart(outcome_df).encode(
1703
+ base = alt.Chart(outcome_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH).encode(
1704
1704
  alt.Theta(f'{c.PCT_OF_CONTRIBUTION}:Q', stack=True),
1705
1705
  alt.Color(
1706
1706
  f'{c.CHANNEL}:N',
@@ -1985,7 +1985,7 @@ class MediaSummary:
1985
1985
  axes_scale = alt.Scale(domain=(0, max_roi), nice=True)
1986
1986
 
1987
1987
  plot = (
1988
- alt.Chart(plot_df)
1988
+ alt.Chart(plot_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
1989
1989
  .mark_circle(tooltip=True, size=c.POINT_SIZE)
1990
1990
  .encode(
1991
1991
  x=alt.X(c.ROI, title='ROI', scale=axes_scale),
meridian/constants.py CHANGED
@@ -588,3 +588,6 @@ END_DATE = 'end_date'
588
588
  CARD_INSIGHTS = 'insights'
589
589
  CARD_CHARTS = 'charts'
590
590
  CARD_STATS = 'stats'
591
+
592
+ # VegaLite common params.
593
+ VEGALITE_FACET_DEFAULT_WIDTH = 400
@@ -18,6 +18,7 @@ The `InputData` class is used to store all the input data to the model.
18
18
  """
19
19
 
20
20
  from collections import abc
21
+ from collections.abc import Sequence
21
22
  import dataclasses
22
23
  import datetime as dt
23
24
  import functools
@@ -59,7 +60,7 @@ def _check_dim_collection(
59
60
  )
60
61
 
61
62
 
62
- def _check_dim_match(dim, arrays):
63
+ def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray]):
63
64
  """Verifies that the dimensions of the appropriate arrays match."""
64
65
  lengths = [len(array.coords[dim]) for array in arrays if array is not None]
65
66
  names = [array.name for array in arrays if array is not None]
@@ -69,6 +70,19 @@ def _check_dim_match(dim, arrays):
69
70
  )
70
71
 
71
72
 
73
+ def _check_coords_match(dim: str, arrays: Sequence[xr.DataArray]):
74
+ """Verifies that the coordinates of the appropriate arrays match."""
75
+ arrays = [arr for arr in arrays if arr is not None and dim in arr.coords]
76
+ if not arrays:
77
+ return
78
+ first_coords = arrays[0].coords[dim].values
79
+ for arr in arrays[1:]:
80
+ if not np.array_equal(arr.coords[dim].values, first_coords):
81
+ raise ValueError(
82
+ f"`{dim}` coordinates of array `{arr.name}` don't match."
83
+ )
84
+
85
+
72
86
  @dataclasses.dataclass
73
87
  class InputData:
74
88
  """A data container for advertising data in a format supported by Meridian.
@@ -242,6 +256,7 @@ class InputData:
242
256
  self._validate_media_channels()
243
257
  self._validate_time_formats()
244
258
  self._validate_times()
259
+ self._validate_geos()
245
260
 
246
261
  def _convert_geos_to_strings(self):
247
262
  """Converts geo coordinates to strings in all relevant DataArrays."""
@@ -542,11 +557,13 @@ class InputData:
542
557
  try:
543
558
  _ = self.time_coordinates.interval_days
544
559
  except ValueError as exc:
545
- raise ValueError("Time coordinates must be evenly spaced.") from exc
560
+ raise ValueError("Time coordinates must be regularly spaced.") from exc
546
561
  try:
547
562
  _ = self.media_time_coordinates.interval_days
548
563
  except ValueError as exc:
549
- raise ValueError("Media time coordinates must be evenly spaced.") from exc
564
+ raise ValueError(
565
+ "Media time coordinates must be regularly spaced."
566
+ ) from exc
550
567
 
551
568
  def _validate_time(self, array: xr.DataArray | None):
552
569
  """Validates the `time` dimension of the given `DataArray`.
@@ -617,6 +634,35 @@ class InputData:
617
634
  f" {constants.DATE_FORMAT}"
618
635
  ) from exc
619
636
 
637
+ def _check_unique_names(self, dim: str, array: xr.DataArray | None):
638
+ """Checks if a DataArray contains unique names on the specified dimension."""
639
+ if array is not None and dim in array.coords:
640
+ names = array.coords[dim].values.tolist()
641
+ if len(names) != len(set(names)):
642
+ raise ValueError(
643
+ f"`{dim}` names must be unique within the array `{array.name}`."
644
+ )
645
+
646
+ def _validate_geos(self):
647
+ """Validates geo coordinates across relevant DataArrays."""
648
+ arrays_with_geos = [
649
+ self.kpi,
650
+ self.revenue_per_kpi,
651
+ self.media,
652
+ self.controls,
653
+ self.population,
654
+ self.reach,
655
+ self.frequency,
656
+ self.organic_media,
657
+ self.organic_reach,
658
+ self.organic_frequency,
659
+ self.non_media_treatments,
660
+ ]
661
+ for array in arrays_with_geos:
662
+ self._check_unique_names(constants.GEO, array)
663
+
664
+ _check_coords_match(constants.GEO, arrays_with_geos)
665
+
620
666
  def as_dataset(self) -> xr.Dataset:
621
667
  """Returns data as a single `xarray.Dataset` object."""
622
668
  data = [
@@ -37,7 +37,7 @@ def _sample_names(prefix: str, n_names: int | None) -> list[str] | None:
37
37
  return [prefix + str(n) for n in range(n_names)] if n_names else None
38
38
 
39
39
 
40
- def _sample_geos(
40
+ def sample_geos(
41
41
  n_geos: int | None, integer_geos: bool = False
42
42
  ) -> list[str] | list[int] | None:
43
43
  """Generates a list of sample geos."""
@@ -519,6 +519,7 @@ def random_media_da(
519
519
  n_media_channels: int,
520
520
  seed: int = 0,
521
521
  date_format: str = c.DATE_FORMAT,
522
+ explicit_geo_names: Sequence[str] | None = None,
522
523
  explicit_time_index: Sequence[str] | None = None,
523
524
  explicit_media_channel_names: Sequence[str] | None = None,
524
525
  array_name: str = 'media',
@@ -535,6 +536,7 @@ def random_media_da(
535
536
  n_media_channels: Number of media channels
536
537
  seed: Random seed used by `np.random.seed()`
537
538
  date_format: The date format to use for time coordinate labels
539
+ explicit_geo_names: If given, ignore `n_geos` and use this as is.
538
540
  explicit_time_index: If given, ignore `date_format` and use this as is
539
541
  explicit_media_channel_names: If given, ignore `n_media_channels` and use
540
542
  this as is
@@ -558,6 +560,11 @@ def random_media_da(
558
560
  np.random.normal(5, 5, size=(n_geos, n_media_times, n_media_channels))
559
561
  )
560
562
  )
563
+ if explicit_geo_names is None:
564
+ geos = sample_geos(n_geos, integer_geos)
565
+ else:
566
+ geos = explicit_geo_names
567
+
561
568
  if explicit_time_index is None:
562
569
  media_time = _sample_times(
563
570
  n_times=n_media_times,
@@ -576,7 +583,7 @@ def random_media_da(
576
583
  media,
577
584
  dims=['geo', 'media_time', channel_variable_name],
578
585
  coords={
579
- 'geo': _sample_geos(n_geos, integer_geos),
586
+ 'geo': geos,
580
587
  'media_time': media_time,
581
588
  channel_variable_name: media_channels,
582
589
  },
@@ -647,7 +654,7 @@ def random_media_spend_nd_da(
647
654
  coords = {}
648
655
  if n_geos is not None:
649
656
  dims.append('geo')
650
- coords['geo'] = _sample_geos(n_geos, integer_geos)
657
+ coords['geo'] = sample_geos(n_geos, integer_geos)
651
658
  if n_times is not None:
652
659
  dims.append('time')
653
660
  coords['time'] = _sample_times(n_times=n_times)
@@ -719,7 +726,7 @@ def random_controls_da(
719
726
  controls,
720
727
  dims=['geo', 'time', 'control_variable'],
721
728
  coords={
722
- 'geo': _sample_geos(n_geos, integer_geos),
729
+ 'geo': sample_geos(n_geos, integer_geos),
723
730
  'time': (
724
731
  _sample_times(n_times=n_times, date_format=date_format)
725
732
  if explicit_time_index is None
@@ -775,7 +782,7 @@ def random_kpi_da(
775
782
  kpi,
776
783
  dims=['geo', 'time'],
777
784
  coords={
778
- 'geo': _sample_geos(n_geos, integer_geos),
785
+ 'geo': sample_geos(n_geos, integer_geos),
779
786
  'time': _sample_times(n_times=n_times),
780
787
  },
781
788
  name=c.KPI,
@@ -796,7 +803,7 @@ def constant_revenue_per_kpi(
796
803
  revenue_per_kpi,
797
804
  dims=['geo', 'time'],
798
805
  coords={
799
- 'geo': _sample_geos(n_geos, integer_geos),
806
+ 'geo': sample_geos(n_geos, integer_geos),
800
807
  'time': _sample_times(n_times=n_times),
801
808
  },
802
809
  name='revenue_per_kpi',
@@ -815,7 +822,7 @@ def random_population(
815
822
  return xr.DataArray(
816
823
  population,
817
824
  dims=['geo'],
818
- coords={'geo': _sample_geos(n_geos, integer_geos)},
825
+ coords={'geo': sample_geos(n_geos, integer_geos)},
819
826
  name='population',
820
827
  )
821
828
 
@@ -857,7 +864,7 @@ def random_reach_da(
857
864
  reach,
858
865
  dims=['geo', 'media_time', channel_variable_name],
859
866
  coords={
860
- 'geo': _sample_geos(n_geos, integer_geos),
867
+ 'geo': sample_geos(n_geos, integer_geos),
861
868
  'media_time': _sample_times(
862
869
  n_times=n_media_times, start_date=start_date
863
870
  ),
@@ -925,7 +932,7 @@ def random_frequency_da(
925
932
  frequency,
926
933
  dims=['geo', 'media_time', channel_variable_name],
927
934
  coords={
928
- 'geo': _sample_geos(n_geos, integer_geos),
935
+ 'geo': sample_geos(n_geos, integer_geos),
929
936
  'media_time': _sample_times(
930
937
  n_times=n_media_times, start_date=start_date
931
938
  ),
@@ -992,7 +999,7 @@ def random_rf_spend_nd_da(
992
999
  coords = {}
993
1000
  if n_geos is not None:
994
1001
  dims.append('geo')
995
- coords['geo'] = _sample_geos(n_geos, integer_geos)
1002
+ coords['geo'] = sample_geos(n_geos, integer_geos)
996
1003
  if n_times is not None:
997
1004
  dims.append('time')
998
1005
  coords['time'] = _sample_times(n_times=n_times)
@@ -1060,7 +1067,7 @@ def random_non_media_treatments_da(
1060
1067
  non_media_treatments,
1061
1068
  dims=['geo', 'time', 'non_media_channel'],
1062
1069
  coords={
1063
- 'geo': _sample_geos(n_geos, integer_geos),
1070
+ 'geo': sample_geos(n_geos, integer_geos),
1064
1071
  'time': (
1065
1072
  _sample_times(n_times=n_times, date_format=date_format)
1066
1073
  if explicit_time_index is None
@@ -19,7 +19,6 @@ import dataclasses
19
19
  import datetime
20
20
  import functools
21
21
  from typing import TypeAlias
22
- import warnings
23
22
 
24
23
  from meridian import constants
25
24
  import numpy as np
@@ -145,6 +144,10 @@ class TimeCoordinates:
145
144
  return cls(datetime_index=_to_pandas_datetime_index(dates))
146
145
 
147
146
  def __post_init__(self):
147
+ if len(self.datetime_index) <= 1:
148
+ raise ValueError(
149
+ "There must be more than one date index in the time coordinates."
150
+ )
148
151
  if not self.datetime_index.is_monotonic_increasing:
149
152
  raise ValueError(
150
153
  "Time coordinates must be strictly monotonically increasing."
@@ -162,28 +165,46 @@ class TimeCoordinates:
162
165
 
163
166
  @functools.cached_property
164
167
  def interval_days(self) -> int:
165
- """Returns the interval between two neighboring dates in `all_dates`.
168
+ """Returns the *mean* interval between two neighboring dates in `all_dates`.
166
169
 
167
170
  Raises:
168
- ValueError if the date index is not regularly spaced.
171
+ ValueError if the date index is not "regularly spaced".
169
172
  """
170
- # Calculate the difference between consecutive dates, in days.
171
- diff = self.datetime_index.to_series().diff().dt.days.dropna()
173
+ if not self._is_regular_time_index():
174
+ raise ValueError("Time coordinates are not regularly spaced!")
172
175
 
173
- if diff.nunique() == 0:
174
- # This edge case happens when there is only one date in the index.
175
- # This is unlikely to happen in practice, but we handle it just in case.
176
- warnings.warn(
177
- "The time coordinates only have one date. Returning an interval of 0."
178
- )
179
- return 0
176
+ # Calculate the difference between consecutive dates, in days.
177
+ diffs = self._interval_days
178
+ # Return the rounded mean interval.
179
+ return int(np.round(np.mean(diffs)))
180
180
 
181
- # Check for regularity.
182
- if diff.nunique() != 1:
183
- raise ValueError("`datetime_index` coordinates are not evenly spaced!")
181
+ @property
182
+ def _timedelta_index(self) -> pd.TimedeltaIndex:
183
+ """Returns the timedeltas between consecutive dates in `datetime_index`."""
184
+ return self.datetime_index.diff().dropna()
184
185
 
185
- # Finally, return the mode interval.
186
- return diff.mode()[0]
186
+ @property
187
+ def _interval_days(self) -> Sequence[int]:
188
+ """Converts `_timedelta_index` to a sequence of days for easier compute."""
189
+ return self._timedelta_index.days.to_numpy()
190
+
191
+ def _is_regular_time_index(self) -> bool:
192
+ """Returns True if the time index is "regularly spaced"."""
193
+ if np.all(self._interval_days == self._interval_days[0]):
194
+ # All intervals are regular. Base case.
195
+ return True
196
+ # Special cases:
197
+ # * Monthly cadences
198
+ if np.all(np.isin(self._interval_days, [28, 29, 30, 31])):
199
+ return True
200
+ # * Quarterly cadences
201
+ if np.all(np.isin(self._interval_days, [90, 91, 92])):
202
+ return True
203
+ # * Yearly cadences
204
+ if np.all(np.isin(self._interval_days, [365, 366])):
205
+ return True
206
+
207
+ return False
187
208
 
188
209
  def get_selected_dates(
189
210
  self,