google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,7 @@ import immutabledict
21
21
  from meridian import constants as c
22
22
  from meridian.data import input_data
23
23
  from meridian.data import load
24
+ from meridian.model import knots
24
25
  import numpy as np
25
26
  import pandas as pd
26
27
  import xarray as xr
@@ -584,6 +585,47 @@ NATIONAL_COORD_TO_COLUMNS_WO_POPULATION_W_GEO = dataclasses.replace(
584
585
  geo='geo',
585
586
  )
586
587
 
588
+ ADSTOCK_DECAY_SPEC_CASES = immutabledict.immutabledict({
589
+ 'media': (
590
+ {},
591
+ {
592
+ 'ch_0': c.BINOMIAL_DECAY,
593
+ 'ch_1': c.GEOMETRIC_DECAY,
594
+ 'ch_2': c.GEOMETRIC_DECAY,
595
+ },
596
+ ),
597
+ 'rf': (
598
+ {},
599
+ {
600
+ 'rf_ch_0': c.BINOMIAL_DECAY,
601
+ 'rf_ch_1': c.GEOMETRIC_DECAY,
602
+ 'rf_ch_2': c.GEOMETRIC_DECAY,
603
+ 'rf_ch_3': c.BINOMIAL_DECAY,
604
+ },
605
+ ),
606
+ 'organic_media': (
607
+ {},
608
+ {
609
+ 'organic_media_0': c.BINOMIAL_DECAY,
610
+ 'organic_media_1': c.GEOMETRIC_DECAY,
611
+ 'organic_media_2': c.GEOMETRIC_DECAY,
612
+ 'organic_media_3': c.BINOMIAL_DECAY,
613
+ 'organic_media_4': c.GEOMETRIC_DECAY,
614
+ },
615
+ ),
616
+ 'organic_rf': (
617
+ {},
618
+ {
619
+ 'organic_rf_ch_0': c.BINOMIAL_DECAY,
620
+ 'organic_rf_ch_1': c.GEOMETRIC_DECAY,
621
+ 'organic_rf_ch_2': c.GEOMETRIC_DECAY,
622
+ 'organic_rf_ch_3': c.BINOMIAL_DECAY,
623
+ 'organic_rf_ch_4': c.BINOMIAL_DECAY,
624
+ 'organic_rf_ch_5': c.GEOMETRIC_DECAY,
625
+ },
626
+ ),
627
+ })
628
+
587
629
 
