google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/constants.py CHANGED
@@ -66,6 +66,7 @@ FREQUENCY = 'frequency'
66
66
  RF_IMPRESSIONS = 'rf_impressions'
67
67
  RF_SPEND = 'rf_spend'
68
68
  ORGANIC_MEDIA = 'organic_media'
69
+ ORGANIC_RF = 'organic_rf'
69
70
  ORGANIC_REACH = 'organic_reach'
70
71
  ORGANIC_FREQUENCY = 'organic_frequency'
71
72
  NON_MEDIA_TREATMENTS = 'non_media_treatments'
@@ -143,6 +144,7 @@ MEDIA_CHANNEL = 'media_channel'
143
144
  RF_CHANNEL = 'rf_channel'
144
145
  CHANNEL = 'channel'
145
146
  RF = 'rf'
147
+ ORGANIC_RF = 'organic_rf'
146
148
  ORGANIC_MEDIA_CHANNEL = 'organic_media_channel'
147
149
  ORGANIC_RF_CHANNEL = 'organic_rf_channel'
148
150
  NON_MEDIA_CHANNEL = 'non_media_channel'
@@ -212,9 +214,11 @@ NON_PAID_TREATMENT_PRIOR_TYPES = frozenset({
212
214
  TREATMENT_PRIOR_TYPE_COEFFICIENT,
213
215
  TREATMENT_PRIOR_TYPE_CONTRIBUTION,
214
216
  })
215
- PAID_MEDIA_ROI_PRIOR_TYPES = frozenset(
216
- {TREATMENT_PRIOR_TYPE_ROI, TREATMENT_PRIOR_TYPE_MROI}
217
- )
217
+ PAID_MEDIA_ROI_PRIOR_TYPES = frozenset({
218
+ TREATMENT_PRIOR_TYPE_ROI,
219
+ TREATMENT_PRIOR_TYPE_MROI,
220
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION,
221
+ })
218
222
  # Represents a 1% increase in spend.
219
223
  MROI_FACTOR = 1.01
220
224
 
@@ -315,6 +319,41 @@ RF_PARAMETER_NAMES = (
315
319
  BETA_RF,
316
320
  BETA_GRF,
317
321
  )
322
+ ORGANIC_MEDIA_PARAMETER_NAMES = (
323
+ CONTRIBUTION_OM,
324
+ BETA_OM,
325
+ ETA_OM,
326
+ ALPHA_OM,
327
+ EC_OM,
328
+ SLOPE_OM,
329
+ BETA_GOM,
330
+ )
331
+ ORGANIC_RF_PARAMETER_NAMES = (
332
+ CONTRIBUTION_ORF,
333
+ BETA_ORF,
334
+ ETA_ORF,
335
+ ALPHA_ORF,
336
+ EC_ORF,
337
+ SLOPE_ORF,
338
+ BETA_GORF,
339
+ )
340
+ NON_MEDIA_PARAMETER_NAMES = (
341
+ CONTRIBUTION_N,
342
+ GAMMA_N,
343
+ XI_N,
344
+ GAMMA_GN,
345
+ )
346
+ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
347
+ SLOPE_M,
348
+ SLOPE_OM,
349
+ XI_N,
350
+ XI_C,
351
+ ETA_M,
352
+ ETA_RF,
353
+ ETA_OM,
354
+ ETA_ORF,
355
+ )
356
+
318
357
 
