google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
meridian/data/test_utils.py
CHANGED
|
@@ -642,6 +642,7 @@ def random_media_da(
|
|
|
642
642
|
channel_variable_name: str = 'media_channel',
|
|
643
643
|
channel_prefix: str = 'ch_',
|
|
644
644
|
integer_geos: bool = False,
|
|
645
|
+
nonzero_shift: float = 0.0,
|
|
645
646
|
) -> xr.DataArray:
|
|
646
647
|
"""Generates a sample `media` DataArray.
|
|
647
648
|
|
|
@@ -662,6 +663,7 @@ def random_media_da(
|
|
|
662
663
|
channel_variable_name: The name of the channel variable
|
|
663
664
|
channel_prefix: The prefix of the channel names
|
|
664
665
|
integer_geos: If True, the geos will be integers.
|
|
666
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
665
667
|
|
|
666
668
|
Returns:
|
|
667
669
|
A DataArray containing random data.
|
|
@@ -695,6 +697,8 @@ def random_media_da(
|
|
|
695
697
|
)
|
|
696
698
|
)
|
|
697
699
|
|
|
700
|
+
media = media + nonzero_shift
|
|
701
|
+
|
|
698
702
|
if explicit_geo_names is None:
|
|
699
703
|
geos = sample_geos(n_geos, integer_geos)
|
|
700
704
|
else:
|
|
@@ -736,6 +740,7 @@ def random_organic_media_da(
|
|
|
736
740
|
explicit_time_index: Sequence[str] | None = None,
|
|
737
741
|
explicit_media_channel_names: Sequence[str] | None = None,
|
|
738
742
|
integer_geos: bool = False,
|
|
743
|
+
nonzero_shift: float = 0.0,
|
|
739
744
|
) -> xr.DataArray:
|
|
740
745
|
"""Generates a sample `organic_media` DataArray."""
|
|
741
746
|
return random_media_da(
|
|
@@ -751,6 +756,7 @@ def random_organic_media_da(
|
|
|
751
756
|
channel_variable_name='organic_media_channel',
|
|
752
757
|
channel_prefix='organic_media_',
|
|
753
758
|
integer_geos=integer_geos,
|
|
759
|
+
nonzero_shift=nonzero_shift,
|
|
754
760
|
)
|
|
755
761
|
|
|
756
762
|
|
|
@@ -761,6 +767,7 @@ def random_media_spend_nd_da(
|
|
|
761
767
|
seed=0,
|
|
762
768
|
integer_geos: bool = False,
|
|
763
769
|
explicit_media_channel_names: Sequence[str] | None = None,
|
|
770
|
+
nonzero_shift: float = 0.0,
|
|
764
771
|
) -> xr.DataArray:
|
|
765
772
|
"""Generates a sample N-dimensional `media_spend` DataArray.
|
|
766
773
|
|
|
@@ -781,6 +788,7 @@ def random_media_spend_nd_da(
|
|
|
781
788
|
integer_geos: If True, the geos will be integers.
|
|
782
789
|
explicit_media_channel_names: If given, ignore `n_media_channels` and use
|
|
783
790
|
this as is.
|
|
791
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
784
792
|
|
|
785
793
|
Returns:
|
|
786
794
|
A DataArray containing the generated `media_spend` data with the given
|
|
@@ -818,7 +826,7 @@ def random_media_spend_nd_da(
|
|
|
818
826
|
f'Shape {dims} not supported by the random_media_spend_nd_da function.'
|
|
819
827
|
)
|
|
820
828
|
|
|
821
|
-
media_spend = abs(np.random.normal(1, 1, size=shape))
|
|
829
|
+
media_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
|
|
822
830
|
|
|
823
831
|
return xr.DataArray(
|
|
824
832
|
media_spend,
|
|
@@ -1007,8 +1015,27 @@ def random_reach_da(
|
|
|
1007
1015
|
channel_variable_name: str = 'rf_channel',
|
|
1008
1016
|
channel_prefix: str = 'rf_ch_',
|
|
1009
1017
|
integer_geos: bool = False,
|
|
1018
|
+
nonzero_shift: float = 0.0,
|
|
1010
1019
|
) -> xr.DataArray:
|
|
1011
|
-
"""Generates a sample `reach` DataArray.
|
|
1020
|
+
"""Generates a sample `reach` DataArray.
|
|
1021
|
+
|
|
1022
|
+
Args:
|
|
1023
|
+
n_geos: Number of geos
|
|
1024
|
+
n_times: Number of time periods
|
|
1025
|
+
n_media_times: Number of media time periods
|
|
1026
|
+
n_rf_channels: Number of reach and frequency channels
|
|
1027
|
+
seed: Random seed used by `np.random.seed()`
|
|
1028
|
+
explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
|
|
1029
|
+
is
|
|
1030
|
+
array_name: The name of the array to be created
|
|
1031
|
+
channel_variable_name: The name of the channel variable
|
|
1032
|
+
channel_prefix: The prefix of the channel names
|
|
1033
|
+
integer_geos: If True, the geos will be integers.
|
|
1034
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1035
|
+
|
|
1036
|
+
Returns:
|
|
1037
|
+
A DataArray containing random data.
|
|
1038
|
+
"""
|
|
1012
1039
|
|
|
1013
1040
|
np.random.seed(seed)
|
|
1014
1041
|
|
|
@@ -1016,12 +1043,15 @@ def random_reach_da(
|
|
|
1016
1043
|
if n_times < n_media_times:
|
|
1017
1044
|
start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
|
|
1018
1045
|
|
|
1019
|
-
reach =
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1046
|
+
reach = (
|
|
1047
|
+
np.round(
|
|
1048
|
+
abs(
|
|
1049
|
+
np.random.normal(
|
|
1050
|
+
3000, 100, size=(n_geos, n_media_times, n_rf_channels)
|
|
1051
|
+
)
|
|
1023
1052
|
)
|
|
1024
1053
|
)
|
|
1054
|
+
+ nonzero_shift
|
|
1025
1055
|
)
|
|
1026
1056
|
|
|
1027
1057
|
channels = (
|
|
@@ -1051,6 +1081,7 @@ def random_organic_reach_da(
|
|
|
1051
1081
|
seed: int = 0,
|
|
1052
1082
|
explicit_organic_rf_channel_names: Sequence[str] | None = None,
|
|
1053
1083
|
integer_geos: bool = False,
|
|
1084
|
+
nonzero_shift: float = 0.0,
|
|
1054
1085
|
) -> xr.DataArray:
|
|
1055
1086
|
"""Generates a sample `organic_reach` DataArray."""
|
|
1056
1087
|
return random_reach_da(
|
|
@@ -1064,6 +1095,7 @@ def random_organic_reach_da(
|
|
|
1064
1095
|
channel_variable_name='organic_rf_channel',
|
|
1065
1096
|
channel_prefix='organic_rf_ch_',
|
|
1066
1097
|
integer_geos=integer_geos,
|
|
1098
|
+
nonzero_shift=nonzero_shift,
|
|
1067
1099
|
)
|
|
1068
1100
|
|
|
1069
1101
|
|
|
@@ -1078,8 +1110,27 @@ def random_frequency_da(
|
|
|
1078
1110
|
channel_variable_name: str = 'rf_channel',
|
|
1079
1111
|
channel_prefix: str = 'rf_ch_',
|
|
1080
1112
|
integer_geos: bool = False,
|
|
1113
|
+
nonzero_shift: float = 0.0,
|
|
1081
1114
|
) -> xr.DataArray:
|
|
1082
|
-
"""Generates a sample `frequency` DataArray.
|
|
1115
|
+
"""Generates a sample `frequency` DataArray.
|
|
1116
|
+
|
|
1117
|
+
Args:
|
|
1118
|
+
n_geos: Number of geos
|
|
1119
|
+
n_times: Number of time periods
|
|
1120
|
+
n_media_times: Number of media time periods
|
|
1121
|
+
n_rf_channels: Number of reach and frequency channels
|
|
1122
|
+
seed: Random seed used by `np.random.seed()`
|
|
1123
|
+
explicit_rf_channel_names: If given, ignore `n_rf_channels` and use this as
|
|
1124
|
+
is
|
|
1125
|
+
array_name: The name of the array to be created
|
|
1126
|
+
channel_variable_name: The name of the channel variable
|
|
1127
|
+
channel_prefix: The prefix of the channel names
|
|
1128
|
+
integer_geos: If True, the geos will be integers.
|
|
1129
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1130
|
+
|
|
1131
|
+
Returns:
|
|
1132
|
+
A DataArray containing random data.
|
|
1133
|
+
"""
|
|
1083
1134
|
|
|
1084
1135
|
np.random.seed(seed)
|
|
1085
1136
|
|
|
@@ -1087,8 +1138,9 @@ def random_frequency_da(
|
|
|
1087
1138
|
if n_times < n_media_times:
|
|
1088
1139
|
start_date -= datetime.timedelta(weeks=(n_media_times - n_times))
|
|
1089
1140
|
|
|
1090
|
-
frequency =
|
|
1091
|
-
np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels))
|
|
1141
|
+
frequency = (
|
|
1142
|
+
abs(np.random.normal(3, 5, size=(n_geos, n_media_times, n_rf_channels)))
|
|
1143
|
+
+ nonzero_shift
|
|
1092
1144
|
)
|
|
1093
1145
|
|
|
1094
1146
|
channels = (
|
|
@@ -1119,6 +1171,7 @@ def random_organic_frequency_da(
|
|
|
1119
1171
|
seed: int = 0,
|
|
1120
1172
|
explicit_organic_rf_channel_names: Sequence[str] | None = None,
|
|
1121
1173
|
integer_geos: bool = False,
|
|
1174
|
+
nonzero_shift: float = 0.0,
|
|
1122
1175
|
) -> xr.DataArray:
|
|
1123
1176
|
"""Generates a sample `organic_frequency` DataArray."""
|
|
1124
1177
|
return random_frequency_da(
|
|
@@ -1132,6 +1185,7 @@ def random_organic_frequency_da(
|
|
|
1132
1185
|
channel_variable_name='organic_rf_channel',
|
|
1133
1186
|
channel_prefix='organic_rf_ch_',
|
|
1134
1187
|
integer_geos=integer_geos,
|
|
1188
|
+
nonzero_shift=nonzero_shift,
|
|
1135
1189
|
)
|
|
1136
1190
|
|
|
1137
1191
|
|
|
@@ -1141,6 +1195,7 @@ def random_rf_spend_nd_da(
|
|
|
1141
1195
|
n_rf_channels: int | None = None,
|
|
1142
1196
|
seed=0,
|
|
1143
1197
|
integer_geos: bool = False,
|
|
1198
|
+
nonzero_shift: float = 0.0,
|
|
1144
1199
|
) -> xr.DataArray:
|
|
1145
1200
|
"""Generates a sample N-dimensional `rf_spend` DataArray.
|
|
1146
1201
|
|
|
@@ -1157,6 +1212,7 @@ def random_rf_spend_nd_da(
|
|
|
1157
1212
|
n_rf_channels: Number of channels in the created `rf_spend` array.
|
|
1158
1213
|
seed: Random seed used by `np.random.seed()`.
|
|
1159
1214
|
integer_geos: If True, the geos will be integers.
|
|
1215
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1160
1216
|
|
|
1161
1217
|
Returns:
|
|
1162
1218
|
A DataArray containing the generated `rf_spend` data with the given
|
|
@@ -1187,7 +1243,7 @@ def random_rf_spend_nd_da(
|
|
|
1187
1243
|
f'Shape {dims} not supported by the random_rf_spend_nd_da function.'
|
|
1188
1244
|
)
|
|
1189
1245
|
|
|
1190
|
-
rf_spend = abs(np.random.normal(1, 1, size=shape))
|
|
1246
|
+
rf_spend = abs(np.random.normal(1, 1, size=shape)) + nonzero_shift
|
|
1191
1247
|
|
|
1192
1248
|
return xr.DataArray(
|
|
1193
1249
|
rf_spend,
|
|
@@ -1206,6 +1262,7 @@ def random_non_media_treatments_da(
|
|
|
1206
1262
|
date_format: str = c.DATE_FORMAT,
|
|
1207
1263
|
explicit_time_index: Sequence[str] | None = None,
|
|
1208
1264
|
integer_geos: bool = False,
|
|
1265
|
+
nonzero_shift: float = 0.0,
|
|
1209
1266
|
) -> xr.DataArray:
|
|
1210
1267
|
"""Generates a sample `non_media_treatments` DataArray.
|
|
1211
1268
|
|
|
@@ -1218,6 +1275,7 @@ def random_non_media_treatments_da(
|
|
|
1218
1275
|
date_format: The date format to use for time coordinate labels
|
|
1219
1276
|
explicit_time_index: If given, ignore `date_format` and use this as is
|
|
1220
1277
|
integer_geos: If True, the geos will be integers.
|
|
1278
|
+
nonzero_shift: A scalar value to add to the generated data.
|
|
1221
1279
|
|
|
1222
1280
|
Returns:
|
|
1223
1281
|
A DataArray containing random non-media variable.
|
|
@@ -1232,6 +1290,8 @@ def random_non_media_treatments_da(
|
|
|
1232
1290
|
non_media_channel,
|
|
1233
1291
|
size=(n_geos, n_times, n_non_media_channels),
|
|
1234
1292
|
)
|
|
1293
|
+
non_media_treatments = non_media_treatments + nonzero_shift
|
|
1294
|
+
|
|
1235
1295
|
return xr.DataArray(
|
|
1236
1296
|
non_media_treatments,
|
|
1237
1297
|
dims=['geo', 'time', 'non_media_channel'],
|
|
@@ -1268,6 +1328,7 @@ def random_dataset(
|
|
|
1268
1328
|
remove_media_time: bool = False,
|
|
1269
1329
|
integer_geos: bool = False,
|
|
1270
1330
|
kpi_data_pattern: str = '',
|
|
1331
|
+
nonzero_shift: float = 0.0,
|
|
1271
1332
|
) -> xr.Dataset:
|
|
1272
1333
|
"""Generates a random dataset."""
|
|
1273
1334
|
if n_media_channels:
|
|
@@ -1280,6 +1341,7 @@ def random_dataset(
|
|
|
1280
1341
|
integer_geos=integer_geos,
|
|
1281
1342
|
explicit_media_channel_names=explicit_media_channel_names,
|
|
1282
1343
|
media_value_scales=media_value_scales,
|
|
1344
|
+
nonzero_shift=nonzero_shift,
|
|
1283
1345
|
)
|
|
1284
1346
|
media_spend = random_media_spend_nd_da(
|
|
1285
1347
|
n_geos=n_geos,
|
|
@@ -1288,6 +1350,7 @@ def random_dataset(
|
|
|
1288
1350
|
explicit_media_channel_names=explicit_media_channel_names,
|
|
1289
1351
|
seed=seed,
|
|
1290
1352
|
integer_geos=integer_geos,
|
|
1353
|
+
nonzero_shift=nonzero_shift,
|
|
1291
1354
|
)
|
|
1292
1355
|
else:
|
|
1293
1356
|
media = None
|
|
@@ -1301,6 +1364,7 @@ def random_dataset(
|
|
|
1301
1364
|
n_rf_channels=n_rf_channels,
|
|
1302
1365
|
seed=seed,
|
|
1303
1366
|
integer_geos=integer_geos,
|
|
1367
|
+
nonzero_shift=nonzero_shift,
|
|
1304
1368
|
)
|
|
1305
1369
|
frequency = random_frequency_da(
|
|
1306
1370
|
n_geos=n_geos,
|
|
@@ -1309,6 +1373,7 @@ def random_dataset(
|
|
|
1309
1373
|
n_rf_channels=n_rf_channels,
|
|
1310
1374
|
seed=seed,
|
|
1311
1375
|
integer_geos=integer_geos,
|
|
1376
|
+
nonzero_shift=nonzero_shift,
|
|
1312
1377
|
)
|
|
1313
1378
|
rf_spend = random_rf_spend_nd_da(
|
|
1314
1379
|
n_geos=n_geos,
|
|
@@ -1316,6 +1381,7 @@ def random_dataset(
|
|
|
1316
1381
|
n_rf_channels=n_rf_channels,
|
|
1317
1382
|
seed=seed,
|
|
1318
1383
|
integer_geos=integer_geos,
|
|
1384
|
+
nonzero_shift=nonzero_shift,
|
|
1319
1385
|
)
|
|
1320
1386
|
else:
|
|
1321
1387
|
reach = None
|
|
@@ -1352,6 +1418,7 @@ def random_dataset(
|
|
|
1352
1418
|
n_non_media_channels=n_non_media_channels,
|
|
1353
1419
|
seed=seed,
|
|
1354
1420
|
integer_geos=integer_geos,
|
|
1421
|
+
nonzero_shift=nonzero_shift,
|
|
1355
1422
|
)
|
|
1356
1423
|
else:
|
|
1357
1424
|
non_media_treatments = None
|
|
@@ -1364,6 +1431,7 @@ def random_dataset(
|
|
|
1364
1431
|
n_organic_media_channels=n_organic_media_channels,
|
|
1365
1432
|
seed=seed,
|
|
1366
1433
|
integer_geos=integer_geos,
|
|
1434
|
+
nonzero_shift=nonzero_shift,
|
|
1367
1435
|
)
|
|
1368
1436
|
else:
|
|
1369
1437
|
organic_media = None
|
|
@@ -1376,6 +1444,7 @@ def random_dataset(
|
|
|
1376
1444
|
n_organic_rf_channels=n_organic_rf_channels,
|
|
1377
1445
|
seed=seed,
|
|
1378
1446
|
integer_geos=integer_geos,
|
|
1447
|
+
nonzero_shift=nonzero_shift,
|
|
1379
1448
|
)
|
|
1380
1449
|
organic_frequency = random_organic_frequency_da(
|
|
1381
1450
|
n_geos=n_geos,
|
|
@@ -1384,6 +1453,7 @@ def random_dataset(
|
|
|
1384
1453
|
n_organic_rf_channels=n_organic_rf_channels,
|
|
1385
1454
|
seed=seed,
|
|
1386
1455
|
integer_geos=integer_geos,
|
|
1456
|
+
nonzero_shift=nonzero_shift,
|
|
1387
1457
|
)
|
|
1388
1458
|
else:
|
|
1389
1459
|
organic_reach = None
|
|
@@ -1406,53 +1476,37 @@ def random_dataset(
|
|
|
1406
1476
|
constant_value=constant_population_value,
|
|
1407
1477
|
)
|
|
1408
1478
|
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1479
|
+
to_merge = [kpi, population]
|
|
1480
|
+
if controls is not None:
|
|
1481
|
+
to_merge.append(controls)
|
|
1412
1482
|
if revenue_per_kpi is not None:
|
|
1413
|
-
|
|
1483
|
+
to_merge.append(revenue_per_kpi)
|
|
1414
1484
|
if media is not None:
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
)
|
|
1418
|
-
|
|
1419
|
-
dataset = xr.combine_by_coords([dataset, media_renamed, media_spend])
|
|
1485
|
+
if remove_media_time:
|
|
1486
|
+
media = media.rename({'media_time': 'time'})
|
|
1487
|
+
to_merge.append(media)
|
|
1488
|
+
to_merge.append(media_spend)
|
|
1420
1489
|
if reach is not None:
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
else frequency
|
|
1428
|
-
)
|
|
1429
|
-
dataset = xr.combine_by_coords(
|
|
1430
|
-
[dataset, reach_renamed, frequency_renamed, rf_spend]
|
|
1431
|
-
)
|
|
1490
|
+
if remove_media_time:
|
|
1491
|
+
reach = reach.rename({'media_time': 'time'})
|
|
1492
|
+
frequency = frequency.rename({'media_time': 'time'})
|
|
1493
|
+
to_merge.append(reach)
|
|
1494
|
+
to_merge.append(frequency)
|
|
1495
|
+
to_merge.append(rf_spend)
|
|
1432
1496
|
if non_media_treatments is not None:
|
|
1433
|
-
|
|
1497
|
+
to_merge.append(non_media_treatments)
|
|
1434
1498
|
if organic_media is not None:
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
else organic_media
|
|
1439
|
-
)
|
|
1440
|
-
dataset = xr.combine_by_coords([dataset, organic_media_renamed])
|
|
1499
|
+
if remove_media_time:
|
|
1500
|
+
organic_media = organic_media.rename({'media_time': 'time'})
|
|
1501
|
+
to_merge.append(organic_media)
|
|
1441
1502
|
if organic_reach is not None:
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
)
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
if remove_media_time
|
|
1450
|
-
else organic_frequency
|
|
1451
|
-
)
|
|
1452
|
-
dataset = xr.combine_by_coords(
|
|
1453
|
-
[dataset, organic_reach_renamed, organic_frequency_renamed]
|
|
1454
|
-
)
|
|
1455
|
-
return dataset
|
|
1503
|
+
if remove_media_time:
|
|
1504
|
+
organic_reach = organic_reach.rename({'media_time': 'time'})
|
|
1505
|
+
organic_frequency = organic_frequency.rename({'media_time': 'time'})
|
|
1506
|
+
to_merge.append(organic_reach)
|
|
1507
|
+
to_merge.append(organic_frequency)
|
|
1508
|
+
|
|
1509
|
+
return xr.merge(to_merge, join='outer', compat='no_conflicts')
|
|
1456
1510
|
|
|
1457
1511
|
|
|
1458
1512
|
def dataset_to_dataframe(
|
|
@@ -1794,6 +1848,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1794
1848
|
n_organic_media_channels: int | None = None,
|
|
1795
1849
|
n_organic_rf_channels: int | None = None,
|
|
1796
1850
|
seed: int = 0,
|
|
1851
|
+
nonzero_shift: float = 0.0,
|
|
1797
1852
|
) -> input_data.InputData:
|
|
1798
1853
|
"""Generates sample InputData for `non_revenue` KPI w/ revenue_per_kpi."""
|
|
1799
1854
|
dataset = random_dataset(
|
|
@@ -1807,6 +1862,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1807
1862
|
n_organic_media_channels=n_organic_media_channels,
|
|
1808
1863
|
n_organic_rf_channels=n_organic_rf_channels,
|
|
1809
1864
|
seed=seed,
|
|
1865
|
+
nonzero_shift=nonzero_shift,
|
|
1810
1866
|
)
|
|
1811
1867
|
return input_data.InputData(
|
|
1812
1868
|
kpi=dataset.kpi,
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""This module contains common validation functions for Meridian data."""
|
|
16
|
+
|
|
17
|
+
import datetime as dt
|
|
18
|
+
from meridian import constants
|
|
19
|
+
import xarray as xr
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_time_coord_format(array: xr.DataArray | None):
|
|
23
|
+
"""Validates the `time` dimensions format of the selected DataArray.
|
|
24
|
+
|
|
25
|
+
The `time` dimension of the selected array must have labels that are
|
|
26
|
+
formatted in the Meridian conventional `"yyyy-mm-dd"` format.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
array: An optional DataArray to validate.
|
|
30
|
+
"""
|
|
31
|
+
if array is None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
# The component data arrays from the input data builders that call this helper
|
|
35
|
+
# method should only have one of either `media_time` or `time` as its time
|
|
36
|
+
# dimension.
|
|
37
|
+
target_coords = [constants.TIME, constants.MEDIA_TIME]
|
|
38
|
+
|
|
39
|
+
for coord_name in target_coords:
|
|
40
|
+
if (values := array.coords.get(coord_name)) is not None:
|
|
41
|
+
for time in values:
|
|
42
|
+
try:
|
|
43
|
+
dt.datetime.strptime(time.item(), constants.DATE_FORMAT)
|
|
44
|
+
except (TypeError, ValueError) as exc:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Invalid {coord_name} label: {time.item()!r}. "
|
|
47
|
+
f"Expected format: '{constants.DATE_FORMAT}'"
|
|
48
|
+
) from exc
|
meridian/mlflow/autolog.py
CHANGED
|
@@ -70,6 +70,7 @@ import dataclasses
|
|
|
70
70
|
import inspect
|
|
71
71
|
import json
|
|
72
72
|
from typing import Any, Callable
|
|
73
|
+
import warnings
|
|
73
74
|
|
|
74
75
|
import arviz as az
|
|
75
76
|
from meridian import backend
|
|
@@ -180,16 +181,25 @@ def autolog(
|
|
|
180
181
|
f"sample_posterior.{param}", kwargs.get(param, "default")
|
|
181
182
|
)
|
|
182
183
|
|
|
183
|
-
original(self, *args, **kwargs)
|
|
184
|
+
result = original(self, *args, **kwargs)
|
|
184
185
|
if log_metrics:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
186
|
+
# TODO: Direct injection of `model.Meridian` object into
|
|
187
|
+
# `PosteriorMCMCSampler` is deprecated. Revisit patching method here.
|
|
188
|
+
if self.model is not None:
|
|
189
|
+
model_diagnostics = visualizer.ModelDiagnostics(self.model)
|
|
190
|
+
df_diag = model_diagnostics.predictive_accuracy_table()
|
|
191
|
+
|
|
192
|
+
get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
|
|
193
|
+
|
|
194
|
+
mlflow.log_metric("R_Squared", get_metric("R_Squared"))
|
|
195
|
+
mlflow.log_metric("MAPE", get_metric("MAPE"))
|
|
196
|
+
mlflow.log_metric("wMAPE", get_metric("wMAPE"))
|
|
197
|
+
else:
|
|
198
|
+
warnings.warn(
|
|
199
|
+
"log_metrics=True is not supported when PosteriorMCMCSampler is"
|
|
200
|
+
" initialized with model_context."
|
|
201
|
+
)
|
|
202
|
+
return result
|
|
193
203
|
|
|
194
204
|
safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
|
|
195
205
|
safe_patch(
|
meridian/model/__init__.py
CHANGED
|
@@ -15,7 +15,9 @@
|
|
|
15
15
|
"""The Meridian API module that models the data."""
|
|
16
16
|
|
|
17
17
|
from meridian.model import adstock_hill
|
|
18
|
+
from meridian.model import context
|
|
18
19
|
from meridian.model import eda
|
|
20
|
+
from meridian.model import equations
|
|
19
21
|
from meridian.model import knots
|
|
20
22
|
from meridian.model import media
|
|
21
23
|
from meridian.model import model
|
meridian/model/adstock_hill.py
CHANGED
|
@@ -279,10 +279,6 @@ def _adstock(
|
|
|
279
279
|
media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
|
|
280
280
|
|
|
281
281
|
# Adstock calculation.
|
|
282
|
-
window_list = [None] * window_size
|
|
283
|
-
for i in range(window_size):
|
|
284
|
-
window_list[i] = media[..., i : i + n_times_output, :]
|
|
285
|
-
windowed = backend.stack(window_list)
|
|
286
282
|
l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
|
|
287
283
|
weights = compute_decay_weights(
|
|
288
284
|
alpha=alpha,
|
|
@@ -291,7 +287,9 @@ def _adstock(
|
|
|
291
287
|
decay_functions=decay_functions,
|
|
292
288
|
normalize=True,
|
|
293
289
|
)
|
|
294
|
-
return backend.
|
|
290
|
+
return backend.adstock_process(
|
|
291
|
+
media=media, weights=weights, n_times_output=n_times_output
|
|
292
|
+
)
|
|
295
293
|
|
|
296
294
|
|
|
297
295
|
def _map_alpha_for_binomial_decay(x: backend.Tensor):
|