google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/summarizer.py +7 -2
- meridian/analysis/test_utils.py +934 -485
- meridian/analysis/visualizer.py +10 -6
- meridian/constants.py +1 -0
- meridian/data/test_utils.py +82 -10
- meridian/model/__init__.py +2 -0
- meridian/model/context.py +925 -0
- meridian/model/eda/constants.py +1 -0
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +93 -792
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +354 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1136 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +412 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/test_data.py +380 -0
- schema/utils/__init__.py +1 -0
- schema/utils/date_range_bucketing.py +117 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
meridian/analysis/visualizer.py
CHANGED
|
@@ -243,6 +243,12 @@ class ModelDiagnostics:
|
|
|
243
243
|
|
|
244
244
|
groupby = posterior_df.columns.tolist()
|
|
245
245
|
groupby.remove(parameter)
|
|
246
|
+
|
|
247
|
+
parameter_99_max = prior_posterior_df[parameter].quantile(0.99)
|
|
248
|
+
# Remove outliers that make the chart hard to read.
|
|
249
|
+
prior_posterior_df[parameter] = prior_posterior_df[parameter].clip(
|
|
250
|
+
upper=parameter_99_max * c.OUTLIER_CLIP_FACTOR
|
|
251
|
+
)
|
|
246
252
|
plot = (
|
|
247
253
|
alt.Chart(prior_posterior_df, width=c.VEGALITE_FACET_DEFAULT_WIDTH)
|
|
248
254
|
.transform_density(
|
|
@@ -269,7 +275,7 @@ class ModelDiagnostics:
|
|
|
269
275
|
title=formatter.custom_title_params(
|
|
270
276
|
summary_text.PRIOR_POSTERIOR_DIST_CHART_TITLE
|
|
271
277
|
)
|
|
272
|
-
).configure_axis(**formatter.TEXT_CONFIG)
|
|
278
|
+
).configure_axis(**formatter.TEXT_CONFIG).interactive()
|
|
273
279
|
|
|
274
280
|
def plot_rhat_boxplot(self) -> alt.Chart:
|
|
275
281
|
"""Plots the R-hat box plot.
|
|
@@ -1450,17 +1456,15 @@ class MediaSummary:
|
|
|
1450
1456
|
|
|
1451
1457
|
Args:
|
|
1452
1458
|
aggregate_times: If `True`, aggregates the metrics across all time
|
|
1453
|
-
periods.
|
|
1459
|
+
periods. If `False`, returns time-varying metrics.
|
|
1454
1460
|
|
|
1455
1461
|
Returns:
|
|
1456
1462
|
An `xarray.Dataset` containing the following:
|
|
1457
1463
|
- **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
|
|
1458
|
-
|
|
1459
|
-
`distribution` (`prior`, `posterior`)
|
|
1464
|
+
`ci_hi`), `distribution` (`prior`, `posterior`)
|
|
1460
1465
|
- **Data variables:** `impressions`, `pct_of_impressions`, `spend`,
|
|
1461
1466
|
`pct_of_spend`, `CPM`, `incremental_outcome`, `pct_of_contribution`,
|
|
1462
|
-
`roi`,
|
|
1463
|
-
`effectiveness`, `mroi`.
|
|
1467
|
+
`roi`, `effectiveness`, `mroi`.
|
|
1464
1468
|
"""
|
|
1465
1469
|
return self._analyzer.summary_metrics(
|
|
1466
1470
|
selected_times=self._selected_times,
|
meridian/constants.py
CHANGED
meridian/data/test_utils.py
CHANGED
|
@@ -642,6 +642,7 @@ def random_media_da(
|
|
|
642
642
|
channel_variable_name: str = 'media_channel',
|
|
643
643
|
channel_prefix: str = 'ch_',
|
|
644
644
|
integer_geos: bool = False,
|
|
645
|
+
nonzero_shift: float = 0.0,
|
|
645
646
|
) -> xr.DataArray:
|
|
646
647
|
"""Generates a sample `media` DataArray.
|
|
647
648
|
|
|
@@ -662,6 +663,7 @@ def random_media_da(
|
|
|
662
663
|
channel_variable_name: The name of the channel variable
|
|
663
664
|
channel_prefix: The prefix of the channel names
|
|
664
665
|
integer_geos: If True, the geos will be integers.
|
|
666
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
665
667
|
|
|
666
668
|
Returns:
|
|
667
669
|
A DataArray containing random data.
|
|
@@ -695,6 +697,8 @@ def random_media_da(
|
|
|
695
697
|
)
|
|
696
698
|
)
|
|
697
699
|
|
|
700
|
+
media = media + nonzero_shift
|
|
701
|
+
|
|
698
702
|
if explicit_geo_names is None:
|
|
699
703
|
geos = sample_geos(n_geos, integer_geos)
|
|
700
704
|
else:
|
|
@@ -736,6 +740,7 @@ def random_organic_media_da(
|
|
|
736
740
|
explicit_time_index: Sequence[str] | None = None,
|
|
737
741
|
explicit_media_channel_names: Sequence[str] | None = None,
|
|
738
742
|
integer_geos: bool = False,
|
|
743
|
+
nonzero_shift: float = 0.0,
|
|
739
744
|
) -> xr.DataArray:
|
|
740
745
|
"""Generates a sample `organic_media` DataArray."""
|
|
741
746
|
return random_media_da(
|
|
@@ -751,6 +756,7 @@ def random_organic_media_da(
|
|
|
751
756
|
channel_variable_name='organic_media_channel',
|
|
752
757
|
channel_prefix='organic_media_',
|
|
753
758
|
integer_geos=integer_geos,
|
|
759
|
+
nonzero_shift=nonzero_shift,
|
|
754
760
|
)
|
|
755
761
|
|
|
756
762
|
|
|
@@ -761,6 +767,7 @@ def random_media_spend_nd_da(
|
|
|
761
767
|
seed=0,
|
|
762
768
|
integer_geos: bool = False,
|
|
763
769
|
explicit_media_channel_names: Sequence[str] | None = None,
|
|
770
|
+
nonzero_shift: float = 0.0,
|
|
764
771
|
) -> xr.DataArray:
|
|
765
772
|
"""Generates a sample N-dimensional `media_spend` DataArray.
|
|
766
773
|
|
|
@@ -781,6 +788,7 @@ def random_media_spend_nd_da(
|
|
|
781
788
|
integer_geos: If True, the geos will be integers.
|
|
782
789
|
explicit_media_channel_names: If given, ignore `n_media_channels` and use
|
|
783
790
|
this as is.
|
|
791
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
784
792
|
|
|
785
793
|
Returns:
|
|
786
794
|
A DataArray containing the generated `media_spend` data with the given
|
|
@@ -818,7 +826,7 @@ def random_media_spend_nd_da(
|
|
|
818
826
|
f'Shape {dims} not supported by the random_media_spend_nd_da function.'
|
|
819
827
|
)
|
|
820
828
|
|
|
821
|
-
media_spend = abs(np.random.normal(1, 1, size=shape))
|
|
829
|
+
media_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
|
|
822
830
|
|
|
823
831
|
return xr.DataArray(
|
|
824
832
|
media_spend,
|
|
@@ -1007,8 +1015,27 @@ def random_reach_da(
|
|
|
1007
1015
|
channel_variable_name: str = 'rf_channel',
|
|
1008
1016
|
channel_prefix: str = 'rf_ch_',
|
|
1009
1017
|
integer_geos: bool = False,
|
|
1018
|
+
nonzero_shift: float = 0.0,
|
|
1010
1019
|
) -> xr.DataArray:
|
|
1011
|
-
"""Generates a sample `reach` DataArray.
|
|
1020
|
+
"""Generates a sample `reach` DataArray.
|
|
1021
|
+
|
|
1022
|
+
Args:
|
|
1023
|
+
n_geos: Number of geos
|
|
1024
|
+
n_times: Number of time periods
|
|
1025
|
+
n_media_times: Number of media time periods
|
|
1026
|
+
n_rf_channels: Number of reach and frequency channels
|
|
1027
|
+
seed: Random seed used by `np.random.seed()`
|
|
1028
|
+
explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
|
|
1029
|
+
is
|
|
1030
|
+
array_name: The name of the array to be created
|
|
1031
|
+
channel_variable_name: The name of the channel variable
|
|
1032
|
+
channel_prefix: The prefix of the channel names
|
|
1033
|
+
integer_geos: If True, the geos will be integers.
|
|
1034
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1035
|
+
|
|
1036
|
+
Returns:
|
|
1037
|
+
A DataArray containing random data.
|
|
1038
|
+
"""
|
|
1012
1039
|
|
|
1013
1040
|
np.random.seed(seed)
|
|
1014
1041
|
|
|
@@ -1016,12 +1043,15 @@ def random_reach_da(
|
|
|
1016
1043
|
if n_times < n_media_times:
|
|
1017
1044
|
start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
|
|
1018
1045
|
|
|
1019
|
-
reach =
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1046
|
+
reach = (
|
|
1047
|
+
np.round(
|
|
1048
|
+
abs(
|
|
1049
|
+
np.random.normal(
|
|
1050
|
+
3000, 100, size=(n_geos, n_media_times, n_rf_channels)
|
|
1051
|
+
)
|
|
1023
1052
|
)
|
|
1024
1053
|
)
|
|
1054
|
+
+ nonzero_shift
|
|
1025
1055
|
)
|
|
1026
1056
|
|
|
1027
1057
|
channels = (
|
|
@@ -1051,6 +1081,7 @@ def random_organic_reach_da(
|
|
|
1051
1081
|
seed: int = 0,
|
|
1052
1082
|
explicit_organic_rf_channel_names: Sequence[str] | None = None,
|
|
1053
1083
|
integer_geos: bool = False,
|
|
1084
|
+
nonzero_shift: float = 0.0,
|
|
1054
1085
|
) -> xr.DataArray:
|
|
1055
1086
|
"""Generates a sample `organic_reach` DataArray."""
|
|
1056
1087
|
return random_reach_da(
|
|
@@ -1064,6 +1095,7 @@ def random_organic_reach_da(
|
|
|
1064
1095
|
channel_variable_name='organic_rf_channel',
|
|
1065
1096
|
channel_prefix='organic_rf_ch_',
|
|
1066
1097
|
integer_geos=integer_geos,
|
|
1098
|
+
nonzero_shift=nonzero_shift,
|
|
1067
1099
|
)
|
|
1068
1100
|
|
|
1069
1101
|
|
|
@@ -1078,8 +1110,27 @@ def random_frequency_da(
|
|
|
1078
1110
|
channel_variable_name: str = 'rf_channel',
|
|
1079
1111
|
channel_prefix: str = 'rf_ch_',
|
|
1080
1112
|
integer_geos: bool = False,
|
|
1113
|
+
nonzero_shift: float = 0.0,
|
|
1081
1114
|
) -> xr.DataArray:
|
|
1082
|
-
"""Generates a sample `frequency` DataArray.
|
|
1115
|
+
"""Generates a sample `frequency` DataArray.
|
|
1116
|
+
|
|
1117
|
+
Args:
|
|
1118
|
+
n_geos: Number of geos
|
|
1119
|
+
n_times: Number of time periods
|
|
1120
|
+
n_media_times: Number of media time periods
|
|
1121
|
+
n_rf_channels: Number of reach and frequency channels
|
|
1122
|
+
seed: Random seed used by `np.random.seed()`
|
|
1123
|
+
explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
|
|
1124
|
+
is
|
|
1125
|
+
array_name: The name of the array to be created
|
|
1126
|
+
channel_variable_name: The name of the channel variable
|
|
1127
|
+
channel_prefix: The prefix of the channel names
|
|
1128
|
+
integer_geos: If True, the geos will be integers.
|
|
1129
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1130
|
+
|
|
1131
|
+
Returns:
|
|
1132
|
+
A DataArray containing random data.
|
|
1133
|
+
"""
|
|
1083
1134
|
|
|
1084
1135
|
np.random.seed(seed)
|
|
1085
1136
|
|
|
@@ -1087,8 +1138,9 @@ def random_frequency_da(
|
|
|
1087
1138
|
if n_times < n_media_times:
|
|
1088
1139
|
start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
|
|
1089
1140
|
|
|
1090
|
-
frequency =
|
|
1091
|
-
np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels))
|
|
1141
|
+
frequency = (
|
|
1142
|
+
abs(np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels)))
|
|
1143
|
+
+ nonzero_shift
|
|
1092
1144
|
)
|
|
1093
1145
|
|
|
1094
1146
|
channels = (
|
|
@@ -1119,6 +1171,7 @@ def random_organic_frequency_da(
|
|
|
1119
1171
|
seed: int = 0,
|
|
1120
1172
|
explicit_organic_rf_channel_names: Sequence[str] | None = None,
|
|
1121
1173
|
integer_geos: bool = False,
|
|
1174
|
+
nonzero_shift: float = 0.0,
|
|
1122
1175
|
) -> xr.DataArray:
|
|
1123
1176
|
"""Generates a sample `organic_frequency` DataArray."""
|
|
1124
1177
|
return random_frequency_da(
|
|
@@ -1132,6 +1185,7 @@ def random_organic_frequency_da(
|
|
|
1132
1185
|
channel_variable_name='organic_rf_channel',
|
|
1133
1186
|
channel_prefix='organic_rf_ch_',
|
|
1134
1187
|
integer_geos=integer_geos,
|
|
1188
|
+
nonzero_shift=nonzero_shift,
|
|
1135
1189
|
)
|
|
1136
1190
|
|
|
1137
1191
|
|
|
@@ -1141,6 +1195,7 @@ def random_rf_spend_nd_da(
|
|
|
1141
1195
|
n_rf_channels: int | None = None,
|
|
1142
1196
|
seed=0,
|
|
1143
1197
|
integer_geos: bool = False,
|
|
1198
|
+
nonzero_shift: float = 0.0,
|
|
1144
1199
|
) -> xr.DataArray:
|
|
1145
1200
|
"""Generates a sample N-dimensional `rf_spend` DataArray.
|
|
1146
1201
|
|
|
@@ -1157,6 +1212,7 @@ def random_rf_spend_nd_da(
|
|
|
1157
1212
|
n_rf_channels: Number of channels in the created `rf_spend` array.
|
|
1158
1213
|
seed: Random seed used by `np.random.seed()`.
|
|
1159
1214
|
integer_geos: If True, the geos will be integers.
|
|
1215
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1160
1216
|
|
|
1161
1217
|
Returns:
|
|
1162
1218
|
A DataArray containing the generated `rf_spend` data with the given
|
|
@@ -1187,7 +1243,7 @@ def random_rf_spend_nd_da(
|
|
|
1187
1243
|
f'Shape {dims} not supported by the random_rf_spend_nd_da function.'
|
|
1188
1244
|
)
|
|
1189
1245
|
|
|
1190
|
-
rf_spend = abs(np.random.normal(1, 1, size=shape))
|
|
1246
|
+
rf_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
|
|
1191
1247
|
|
|
1192
1248
|
return xr.DataArray(
|
|
1193
1249
|
rf_spend,
|
|
@@ -1206,6 +1262,7 @@ def random_non_media_treatments_da(
|
|
|
1206
1262
|
date_format: str = c.DATE_FORMAT,
|
|
1207
1263
|
explicit_time_index: Sequence[str] | None = None,
|
|
1208
1264
|
integer_geos: bool = False,
|
|
1265
|
+
nonzero_shift: float = 0.0,
|
|
1209
1266
|
) -> xr.DataArray:
|
|
1210
1267
|
"""Generates a sample `non_media_treatments` DataArray.
|
|
1211
1268
|
|
|
@@ -1218,6 +1275,7 @@ def random_non_media_treatments_da(
|
|
|
1218
1275
|
date_format: The date format to use for time coordinate labels
|
|
1219
1276
|
explicit_time_index: If given, ignore `date_format` and use this as is
|
|
1220
1277
|
integer_geos: If True, the geos will be integers.
|
|
1278
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1221
1279
|
|
|
1222
1280
|
Returns:
|
|
1223
1281
|
A DataArray containing random non-media variable.
|
|
@@ -1232,6 +1290,8 @@ def random_non_media_treatments_da(
|
|
|
1232
1290
|
non_media_channel,
|
|
1233
1291
|
size=(n_geos, n_times, n_non_media_channels),
|
|
1234
1292
|
)
|
|
1293
|
+
non_media_treatments = non_media_treatments + nonzero_shift
|
|
1294
|
+
|
|
1235
1295
|
return xr.DataArray(
|
|
1236
1296
|
non_media_treatments,
|
|
1237
1297
|
dims=['geo', 'time', 'non_media_channel'],
|
|
@@ -1268,6 +1328,7 @@ def random_dataset(
|
|
|
1268
1328
|
remove_media_time: bool = False,
|
|
1269
1329
|
integer_geos: bool = False,
|
|
1270
1330
|
kpi_data_pattern: str = '',
|
|
1331
|
+
nonzero_shift: float = 0.0,
|
|
1271
1332
|
) -> xr.Dataset:
|
|
1272
1333
|
"""Generates a random dataset."""
|
|
1273
1334
|
if n_media_channels:
|
|
@@ -1280,6 +1341,7 @@ def random_dataset(
|
|
|
1280
1341
|
integer_geos=integer_geos,
|
|
1281
1342
|
explicit_media_channel_names=explicit_media_channel_names,
|
|
1282
1343
|
media_value_scales=media_value_scales,
|
|
1344
|
+
nonzero_shift=nonzero_shift,
|
|
1283
1345
|
)
|
|
1284
1346
|
media_spend = random_media_spend_nd_da(
|
|
1285
1347
|
n_geos=n_geos,
|
|
@@ -1288,6 +1350,7 @@ def random_dataset(
|
|
|
1288
1350
|
explicit_media_channel_names=explicit_media_channel_names,
|
|
1289
1351
|
seed=seed,
|
|
1290
1352
|
integer_geos=integer_geos,
|
|
1353
|
+
nonzero_shift=nonzero_shift,
|
|
1291
1354
|
)
|
|
1292
1355
|
else:
|
|
1293
1356
|
media = None
|
|
@@ -1301,6 +1364,7 @@ def random_dataset(
|
|
|
1301
1364
|
n_rf_channels=n_rf_channels,
|
|
1302
1365
|
seed=seed,
|
|
1303
1366
|
integer_geos=integer_geos,
|
|
1367
|
+
nonzero_shift=nonzero_shift,
|
|
1304
1368
|
)
|
|
1305
1369
|
frequency = random_frequency_da(
|
|
1306
1370
|
n_geos=n_geos,
|
|
@@ -1309,6 +1373,7 @@ def random_dataset(
|
|
|
1309
1373
|
n_rf_channels=n_rf_channels,
|
|
1310
1374
|
seed=seed,
|
|
1311
1375
|
integer_geos=integer_geos,
|
|
1376
|
+
nonzero_shift=nonzero_shift,
|
|
1312
1377
|
)
|
|
1313
1378
|
rf_spend = random_rf_spend_nd_da(
|
|
1314
1379
|
n_geos=n_geos,
|
|
@@ -1316,6 +1381,7 @@ def random_dataset(
|
|
|
1316
1381
|
n_rf_channels=n_rf_channels,
|
|
1317
1382
|
seed=seed,
|
|
1318
1383
|
integer_geos=integer_geos,
|
|
1384
|
+
nonzero_shift=nonzero_shift,
|
|
1319
1385
|
)
|
|
1320
1386
|
else:
|
|
1321
1387
|
reach = None
|
|
@@ -1352,6 +1418,7 @@ def random_dataset(
|
|
|
1352
1418
|
n_non_media_channels=n_non_media_channels,
|
|
1353
1419
|
seed=seed,
|
|
1354
1420
|
integer_geos=integer_geos,
|
|
1421
|
+
nonzero_shift=nonzero_shift,
|
|
1355
1422
|
)
|
|
1356
1423
|
else:
|
|
1357
1424
|
non_media_treatments = None
|
|
@@ -1364,6 +1431,7 @@ def random_dataset(
|
|
|
1364
1431
|
n_organic_media_channels=n_organic_media_channels,
|
|
1365
1432
|
seed=seed,
|
|
1366
1433
|
integer_geos=integer_geos,
|
|
1434
|
+
nonzero_shift=nonzero_shift,
|
|
1367
1435
|
)
|
|
1368
1436
|
else:
|
|
1369
1437
|
organic_media = None
|
|
@@ -1376,6 +1444,7 @@ def random_dataset(
|
|
|
1376
1444
|
n_organic_rf_channels=n_organic_rf_channels,
|
|
1377
1445
|
seed=seed,
|
|
1378
1446
|
integer_geos=integer_geos,
|
|
1447
|
+
nonzero_shift=nonzero_shift,
|
|
1379
1448
|
)
|
|
1380
1449
|
organic_frequency = random_organic_frequency_da(
|
|
1381
1450
|
n_geos=n_geos,
|
|
@@ -1384,6 +1453,7 @@ def random_dataset(
|
|
|
1384
1453
|
n_organic_rf_channels=n_organic_rf_channels,
|
|
1385
1454
|
seed=seed,
|
|
1386
1455
|
integer_geos=integer_geos,
|
|
1456
|
+
nonzero_shift=nonzero_shift,
|
|
1387
1457
|
)
|
|
1388
1458
|
else:
|
|
1389
1459
|
organic_reach = None
|
|
@@ -1794,6 +1864,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1794
1864
|
n_organic_media_channels: int | None = None,
|
|
1795
1865
|
n_organic_rf_channels: int | None = None,
|
|
1796
1866
|
seed: int = 0,
|
|
1867
|
+
nonzero_shift: float = 0.0,
|
|
1797
1868
|
) -> input_data.InputData:
|
|
1798
1869
|
"""Generates sample InputData for `non_revenue` KPI w/ revenue_per_kpi."""
|
|
1799
1870
|
dataset = random_dataset(
|
|
@@ -1807,6 +1878,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1807
1878
|
n_organic_media_channels=n_organic_media_channels,
|
|
1808
1879
|
n_organic_rf_channels=n_organic_rf_channels,
|
|
1809
1880
|
seed=seed,
|
|
1881
|
+
nonzero_shift=nonzero_shift,
|
|
1810
1882
|
)
|
|
1811
1883
|
return input_data.InputData(
|
|
1812
1884
|
kpi=dataset.kpi,
|
meridian/model/__init__.py
CHANGED
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
"""The Meridian API module that models the data."""
|
|
16
16
|
|
|
17
17
|
from meridian.model import adstock_hill
|
|
18
|
+
from meridian.model import context
|
|
18
19
|
from meridian.model import eda
|
|
20
|
+
from meridian.model import equations
|
|
19
21
|
from meridian.model import knots
|
|
20
22
|
from meridian.model import media
|
|
21
23
|
from meridian.model import model
|