319
358
  MEDIA_PARAMETERS = (
320
359
  ROI_M,
@@ -501,10 +540,17 @@ ADSTOCK_HILL_FUNCTIONS = frozenset({
501
540
  'hill',
502
541
  })
503
542
 
543
+ # Adstock decay functions.
544
+ GEOMETRIC_DECAY = 'geometric'
545
+ BINOMIAL_DECAY = 'binomial'
546
+
547
+ ADSTOCK_DECAY_FUNCTIONS = frozenset({GEOMETRIC_DECAY, BINOMIAL_DECAY})
548
+ ADSTOCK_CHANNELS = (MEDIA, RF, ORGANIC_MEDIA, ORGANIC_RF)
504
549
 
505
550
  # Distribution constants.
506
551
  DISTRIBUTION = 'distribution'
507
552
  DISTRIBUTION_TYPE = 'distribution_type'
553
+ INDEPENDENT_MULTIVARIATE = 'IndependentMultivariate'
508
554
  PRIOR = 'prior'
509
555
  POSTERIOR = 'posterior'
510
556
  # Prior mean proportion of KPI incremental due to all media.
@@ -710,3 +756,13 @@ WEEKLY = 'weekly'
710
756
  QUARTERLY = 'quarterly'
711
757
  TIME_GRANULARITIES = frozenset({WEEKLY, QUARTERLY})
712
758
  QUARTERLY_SUMMARY_THRESHOLD_WEEKS = 52
759
+
760
+ # Automatic Knot Selection constants
761
+ KNOTS_SELECTED = 'knots_selected'
762
+ SELECTION_COEFS = 'selection_coefs'
763
+ MODEL = 'model'
764
+ REGRESSION_COEFS = 'regression_coefs'
765
+ SELECTED_MATRIX = 'selected_matrix'
766
+ AIC = 'aic'
767
+ BIC = 'bic'
768
+ EBIC = 'ebic'
@@ -442,6 +442,54 @@ class InputData:
442
442
  """Checks whether the `rf_spend` array has a time dimension."""
443
443
  return self.rf_spend is not None and constants.TIME in self.rf_spend.coords
444
444
 
445
+ @property
446
+ def scaled_centered_kpi(self) -> np.ndarray:
447
+ """Calculates scaled and centered KPI values.
448
+
449
+ Returns:
450
+ An array of KPI values that have been population-scaled and
451
+ mean-centered by geo.
452
+ """
453
+ kpi = self.kpi.values
454
+ population = self.population.values[:, np.newaxis]
455
+
456
+ population_scaled_kpi = np.divide(
457
+ kpi, population, out=np.zeros_like(kpi), where=(population != 0)
458
+ )
459
+ population_scaled_mean = np.mean(population_scaled_kpi)
460
+ population_scaled_stdev = np.std(population_scaled_kpi)
461
+ kpi_scaled = np.divide(
462
+ population_scaled_kpi - population_scaled_mean,
463
+ population_scaled_stdev,
464
+ out=np.zeros_like(population_scaled_kpi - population_scaled_mean),
465
+ where=(population_scaled_stdev != 0),
466
+ )
467
+ return kpi_scaled - np.mean(kpi_scaled, axis=1, keepdims=True)
468
+
469
+ def copy(self, deep: bool = True) -> "InputData":
470
+ """Returns a copy of the InputData instance.
471
+
472
+ Args:
473
+ deep: If True, a deep copy is made, meaning all xarray.DataArray objects
474
+ are also deepcopied. If False, a shallow copy is made.
475
+
476
+ Returns:
477
+ A new InputData instance.
478
+ """
479
+ if not deep:
480
+ return dataclasses.replace(self)
481
+
482
+ copied_fields = {}
483
+ for field in dataclasses.fields(self):
484
+ value = getattr(self, field.name)
485
+ if isinstance(value, xr.DataArray):
486
+ copied_fields[field.name] = value.copy(deep=True)
487
+ else:
488
+ # For other types, dataclasses.replace does a shallow copy.
489
+ copied_fields[field.name] = value
490
+
491
+ return InputData(**copied_fields)
492
+
445
493
  def _validate_scenarios(self):
446
494
  """Verifies that calibration and analysis is set correctly."""
447
495
  n_geos = len(self.kpi.coords[constants.GEO])
@@ -848,6 +896,32 @@ class InputData:
848
896
  raise ValueError("Both RF and media channel values are missing.")
849
897
  # pytype: enable=attribute-error
850
898
 
899
+ def get_all_adstock_hill_channels(self) -> np.ndarray:
900
+ """Returns all channel dimensions that adstock hill is applied to.
901
+
902
+ RF, organic media and organic RF channels are concatenated to the end of the
903
+ media channels if they are present.
904
+ """
905
+ adstock_hill_channels = []
906
+
907
+ if self.media_channel is not None:
908
+ adstock_hill_channels.append(self.media_channel.values)
909
+
910
+ if self.rf_channel is not None:
911
+ adstock_hill_channels.append(self.rf_channel.values)
912
+
913
+ if self.organic_media_channel is not None:
914
+ adstock_hill_channels.append(self.organic_media_channel.values)
915
+
916
+ if self.organic_rf_channel is not None:
917
+ adstock_hill_channels.append(self.organic_rf_channel.values)
918
+
919
+ if not adstock_hill_channels:
920
+ raise ValueError("Media, RF, organic media and organic RF channels are "
921
+ "all missing.")
922
+
923
+ return np.concatenate(adstock_hill_channels, axis=None)
924
+
851
925
  def get_paid_channels_argument_builder(
852
926
  self,
853
927
  ) -> arg_builder.OrderedListArgumentBuilder:
@@ -870,6 +944,26 @@ class InputData:
870
944
  raise ValueError("There are no RF channels in the input data.")
871
945
  return arg_builder.OrderedListArgumentBuilder(self.rf_channel.values)
872
946
 
947
+ def get_organic_media_channels_argument_builder(
948
+ self
949
+ ) -> arg_builder.OrderedListArgumentBuilder:
950
+ """Returns an argument builder for *organic* media channels *only*."""
951
+ if self.organic_media_channel is None:
952
+ raise ValueError("There are no organic media channels in the input data.")
953
+ return arg_builder.OrderedListArgumentBuilder(
954
+ self.organic_media_channel.values
955
+ )
956
+
957
+ def get_organic_rf_channels_argument_builder(
958
+ self
959
+ ) -> arg_builder.OrderedListArgumentBuilder:
960
+ """Returns an argument builder for *organic* RF channels *only*."""
961
+ if self.organic_rf_channel is None:
962
+ raise ValueError("There are no organic RF channels in the input data.")
963
+ return arg_builder.OrderedListArgumentBuilder(
964
+ self.organic_rf_channel.values
965
+ )
966
+
873
967
  def get_all_channels(self) -> np.ndarray:
874
968
  """Returns all the channel dimensions.
875
969
 
@@ -21,6 +21,7 @@ import immutabledict
21
21
  from meridian import constants as c
22
22
  from meridian.data import input_data
23
23
  from meridian.data import load
24
+ from meridian.model import knots
24
25
  import numpy as np
25
26
  import pandas as pd
26
27
  import xarray as xr
@@ -584,6 +585,47 @@ NATIONAL_COORD_TO_COLUMNS_WO_POPULATION_W_GEO = dataclasses.replace(
584
585
  geo='geo',
585
586
  )
586
587
 
588
+ ADSTOCK_DECAY_SPEC_CASES = immutabledict.immutabledict({
589
+ 'media': (
590
+ {},
591
+ {
592
+ 'ch_0': c.BINOMIAL_DECAY,
593
+ 'ch_1': c.GEOMETRIC_DECAY,
594
+ 'ch_2': c.GEOMETRIC_DECAY,
595
+ },
596
+ ),
597
+ 'rf': (
598
+ {},
599
+ {
600
+ 'rf_ch_0': c.BINOMIAL_DECAY,
601
+ 'rf_ch_1': c.GEOMETRIC_DECAY,
602
+ 'rf_ch_2': c.GEOMETRIC_DECAY,
603
+ 'rf_ch_3': c.BINOMIAL_DECAY,
604
+ },
605
+ ),
606
+ 'organic_media': (
607
+ {},
608
+ {
609
+ 'organic_media_0': c.BINOMIAL_DECAY,
610
+ 'organic_media_1': c.GEOMETRIC_DECAY,
611
+ 'organic_media_2': c.GEOMETRIC_DECAY,
612
+ 'organic_media_3': c.BINOMIAL_DECAY,
613
+ 'organic_media_4': c.GEOMETRIC_DECAY,
614
+ },
615
+ ),
616
+ 'organic_rf': (
617
+ {},
618
+ {
619
+ 'organic_rf_ch_0': c.BINOMIAL_DECAY,
620
+ 'organic_rf_ch_1': c.GEOMETRIC_DECAY,
621
+ 'organic_rf_ch_2': c.GEOMETRIC_DECAY,
622
+ 'organic_rf_ch_3': c.BINOMIAL_DECAY,
623
+ 'organic_rf_ch_4': c.BINOMIAL_DECAY,
624
+ 'organic_rf_ch_5': c.GEOMETRIC_DECAY,
625
+ },
626
+ ),
627
+ })
628
+
587
629
 
588
630
  def random_media_da(
589
631
  n_geos: int,
@@ -595,6 +637,7 @@ def random_media_da(
595
637
  explicit_geo_names: Sequence[str] | None = None,
596
638
  explicit_time_index: Sequence[str] | None = None,
597
639
  explicit_media_channel_names: Sequence[str] | None = None,
640
+ media_value_scales: list[tuple[float, float]] | None = None,
598
641
  array_name: str = 'media',
599
642
  channel_variable_name: str = 'media_channel',
600
643
  channel_prefix: str = 'ch_',
@@ -613,6 +656,8 @@ def random_media_da(
613
656
  explicit_time_index: If given, ignore `date_format` and use this as is
614
657
  explicit_media_channel_names: If given, ignore `n_media_channels` and use
615
658
  this as is
659
+ media_value_scales: A list of (mean, std) tuples, one for each media
660
+ channel, to control the scale of the generated random values.
616
661
  array_name: The name of the array to be created
617
662
  channel_variable_name: The name of the channel variable
618
663
  channel_prefix: The prefix of the channel names
@@ -628,11 +673,28 @@ def random_media_da(
628
673
  if n_times < n_media_times:
629
674
  start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
630
675
 
631
- media = np.round(
632
- abs(
633
- np.random.normal(5, 5, size=(n_geos, n_media_times, n_media_channels))
676
+ if media_value_scales:
677
+ if len(media_value_scales) != n_media_channels:
678
+ raise ValueError(
679
+ 'Length of media_value_scales must match n_media_channels.'
634
680
  )
635
- )
681
+ channel_data = []
682
+ for mean, std in media_value_scales:
683
+ channel_data.append(
684
+ np.round(
685
+ abs(np.random.normal(mean, std, size=(n_geos, n_media_times)))
686
+ )
687
+ )
688
+ media = np.stack(channel_data, axis=-1)
689
+ else:
690
+ media = np.round(
691
+ abs(
692
+ np.random.normal(
693
+ 5, 5, size=(n_geos, n_media_times, n_media_channels)
694
+ )
695
+ )
696
+ )
697
+
636
698
  if explicit_geo_names is None:
637
699
  geos = sample_geos(n_geos, integer_geos)
638
700
  else:
@@ -698,6 +760,7 @@ def random_media_spend_nd_da(
698
760
  n_media_channels: int | None = None,
699
761
  seed=0,
700
762
  integer_geos: bool = False,
763
+ explicit_media_channel_names: Sequence[str] | None = None,
701
764
  ) -> xr.DataArray:
702
765
  """Generates a sample N-dimensional `media_spend` DataArray.
703
766
 
@@ -716,6 +779,8 @@ def random_media_spend_nd_da(
716
779
  n_media_channels: Number of channels in the created `media_spend` array.
717
780
  seed: Random seed used by `np.random.seed()`.
718
781
  integer_geos: If True, the geos will be integers.
782
+ explicit_media_channel_names: If given, ignore `n_media_channels` and use
783
+ this as is.
719
784
 
720
785
  Returns:
721
786
  A DataArray containing the generated `media_spend` data with the given
@@ -733,9 +798,12 @@ def random_media_spend_nd_da(
733
798
  coords['time'] = _sample_times(n_times=n_times)
734
799
  if n_media_channels is not None:
735
800
  dims.append('media_channel')
736
- coords['media_channel'] = _sample_names(
737
- prefix='ch_', n_names=n_media_channels
738
- )
801
+ if explicit_media_channel_names is not None:
802
+ coords['media_channel'] = explicit_media_channel_names
803
+ else:
804
+ coords['media_channel'] = _sample_names(
805
+ prefix='ch_', n_names=n_media_channels
806
+ )
739
807
 
740
808
  if dims == ['geo', 'time', 'media_channel']:
741
809
  shape = (n_geos, n_times, n_media_channels)
@@ -822,6 +890,7 @@ def random_kpi_da(
822
890
  controls: xr.DataArray | None = None,
823
891
  seed: int = 0,
824
892
  integer_geos: bool = False,
893
+ kpi_data_pattern: str = '',
825
894
  ) -> xr.DataArray:
826
895
  """Generates a sample `kpi` DataArray."""
827
896
 
@@ -857,6 +926,22 @@ def random_kpi_da(
857
926
 
858
927
  error = np.random.normal(0, 2, size=(n_geos, n_times))
859
928
  kpi = abs(media_portion + control_portion + error)
929
+ if kpi_data_pattern == 'flat':
930
+ first_col = kpi[:, 0] # all rows will have value same as first col
931
+ kpi = (
932
+ first_col[:, np.newaxis]
933
+ + np.random.normal(scale=0.02, size=kpi.shape)
934
+ + 0.04
935
+ )
936
+ elif kpi_data_pattern == 'seasonal':
937
+ for row in kpi:
938
+ row.sort()
939
+ kpi = np.sin(kpi) + 5
940
+ elif kpi_data_pattern == 'peak':
941
+ peak_index = int(len(kpi[0]) / 2)
942
+ kpi[:] = kpi[0, 0]
943
+ for row in kpi:
944
+ row[peak_index] *= 3
860
945
 
861
946
  return xr.DataArray(
862
947
  kpi,
@@ -891,14 +976,18 @@ def constant_revenue_per_kpi(
891
976
 
892
977
 
893
978
  def random_population(
894
- n_geos: int, seed: int = 0, integer_geos: bool = False
979
+ n_geos: int,
980
+ seed: int = 0,
981
+ integer_geos: bool = False,
982
+ constant_value: float | None = None,
895
983
  ) -> xr.DataArray:
896
984
  """Generates a sample `population` DataArray."""
897
985
 
898
986
  np.random.seed(seed)
899
-
900
- population = np.round(10 + abs(np.random.normal(3000, 100, size=n_geos)))
901
-
987
+ if constant_value is not None:
988
+ population = np.full(n_geos, constant_value)
989
+ else:
990
+ population = np.round(10 + abs(np.random.normal(3000, 100, size=n_geos)))
902
991
  return xr.DataArray(
903
992
  population,
904
993
  dims=['geo'],
@@ -1170,11 +1259,15 @@ def random_dataset(
1170
1259
  n_organic_media_channels: int | None = None,
1171
1260
  n_organic_rf_channels: int | None = None,
1172
1261
  n_media_channels: int | None = None,
1262
+ explicit_media_channel_names: Sequence[str] | None = None,
1263
+ media_value_scales: list[tuple[float, float]] | None = None,
1173
1264
  n_rf_channels: int | None = None,
1174
1265
  revenue_per_kpi_value: float | None = 3.14,
1266
+ constant_population_value: float | None = None,
1175
1267
  seed: int = 0,
1176
1268
  remove_media_time: bool = False,
1177
1269
  integer_geos: bool = False,
1270
+ kpi_data_pattern: str = '',
1178
1271
  ) -> xr.Dataset:
1179
1272
  """Generates a random dataset."""
1180
1273
  if n_media_channels:
@@ -1185,11 +1278,14 @@ def random_dataset(
1185
1278
  n_media_channels=n_media_channels,
1186
1279
  seed=seed,
1187
1280
  integer_geos=integer_geos,
1281
+ explicit_media_channel_names=explicit_media_channel_names,
1282
+ media_value_scales=media_value_scales,
1188
1283
  )
1189
1284
  media_spend = random_media_spend_nd_da(
1190
1285
  n_geos=n_geos,
1191
1286
  n_times=n_times,
1192
1287
  n_media_channels=n_media_channels,
1288
+ explicit_media_channel_names=explicit_media_channel_names,
1193
1289
  seed=seed,
1194
1290
  integer_geos=integer_geos,
1195
1291
  )
@@ -1301,9 +1397,13 @@ def random_dataset(
1301
1397
  n_media_channels=n_media_channels or n_rf_channels or 0,
1302
1398
  n_controls=n_controls,
1303
1399
  integer_geos=integer_geos,
1400
+ kpi_data_pattern=kpi_data_pattern,
1304
1401
  )
1305
1402
  population = random_population(
1306
- n_geos=n_geos, seed=seed, integer_geos=integer_geos
1403
+ n_geos=n_geos,
1404
+ seed=seed,
1405
+ integer_geos=integer_geos,
1406
+ constant_value=constant_population_value,
1307
1407
  )
1308
1408
 
1309
1409
  dataset = xr.combine_by_coords(
@@ -1644,6 +1744,7 @@ def sample_input_data_revenue(
1644
1744
  n_organic_media_channels: int | None = None,
1645
1745
  n_organic_rf_channels: int | None = None,
1646
1746
  seed: int = 0,
1747
+ explicit_media_channel_names: Sequence[str] | None = None,
1647
1748
  ) -> input_data.InputData:
1648
1749
  """Generates sample InputData for `kpi_type='revenue'`."""
1649
1750
  dataset = random_dataset(
@@ -1658,6 +1759,7 @@ def sample_input_data_revenue(
1658
1759
  n_organic_rf_channels=n_organic_rf_channels,
1659
1760
  revenue_per_kpi_value=1.0,
1660
1761
  seed=seed,
1762
+ explicit_media_channel_names=explicit_media_channel_names,
1661
1763
  )
1662
1764
  return input_data.InputData(
1663
1765
  kpi=dataset.kpi,
@@ -1773,3 +1875,33 @@ def sample_input_data_non_revenue_no_revenue_per_kpi(
1773
1875
  if n_organic_rf_channels
1774
1876
  else None,
1775
1877
  )
1878
+
1879
+
1880
+ def sample_input_data_for_aks_with_expected_knot_info() -> (
1881
+ tuple[input_data.InputData, knots.KnotInfo]
1882
+ ):
1883
+ """Generates sample InputData and corresponding expected KnotInfo for testing.
1884
+
1885
+ Returns:
1886
+ A tuple containing:
1887
+ - InputData object with sample data.
1888
+ - KnotInfo object with expected knot information.
1889
+ """
1890
+ data = sample_input_data_from_dataset(
1891
+ random_dataset(
1892
+ n_geos=20,
1893
+ n_times=117,
1894
+ n_media_times=117,
1895
+ n_controls=2,
1896
+ n_media_channels=5,
1897
+ ),
1898
+ 'non_revenue',
1899
+ )
1900
+ expected_knot_info = knots.KnotInfo(
1901
+ n_knots=6,
1902
+ knot_locations=np.array([38, 39, 41, 48, 50, 55]),
1903
+ weights=knots.l1_distance_weights(
1904
+ 117, np.array([38, 39, 41, 48, 50, 55])
1905
+ ),
1906
+ )
1907
+ return data, expected_knot_info