google-meridian 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/METADATA +8 -2
- google_meridian-1.2.0.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +526 -362
- meridian/analysis/optimizer.py +275 -267
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +37 -49
- meridian/backend/__init__.py +514 -0
- meridian/backend/config.py +59 -0
- meridian/backend/test_utils.py +95 -0
- meridian/constants.py +59 -3
- meridian/data/input_data.py +94 -0
- meridian/data/test_utils.py +144 -12
- meridian/model/adstock_hill.py +279 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +306 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +323 -157
- meridian/model/posterior_sampler.py +81 -76
- meridian/model/prior_distribution.py +538 -168
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +53 -47
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/top_level.txt +0 -0
meridian/constants.py
CHANGED
|
@@ -66,6 +66,7 @@ FREQUENCY = 'frequency'
|
|
|
66
66
|
RF_IMPRESSIONS = 'rf_impressions'
|
|
67
67
|
RF_SPEND = 'rf_spend'
|
|
68
68
|
ORGANIC_MEDIA = 'organic_media'
|
|
69
|
+
ORGANIC_RF = 'organic_rf'
|
|
69
70
|
ORGANIC_REACH = 'organic_reach'
|
|
70
71
|
ORGANIC_FREQUENCY = 'organic_frequency'
|
|
71
72
|
NON_MEDIA_TREATMENTS = 'non_media_treatments'
|
|
@@ -143,6 +144,7 @@ MEDIA_CHANNEL = 'media_channel'
|
|
|
143
144
|
RF_CHANNEL = 'rf_channel'
|
|
144
145
|
CHANNEL = 'channel'
|
|
145
146
|
RF = 'rf'
|
|
147
|
+
ORGANIC_RF = 'organic_rf'
|
|
146
148
|
ORGANIC_MEDIA_CHANNEL = 'organic_media_channel'
|
|
147
149
|
ORGANIC_RF_CHANNEL = 'organic_rf_channel'
|
|
148
150
|
NON_MEDIA_CHANNEL = 'non_media_channel'
|
|
@@ -212,9 +214,11 @@ NON_PAID_TREATMENT_PRIOR_TYPES = frozenset({
|
|
|
212
214
|
TREATMENT_PRIOR_TYPE_COEFFICIENT,
|
|
213
215
|
TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
214
216
|
})
|
|
215
|
-
PAID_MEDIA_ROI_PRIOR_TYPES = frozenset(
|
|
216
|
-
|
|
217
|
-
|
|
217
|
+
PAID_MEDIA_ROI_PRIOR_TYPES = frozenset({
|
|
218
|
+
TREATMENT_PRIOR_TYPE_ROI,
|
|
219
|
+
TREATMENT_PRIOR_TYPE_MROI,
|
|
220
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
221
|
+
})
|
|
218
222
|
# Represents a 1% increase in spend.
|
|
219
223
|
MROI_FACTOR = 1.01
|
|
220
224
|
|
|
@@ -315,6 +319,41 @@ RF_PARAMETER_NAMES = (
|
|
|
315
319
|
BETA_RF,
|
|
316
320
|
BETA_GRF,
|
|
317
321
|
)
|
|
322
|
+
ORGANIC_MEDIA_PARAMETER_NAMES = (
|
|
323
|
+
CONTRIBUTION_OM,
|
|
324
|
+
BETA_OM,
|
|
325
|
+
ETA_OM,
|
|
326
|
+
ALPHA_OM,
|
|
327
|
+
EC_OM,
|
|
328
|
+
SLOPE_OM,
|
|
329
|
+
BETA_GOM,
|
|
330
|
+
)
|
|
331
|
+
ORGANIC_RF_PARAMETER_NAMES = (
|
|
332
|
+
CONTRIBUTION_ORF,
|
|
333
|
+
BETA_ORF,
|
|
334
|
+
ETA_ORF,
|
|
335
|
+
ALPHA_ORF,
|
|
336
|
+
EC_ORF,
|
|
337
|
+
SLOPE_ORF,
|
|
338
|
+
BETA_GORF,
|
|
339
|
+
)
|
|
340
|
+
NON_MEDIA_PARAMETER_NAMES = (
|
|
341
|
+
CONTRIBUTION_N,
|
|
342
|
+
GAMMA_N,
|
|
343
|
+
XI_N,
|
|
344
|
+
GAMMA_GN,
|
|
345
|
+
)
|
|
346
|
+
ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
|
|
347
|
+
SLOPE_M,
|
|
348
|
+
SLOPE_OM,
|
|
349
|
+
XI_N,
|
|
350
|
+
XI_C,
|
|
351
|
+
ETA_M,
|
|
352
|
+
ETA_RF,
|
|
353
|
+
ETA_OM,
|
|
354
|
+
ETA_ORF,
|
|
355
|
+
)
|
|
356
|
+
|
|
318
357
|
|
|
319
358
|
MEDIA_PARAMETERS = (
|
|
320
359
|
ROI_M,
|
|
@@ -501,10 +540,17 @@ ADSTOCK_HILL_FUNCTIONS = frozenset({
|
|
|
501
540
|
'hill',
|
|
502
541
|
})
|
|
503
542
|
|
|
543
|
+
# Adstock decay functions.
|
|
544
|
+
GEOMETRIC_DECAY = 'geometric'
|
|
545
|
+
BINOMIAL_DECAY = 'binomial'
|
|
546
|
+
|
|
547
|
+
ADSTOCK_DECAY_FUNCTIONS = frozenset({GEOMETRIC_DECAY, BINOMIAL_DECAY})
|
|
548
|
+
ADSTOCK_CHANNELS = (MEDIA, RF, ORGANIC_MEDIA, ORGANIC_RF)
|
|
504
549
|
|
|
505
550
|
# Distribution constants.
|
|
506
551
|
DISTRIBUTION = 'distribution'
|
|
507
552
|
DISTRIBUTION_TYPE = 'distribution_type'
|
|
553
|
+
INDEPENDENT_MULTIVARIATE = 'IndependentMultivariate'
|
|
508
554
|
PRIOR = 'prior'
|
|
509
555
|
POSTERIOR = 'posterior'
|
|
510
556
|
# Prior mean proportion of KPI incremental due to all media.
|
|
@@ -710,3 +756,13 @@ WEEKLY = 'weekly'
|
|
|
710
756
|
QUARTERLY = 'quarterly'
|
|
711
757
|
TIME_GRANULARITIES = frozenset({WEEKLY, QUARTERLY})
|
|
712
758
|
QUARTERLY_SUMMARY_THRESHOLD_WEEKS = 52
|
|
759
|
+
|
|
760
|
+
# Automatic Knot Selection constants
|
|
761
|
+
KNOTS_SELECTED = 'knots_selected'
|
|
762
|
+
SELECTION_COEFS = 'selection_coefs'
|
|
763
|
+
MODEL = 'model'
|
|
764
|
+
REGRESSION_COEFS = 'regression_coefs'
|
|
765
|
+
SELECTED_MATRIX = 'selected_matrix'
|
|
766
|
+
AIC = 'aic'
|
|
767
|
+
BIC = 'bic'
|
|
768
|
+
EBIC = 'ebic'
|
meridian/data/input_data.py
CHANGED
|
@@ -442,6 +442,54 @@ class InputData:
|
|
|
442
442
|
"""Checks whether the `rf_spend` array has a time dimension."""
|
|
443
443
|
return self.rf_spend is not None and constants.TIME in self.rf_spend.coords
|
|
444
444
|
|
|
445
|
+
@property
|
|
446
|
+
def scaled_centered_kpi(self) -> np.ndarray:
|
|
447
|
+
"""Calculates scaled and centered KPI values.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
An array of KPI values that have been population-scaled and
|
|
451
|
+
mean-centered by geo.
|
|
452
|
+
"""
|
|
453
|
+
kpi = self.kpi.values
|
|
454
|
+
population = self.population.values[:, np.newaxis]
|
|
455
|
+
|
|
456
|
+
population_scaled_kpi = np.divide(
|
|
457
|
+
kpi, population, out=np.zeros_like(kpi), where=(population != 0)
|
|
458
|
+
)
|
|
459
|
+
population_scaled_mean = np.mean(population_scaled_kpi)
|
|
460
|
+
population_scaled_stdev = np.std(population_scaled_kpi)
|
|
461
|
+
kpi_scaled = np.divide(
|
|
462
|
+
population_scaled_kpi - population_scaled_mean,
|
|
463
|
+
population_scaled_stdev,
|
|
464
|
+
out=np.zeros_like(population_scaled_kpi - population_scaled_mean),
|
|
465
|
+
where=(population_scaled_stdev != 0),
|
|
466
|
+
)
|
|
467
|
+
return kpi_scaled - np.mean(kpi_scaled, axis=1, keepdims=True)
|
|
468
|
+
|
|
469
|
+
def copy(self, deep: bool = True) -> "InputData":
|
|
470
|
+
"""Returns a copy of the InputData instance.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
deep: If True, a deep copy is made, meaning all xarray.DataArray objects
|
|
474
|
+
are also deepcopied. If False, a shallow copy is made.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
A new InputData instance.
|
|
478
|
+
"""
|
|
479
|
+
if not deep:
|
|
480
|
+
return dataclasses.replace(self)
|
|
481
|
+
|
|
482
|
+
copied_fields = {}
|
|
483
|
+
for field in dataclasses.fields(self):
|
|
484
|
+
value = getattr(self, field.name)
|
|
485
|
+
if isinstance(value, xr.DataArray):
|
|
486
|
+
copied_fields[field.name] = value.copy(deep=True)
|
|
487
|
+
else:
|
|
488
|
+
# For other types, dataclasses.replace does a shallow copy.
|
|
489
|
+
copied_fields[field.name] = value
|
|
490
|
+
|
|
491
|
+
return InputData(**copied_fields)
|
|
492
|
+
|
|
445
493
|
def _validate_scenarios(self):
|
|
446
494
|
"""Verifies that calibration and analysis is set correctly."""
|
|
447
495
|
n_geos = len(self.kpi.coords[constants.GEO])
|
|
@@ -848,6 +896,32 @@ class InputData:
|
|
|
848
896
|
raise ValueError("Both RF and media channel values are missing.")
|
|
849
897
|
# pytype: enable=attribute-error
|
|
850
898
|
|
|
899
|
+
def get_all_adstock_hill_channels(self) -> np.ndarray:
|
|
900
|
+
"""Returns all channel dimensions that adstock hill is applied to.
|
|
901
|
+
|
|
902
|
+
RF, organic media and organic RF channels are concatenated to the end of the
|
|
903
|
+
media channels if they are present.
|
|
904
|
+
"""
|
|
905
|
+
adstock_hill_channels = []
|
|
906
|
+
|
|
907
|
+
if self.media_channel is not None:
|
|
908
|
+
adstock_hill_channels.append(self.media_channel.values)
|
|
909
|
+
|
|
910
|
+
if self.rf_channel is not None:
|
|
911
|
+
adstock_hill_channels.append(self.rf_channel.values)
|
|
912
|
+
|
|
913
|
+
if self.organic_media_channel is not None:
|
|
914
|
+
adstock_hill_channels.append(self.organic_media_channel.values)
|
|
915
|
+
|
|
916
|
+
if self.organic_rf_channel is not None:
|
|
917
|
+
adstock_hill_channels.append(self.organic_rf_channel.values)
|
|
918
|
+
|
|
919
|
+
if not adstock_hill_channels:
|
|
920
|
+
raise ValueError("Media, RF, organic media and organic RF channels are "
|
|
921
|
+
"all missing.")
|
|
922
|
+
|
|
923
|
+
return np.concatenate(adstock_hill_channels, axis=None)
|
|
924
|
+
|
|
851
925
|
def get_paid_channels_argument_builder(
|
|
852
926
|
self,
|
|
853
927
|
) -> arg_builder.OrderedListArgumentBuilder:
|
|
@@ -870,6 +944,26 @@ class InputData:
|
|
|
870
944
|
raise ValueError("There are no RF channels in the input data.")
|
|
871
945
|
return arg_builder.OrderedListArgumentBuilder(self.rf_channel.values)
|
|
872
946
|
|
|
947
|
+
def get_organic_media_channels_argument_builder(
|
|
948
|
+
self
|
|
949
|
+
) -> arg_builder.OrderedListArgumentBuilder:
|
|
950
|
+
"""Returns an argument builder for *organic* media channels *only*."""
|
|
951
|
+
if self.organic_media_channel is None:
|
|
952
|
+
raise ValueError("There are no organic media channels in the input data.")
|
|
953
|
+
return arg_builder.OrderedListArgumentBuilder(
|
|
954
|
+
self.organic_media_channel.values
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
def get_organic_rf_channels_argument_builder(
|
|
958
|
+
self
|
|
959
|
+
) -> arg_builder.OrderedListArgumentBuilder:
|
|
960
|
+
"""Returns an argument builder for *organic* RF channels *only*."""
|
|
961
|
+
if self.organic_rf_channel is None:
|
|
962
|
+
raise ValueError("There are no organic RF channels in the input data.")
|
|
963
|
+
return arg_builder.OrderedListArgumentBuilder(
|
|
964
|
+
self.organic_rf_channel.values
|
|
965
|
+
)
|
|
966
|
+
|
|
873
967
|
def get_all_channels(self) -> np.ndarray:
|
|
874
968
|
"""Returns all the channel dimensions.
|
|
875
969
|
|
meridian/data/test_utils.py
CHANGED
|
@@ -21,6 +21,7 @@ import immutabledict
|
|
|
21
21
|
from meridian import constants as c
|
|
22
22
|
from meridian.data import input_data
|
|
23
23
|
from meridian.data import load
|
|
24
|
+
from meridian.model import knots
|
|
24
25
|
import numpy as np
|
|
25
26
|
import pandas as pd
|
|
26
27
|
import xarray as xr
|
|
@@ -584,6 +585,47 @@ NATIONAL_COORD_TO_COLUMNS_WO_POPULATION_W_GEO = dataclasses.replace(
|
|
|
584
585
|
geo='geo',
|
|
585
586
|
)
|
|
586
587
|
|
|
588
|
+
ADSTOCK_DECAY_SPEC_CASES = immutabledict.immutabledict({
|
|
589
|
+
'media': (
|
|
590
|
+
{},
|
|
591
|
+
{
|
|
592
|
+
'ch_0': c.BINOMIAL_DECAY,
|
|
593
|
+
'ch_1': c.GEOMETRIC_DECAY,
|
|
594
|
+
'ch_2': c.GEOMETRIC_DECAY,
|
|
595
|
+
},
|
|
596
|
+
),
|
|
597
|
+
'rf': (
|
|
598
|
+
{},
|
|
599
|
+
{
|
|
600
|
+
'rf_ch_0': c.BINOMIAL_DECAY,
|
|
601
|
+
'rf_ch_1': c.GEOMETRIC_DECAY,
|
|
602
|
+
'rf_ch_2': c.GEOMETRIC_DECAY,
|
|
603
|
+
'rf_ch_3': c.BINOMIAL_DECAY,
|
|
604
|
+
},
|
|
605
|
+
),
|
|
606
|
+
'organic_media': (
|
|
607
|
+
{},
|
|
608
|
+
{
|
|
609
|
+
'organic_media_0': c.BINOMIAL_DECAY,
|
|
610
|
+
'organic_media_1': c.GEOMETRIC_DECAY,
|
|
611
|
+
'organic_media_2': c.GEOMETRIC_DECAY,
|
|
612
|
+
'organic_media_3': c.BINOMIAL_DECAY,
|
|
613
|
+
'organic_media_4': c.GEOMETRIC_DECAY,
|
|
614
|
+
},
|
|
615
|
+
),
|
|
616
|
+
'organic_rf': (
|
|
617
|
+
{},
|
|
618
|
+
{
|
|
619
|
+
'organic_rf_ch_0': c.BINOMIAL_DECAY,
|
|
620
|
+
'organic_rf_ch_1': c.GEOMETRIC_DECAY,
|
|
621
|
+
'organic_rf_ch_2': c.GEOMETRIC_DECAY,
|
|
622
|
+
'organic_rf_ch_3': c.BINOMIAL_DECAY,
|
|
623
|
+
'organic_rf_ch_4': c.BINOMIAL_DECAY,
|
|
624
|
+
'organic_rf_ch_5': c.GEOMETRIC_DECAY,
|
|
625
|
+
},
|
|
626
|
+
),
|
|
627
|
+
})
|
|
628
|
+
|
|
587
629
|
|
|
588
630
|
def random_media_da(
|
|
589
631
|
n_geos: int,
|
|
@@ -595,6 +637,7 @@ def random_media_da(
|
|
|
595
637
|
explicit_geo_names: Sequence[str] | None = None,
|
|
596
638
|
explicit_time_index: Sequence[str] | None = None,
|
|
597
639
|
explicit_media_channel_names: Sequence[str] | None = None,
|
|
640
|
+
media_value_scales: list[tuple[float, float]] | None = None,
|
|
598
641
|
array_name: str = 'media',
|
|
599
642
|
channel_variable_name: str = 'media_channel',
|
|
600
643
|
channel_prefix: str = 'ch_',
|
|
@@ -613,6 +656,8 @@ def random_media_da(
|
|
|
613
656
|
explicit_time_index: If given, ignore `date_format` and use this as is
|
|
614
657
|
explicit_media_channel_names: If given, ignore `n_media_channels` and use
|
|
615
658
|
this as is
|
|
659
|
+
media_value_scales: A list of (mean, std) tuples, one for each media
|
|
660
|
+
channel, to control the scale of the generated random values.
|
|
616
661
|
array_name: The name of the array to be created
|
|
617
662
|
channel_variable_name: The name of the channel variable
|
|
618
663
|
channel_prefix: The prefix of the channel names
|
|
@@ -628,11 +673,28 @@ def random_media_da(
|
|
|
628
673
|
if n_times < n_media_times:
|
|
629
674
|
start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
|
|
630
675
|
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
676
|
+
if media_value_scales:
|
|
677
|
+
if len(media_value_scales) != n_media_channels:
|
|
678
|
+
raise ValueError(
|
|
679
|
+
'Length of media_value_scales must match n_media_channels.'
|
|
634
680
|
)
|
|
635
|
-
|
|
681
|
+
channel_data = []
|
|
682
|
+
for mean, std in media_value_scales:
|
|
683
|
+
channel_data.append(
|
|
684
|
+
np.round(
|
|
685
|
+
abs(np.random.normal(mean, std, size=(n_geos, n_media_times)))
|
|
686
|
+
)
|
|
687
|
+
)
|
|
688
|
+
media = np.stack(channel_data, axis=-1)
|
|
689
|
+
else:
|
|
690
|
+
media = np.round(
|
|
691
|
+
abs(
|
|
692
|
+
np.random.normal(
|
|
693
|
+
5, 5, size=(n_geos, n_media_times, n_media_channels)
|
|
694
|
+
)
|
|
695
|
+
)
|
|
696
|
+
)
|
|
697
|
+
|
|
636
698
|
if explicit_geo_names is None:
|
|
637
699
|
geos = sample_geos(n_geos, integer_geos)
|
|
638
700
|
else:
|
|
@@ -698,6 +760,7 @@ def random_media_spend_nd_da(
|
|
|
698
760
|
n_media_channels: int | None = None,
|
|
699
761
|
seed=0,
|
|
700
762
|
integer_geos: bool = False,
|
|
763
|
+
explicit_media_channel_names: Sequence[str] | None = None,
|
|
701
764
|
) -> xr.DataArray:
|
|
702
765
|
"""Generates a sample N-dimensional `media_spend` DataArray.
|
|
703
766
|
|
|
@@ -716,6 +779,8 @@ def random_media_spend_nd_da(
|
|
|
716
779
|
n_media_channels: Number of channels in the created `media_spend` array.
|
|
717
780
|
seed: Random seed used by `np.random.seed()`.
|
|
718
781
|
integer_geos: If True, the geos will be integers.
|
|
782
|
+
explicit_media_channel_names: If given, ignore `n_media_channels` and use
|
|
783
|
+
this as is.
|
|
719
784
|
|
|
720
785
|
Returns:
|
|
721
786
|
A DataArray containing the generated `media_spend` data with the given
|
|
@@ -733,9 +798,12 @@ def random_media_spend_nd_da(
|
|
|
733
798
|
coords['time'] = _sample_times(n_times=n_times)
|
|
734
799
|
if n_media_channels is not None:
|
|
735
800
|
dims.append('media_channel')
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
801
|
+
if explicit_media_channel_names is not None:
|
|
802
|
+
coords['media_channel'] = explicit_media_channel_names
|
|
803
|
+
else:
|
|
804
|
+
coords['media_channel'] = _sample_names(
|
|
805
|
+
prefix='ch_', n_names=n_media_channels
|
|
806
|
+
)
|
|
739
807
|
|
|
740
808
|
if dims == ['geo', 'time', 'media_channel']:
|
|
741
809
|
shape = (n_geos, n_times, n_media_channels)
|
|
@@ -822,6 +890,7 @@ def random_kpi_da(
|
|
|
822
890
|
controls: xr.DataArray | None = None,
|
|
823
891
|
seed: int = 0,
|
|
824
892
|
integer_geos: bool = False,
|
|
893
|
+
kpi_data_pattern: str = '',
|
|
825
894
|
) -> xr.DataArray:
|
|
826
895
|
"""Generates a sample `kpi` DataArray."""
|
|
827
896
|
|
|
@@ -857,6 +926,22 @@ def random_kpi_da(
|
|
|
857
926
|
|
|
858
927
|
error = np.random.normal(0, 2, size=(n_geos, n_times))
|
|
859
928
|
kpi = abs(media_portion + control_portion + error)
|
|
929
|
+
if kpi_data_pattern == 'flat':
|
|
930
|
+
first_col = kpi[:, 0] # all rows will have value same as first col
|
|
931
|
+
kpi = (
|
|
932
|
+
first_col[:, np.newaxis]
|
|
933
|
+
+ np.random.normal(scale=0.02, size=kpi.shape)
|
|
934
|
+
+ 0.04
|
|
935
|
+
)
|
|
936
|
+
elif kpi_data_pattern == 'seasonal':
|
|
937
|
+
for row in kpi:
|
|
938
|
+
row.sort()
|
|
939
|
+
kpi = np.sin(kpi) + 5
|
|
940
|
+
elif kpi_data_pattern == 'peak':
|
|
941
|
+
peak_index = int(len(kpi[0]) / 2)
|
|
942
|
+
kpi[:] = kpi[0, 0]
|
|
943
|
+
for row in kpi:
|
|
944
|
+
row[peak_index] *= 3
|
|
860
945
|
|
|
861
946
|
return xr.DataArray(
|
|
862
947
|
kpi,
|
|
@@ -891,14 +976,18 @@ def constant_revenue_per_kpi(
|
|
|
891
976
|
|
|
892
977
|
|
|
893
978
|
def random_population(
|
|
894
|
-
n_geos: int,
|
|
979
|
+
n_geos: int,
|
|
980
|
+
seed: int = 0,
|
|
981
|
+
integer_geos: bool = False,
|
|
982
|
+
constant_value: float | None = None,
|
|
895
983
|
) -> xr.DataArray:
|
|
896
984
|
"""Generates a sample `population` DataArray."""
|
|
897
985
|
|
|
898
986
|
np.random.seed(seed)
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
987
|
+
if constant_value is not None:
|
|
988
|
+
population = np.full(n_geos, constant_value)
|
|
989
|
+
else:
|
|
990
|
+
population = np.round(10 + abs(np.random.normal(3000, 100, size=n_geos)))
|
|
902
991
|
return xr.DataArray(
|
|
903
992
|
population,
|
|
904
993
|
dims=['geo'],
|
|
@@ -1170,11 +1259,15 @@ def random_dataset(
|
|
|
1170
1259
|
n_organic_media_channels: int | None = None,
|
|
1171
1260
|
n_organic_rf_channels: int | None = None,
|
|
1172
1261
|
n_media_channels: int | None = None,
|
|
1262
|
+
explicit_media_channel_names: Sequence[str] | None = None,
|
|
1263
|
+
media_value_scales: list[tuple[float, float]] | None = None,
|
|
1173
1264
|
n_rf_channels: int | None = None,
|
|
1174
1265
|
revenue_per_kpi_value: float | None = 3.14,
|
|
1266
|
+
constant_population_value: float | None = None,
|
|
1175
1267
|
seed: int = 0,
|
|
1176
1268
|
remove_media_time: bool = False,
|
|
1177
1269
|
integer_geos: bool = False,
|
|
1270
|
+
kpi_data_pattern: str = '',
|
|
1178
1271
|
) -> xr.Dataset:
|
|
1179
1272
|
"""Generates a random dataset."""
|
|
1180
1273
|
if n_media_channels:
|
|
@@ -1185,11 +1278,14 @@ def random_dataset(
|
|
|
1185
1278
|
n_media_channels=n_media_channels,
|
|
1186
1279
|
seed=seed,
|
|
1187
1280
|
integer_geos=integer_geos,
|
|
1281
|
+
explicit_media_channel_names=explicit_media_channel_names,
|
|
1282
|
+
media_value_scales=media_value_scales,
|
|
1188
1283
|
)
|
|
1189
1284
|
media_spend = random_media_spend_nd_da(
|
|
1190
1285
|
n_geos=n_geos,
|
|
1191
1286
|
n_times=n_times,
|
|
1192
1287
|
n_media_channels=n_media_channels,
|
|
1288
|
+
explicit_media_channel_names=explicit_media_channel_names,
|
|
1193
1289
|
seed=seed,
|
|
1194
1290
|
integer_geos=integer_geos,
|
|
1195
1291
|
)
|
|
@@ -1301,9 +1397,13 @@ def random_dataset(
|
|
|
1301
1397
|
n_media_channels=n_media_channels or n_rf_channels or 0,
|
|
1302
1398
|
n_controls=n_controls,
|
|
1303
1399
|
integer_geos=integer_geos,
|
|
1400
|
+
kpi_data_pattern=kpi_data_pattern,
|
|
1304
1401
|
)
|
|
1305
1402
|
population = random_population(
|
|
1306
|
-
n_geos=n_geos,
|
|
1403
|
+
n_geos=n_geos,
|
|
1404
|
+
seed=seed,
|
|
1405
|
+
integer_geos=integer_geos,
|
|
1406
|
+
constant_value=constant_population_value,
|
|
1307
1407
|
)
|
|
1308
1408
|
|
|
1309
1409
|
dataset = xr.combine_by_coords(
|
|
@@ -1644,6 +1744,7 @@ def sample_input_data_revenue(
|
|
|
1644
1744
|
n_organic_media_channels: int | None = None,
|
|
1645
1745
|
n_organic_rf_channels: int | None = None,
|
|
1646
1746
|
seed: int = 0,
|
|
1747
|
+
explicit_media_channel_names: Sequence[str] | None = None,
|
|
1647
1748
|
) -> input_data.InputData:
|
|
1648
1749
|
"""Generates sample InputData for `kpi_type='revenue'`."""
|
|
1649
1750
|
dataset = random_dataset(
|
|
@@ -1658,6 +1759,7 @@ def sample_input_data_revenue(
|
|
|
1658
1759
|
n_organic_rf_channels=n_organic_rf_channels,
|
|
1659
1760
|
revenue_per_kpi_value=1.0,
|
|
1660
1761
|
seed=seed,
|
|
1762
|
+
explicit_media_channel_names=explicit_media_channel_names,
|
|
1661
1763
|
)
|
|
1662
1764
|
return input_data.InputData(
|
|
1663
1765
|
kpi=dataset.kpi,
|
|
@@ -1773,3 +1875,33 @@ def sample_input_data_non_revenue_no_revenue_per_kpi(
|
|
|
1773
1875
|
if n_organic_rf_channels
|
|
1774
1876
|
else None,
|
|
1775
1877
|
)
|
|
1878
|
+
|
|
1879
|
+
|
|
1880
|
+
def sample_input_data_for_aks_with_expected_knot_info() -> (
|
|
1881
|
+
tuple[input_data.InputData, knots.KnotInfo]
|
|
1882
|
+
):
|
|
1883
|
+
"""Generates sample InputData and corresponding expected KnotInfo for testing.
|
|
1884
|
+
|
|
1885
|
+
Returns:
|
|
1886
|
+
A tuple containing:
|
|
1887
|
+
- InputData object with sample data.
|
|
1888
|
+
- KnotInfo object with expected knot information.
|
|
1889
|
+
"""
|
|
1890
|
+
data = sample_input_data_from_dataset(
|
|
1891
|
+
random_dataset(
|
|
1892
|
+
n_geos=20,
|
|
1893
|
+
n_times=117,
|
|
1894
|
+
n_media_times=117,
|
|
1895
|
+
n_controls=2,
|
|
1896
|
+
n_media_channels=5,
|
|
1897
|
+
),
|
|
1898
|
+
'non_revenue',
|
|
1899
|
+
)
|
|
1900
|
+
expected_knot_info = knots.KnotInfo(
|
|
1901
|
+
n_knots=6,
|
|
1902
|
+
knot_locations=np.array([38, 39, 41, 48, 50, 55]),
|
|
1903
|
+
weights=knots.l1_distance_weights(
|
|
1904
|
+
117, np.array([38, 39, 41, 48, 50, 55])
|
|
1905
|
+
),
|
|
1906
|
+
)
|
|
1907
|
+
return data, expected_knot_info
|