588
630
  def random_media_da(
589
631
  n_geos: int,
@@ -595,6 +637,7 @@ def random_media_da(
595
637
  explicit_geo_names: Sequence[str] | None = None,
596
638
  explicit_time_index: Sequence[str] | None = None,
597
639
  explicit_media_channel_names: Sequence[str] | None = None,
640
+ media_value_scales: list[tuple[float, float]] | None = None,
598
641
  array_name: str = 'media',
599
642
  channel_variable_name: str = 'media_channel',
600
643
  channel_prefix: str = 'ch_',
@@ -613,6 +656,8 @@ def random_media_da(
613
656
  explicit_time_index: If given, ignore `date_format` and use this as is
614
657
  explicit_media_channel_names: If given, ignore `n_media_channels` and use
615
658
  this as is
659
+ media_value_scales: A list of (mean, std) tuples, one for each media
660
+ channel, to control the scale of the generated random values.
616
661
  array_name: The name of the array to be created
617
662
  channel_variable_name: The name of the channel variable
618
663
  channel_prefix: The prefix of the channel names
@@ -628,11 +673,28 @@ def random_media_da(
628
673
  if n_times < n_media_times:
629
674
  start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
630
675
 
631
- media = np.round(
632
- abs(
633
- np.random.normal(5, 5, size=(n_geos, n_media_times, n_media_channels))
676
+ if media_value_scales:
677
+ if len(media_value_scales) != n_media_channels:
678
+ raise ValueError(
679
+ 'Length of media_value_scales must match n_media_channels.'
634
680
  )
635
- )
681
+ channel_data = []
682
+ for mean, std in media_value_scales:
683
+ channel_data.append(
684
+ np.round(
685
+ abs(np.random.normal(mean, std, size=(n_geos, n_media_times)))
686
+ )
687
+ )
688
+ media = np.stack(channel_data, axis=-1)
689
+ else:
690
+ media = np.round(
691
+ abs(
692
+ np.random.normal(
693
+ 5, 5, size=(n_geos, n_media_times, n_media_channels)
694
+ )
695
+ )
696
+ )
697
+
636
698
  if explicit_geo_names is None:
637
699
  geos = sample_geos(n_geos, integer_geos)
638
700
  else:
@@ -698,6 +760,7 @@ def random_media_spend_nd_da(
698
760
  n_media_channels: int | None = None,
699
761
  seed=0,
700
762
  integer_geos: bool = False,
763
+ explicit_media_channel_names: Sequence[str] | None = None,
701
764
  ) -> xr.DataArray:
702
765
  """Generates a sample N-dimensional `media_spend` DataArray.
703
766
 
@@ -716,6 +779,8 @@ def random_media_spend_nd_da(
716
779
  n_media_channels: Number of channels in the created `media_spend` array.
717
780
  seed: Random seed used by `np.random.seed()`.
718
781
  integer_geos: If True, the geos will be integers.
782
+ explicit_media_channel_names: If given, ignore `n_media_channels` and use
783
+ this as is.
719
784
 
720
785
  Returns:
721
786
  A DataArray containing the generated `media_spend` data with the given
@@ -733,9 +798,12 @@ def random_media_spend_nd_da(
733
798
  coords['time'] = _sample_times(n_times=n_times)
734
799
  if n_media_channels is not None:
735
800
  dims.append('media_channel')
736
- coords['media_channel'] = _sample_names(
737
- prefix='ch_', n_names=n_media_channels
738
- )
801
+ if explicit_media_channel_names is not None:
802
+ coords['media_channel'] = explicit_media_channel_names
803
+ else:
804
+ coords['media_channel'] = _sample_names(
805
+ prefix='ch_', n_names=n_media_channels
806
+ )
739
807
 
740
808
  if dims == ['geo', 'time', 'media_channel']:
741
809
  shape = (n_geos, n_times, n_media_channels)
@@ -822,6 +890,7 @@ def random_kpi_da(
822
890
  controls: xr.DataArray | None = None,
823
891
  seed: int = 0,
824
892
  integer_geos: bool = False,
893
+ kpi_data_pattern: str = '',
825
894
  ) -> xr.DataArray:
826
895
  """Generates a sample `kpi` DataArray."""
827
896
 
@@ -857,6 +926,22 @@ def random_kpi_da(
857
926
 
858
927
  error = np.random.normal(0, 2, size=(n_geos, n_times))
859
928
  kpi = abs(media_portion + control_portion + error)
929
+ if kpi_data_pattern == 'flat':
930
+ first_col = kpi[:, 0] # all rows will have value same as first col
931
+ kpi = (
932
+ first_col[:, np.newaxis]
933
+ + np.random.normal(scale=0.02, size=kpi.shape)
934
+ + 0.04
935
+ )
936
+ elif kpi_data_pattern == 'seasonal':
937
+ for row in kpi:
938
+ row.sort()
939
+ kpi = np.sin(kpi) + 5
940
+ elif kpi_data_pattern == 'peak':
941
+ peak_index = int(len(kpi[0]) / 2)
942
+ kpi[:] = kpi[0, 0]
943
+ for row in kpi:
944
+ row[peak_index] *= 3
860
945
 
861
946
  return xr.DataArray(
862
947
  kpi,
@@ -891,14 +976,18 @@ def constant_revenue_per_kpi(
891
976
 
892
977
 
893
978
  def random_population(
894
- n_geos: int, seed: int = 0, integer_geos: bool = False
979
+ n_geos: int,
980
+ seed: int = 0,
981
+ integer_geos: bool = False,
982
+ constant_value: float | None = None,
895
983
  ) -> xr.DataArray:
896
984
  """Generates a sample `population` DataArray."""
897
985
 
898
986
  np.random.seed(seed)
899
-
900
- population = np.round(10 + abs(np.random.normal(3000, 100, size=n_geos)))
901
-
987
+ if constant_value is not None:
988
+ population = np.full(n_geos, constant_value)
989
+ else:
990
+ population = np.round(10 + abs(np.random.normal(3000, 100, size=n_geos)))
902
991
  return xr.DataArray(
903
992
  population,
904
993
  dims=['geo'],
@@ -1170,11 +1259,15 @@ def random_dataset(
1170
1259
  n_organic_media_channels: int | None = None,
1171
1260
  n_organic_rf_channels: int | None = None,
1172
1261
  n_media_channels: int | None = None,
1262
+ explicit_media_channel_names: Sequence[str] | None = None,
1263
+ media_value_scales: list[tuple[float, float]] | None = None,
1173
1264
  n_rf_channels: int | None = None,
1174
1265
  revenue_per_kpi_value: float | None = 3.14,
1266
+ constant_population_value: float | None = None,
1175
1267
  seed: int = 0,
1176
1268
  remove_media_time: bool = False,
1177
1269
  integer_geos: bool = False,
1270
+ kpi_data_pattern: str = '',
1178
1271
  ) -> xr.Dataset:
1179
1272
  """Generates a random dataset."""
1180
1273
  if n_media_channels:
@@ -1185,11 +1278,14 @@ def random_dataset(
1185
1278
  n_media_channels=n_media_channels,
1186
1279
  seed=seed,
1187
1280
  integer_geos=integer_geos,
1281
+ explicit_media_channel_names=explicit_media_channel_names,
1282
+ media_value_scales=media_value_scales,
1188
1283
  )
1189
1284
  media_spend = random_media_spend_nd_da(
1190
1285
  n_geos=n_geos,
1191
1286
  n_times=n_times,
1192
1287
  n_media_channels=n_media_channels,
1288
+ explicit_media_channel_names=explicit_media_channel_names,
1193
1289
  seed=seed,
1194
1290
  integer_geos=integer_geos,
1195
1291
  )
@@ -1301,9 +1397,13 @@ def random_dataset(
1301
1397
  n_media_channels=n_media_channels or n_rf_channels or 0,
1302
1398
  n_controls=n_controls,
1303
1399
  integer_geos=integer_geos,
1400
+ kpi_data_pattern=kpi_data_pattern,
1304
1401
  )
1305
1402
  population = random_population(
1306
- n_geos=n_geos, seed=seed, integer_geos=integer_geos
1403
+ n_geos=n_geos,
1404
+ seed=seed,
1405
+ integer_geos=integer_geos,
1406
+ constant_value=constant_population_value,
1307
1407
  )
1308
1408
 
1309
1409
  dataset = xr.combine_by_coords(
@@ -1644,6 +1744,7 @@ def sample_input_data_revenue(
1644
1744
  n_organic_media_channels: int | None = None,
1645
1745
  n_organic_rf_channels: int | None = None,
1646
1746
  seed: int = 0,
1747
+ explicit_media_channel_names: Sequence[str] | None = None,
1647
1748
  ) -> input_data.InputData:
1648
1749
  """Generates sample InputData for `kpi_type='revenue'`."""
1649
1750
  dataset = random_dataset(
@@ -1658,6 +1759,7 @@ def sample_input_data_revenue(
1658
1759
  n_organic_rf_channels=n_organic_rf_channels,
1659
1760
  revenue_per_kpi_value=1.0,
1660
1761
  seed=seed,
1762
+ explicit_media_channel_names=explicit_media_channel_names,
1661
1763
  )
1662
1764
  return input_data.InputData(
1663
1765
  kpi=dataset.kpi,
@@ -1773,3 +1875,35 @@ def sample_input_data_non_revenue_no_revenue_per_kpi(
1773
1875
  if n_organic_rf_channels
1774
1876
  else None,
1775
1877
  )
1878
+
1879
+
1880
+ def sample_input_data_for_aks_with_expected_knot_info() -> (
1881
+ tuple[input_data.InputData, knots.KnotInfo]
1882
+ ):
1883
+ """Generates sample InputData and corresponding expected KnotInfo for testing.
1884
+
1885
+ Returns:
1886
+ A tuple containing:
1887
+ - InputData object with sample data.
1888
+ - KnotInfo object with expected knot information.
1889
+ """
1890
+ data = sample_input_data_from_dataset(
1891
+ random_dataset(
1892
+ n_geos=20,
1893
+ n_times=117,
1894
+ n_media_times=117,
1895
+ n_controls=2,
1896
+ n_media_channels=5,
1897
+ ),
1898
+ 'non_revenue',
1899
+ )
1900
+ expected_knot_info = knots.KnotInfo(
1901
+ n_knots=13,
1902
+ knot_locations=np.array(
1903
+ [11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90]
1904
+ ),
1905
+ weights=knots.l1_distance_weights(
1906
+ 117, np.array([11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90])
1907
+ ),
1908
+ )
1909
+ return data, expected_knot_info
@@ -72,6 +72,7 @@ import json
72
72
  from typing import Any, Callable
73
73
 
74
74
  import arviz as az
75
+ from meridian import backend
75
76
  from meridian.analysis import visualizer
76
77
  import mlflow
77
78
  from mlflow.utils.autologging_utils import autologging_integration, safe_patch
@@ -81,7 +82,6 @@ from meridian.model import prior_sampler
81
82
  from meridian.model import spec
82
83
  from meridian.version import __version__
83
84
  import numpy as np
84
- import tensorflow_probability as tfp
85
85
 
86
86
 
87
87
  FLAVOR_NAME = "meridian"
@@ -123,7 +123,7 @@ def _log_priors(model_spec: spec.ModelSpec) -> None:
123
123
  field_value = getattr(priors, field.name)
124
124
 
125
125
  # Stringify Distributions and numpy arrays.
126
- if isinstance(field_value, tfp.distributions.Distribution):
126
+ if isinstance(field_value, backend.tfd.Distribution):
127
127
  field_value = str(field_value)
128
128
  elif isinstance(field_value, np.ndarray):
129
129
  field_value = json.dumps(field_value.tolist())