google-meridian 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/METADATA +2 -2
- google_meridian-1.1.1.dist-info/RECORD +41 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/WHEEL +1 -1
- meridian/__init__.py +2 -2
- meridian/analysis/__init__.py +1 -1
- meridian/analysis/analyzer.py +18 -17
- meridian/analysis/formatter.py +1 -1
- meridian/analysis/optimizer.py +1 -1
- meridian/analysis/summarizer.py +1 -1
- meridian/analysis/summary_text.py +1 -1
- meridian/analysis/test_utils.py +1 -1
- meridian/analysis/visualizer.py +2 -3
- meridian/constants.py +3 -3
- meridian/data/__init__.py +1 -1
- meridian/data/arg_builder.py +1 -1
- meridian/data/input_data.py +12 -8
- meridian/data/load.py +53 -40
- meridian/data/test_utils.py +60 -43
- meridian/data/time_coordinates.py +1 -1
- meridian/model/__init__.py +1 -1
- meridian/model/adstock_hill.py +1 -1
- meridian/model/knots.py +1 -1
- meridian/model/media.py +1 -1
- meridian/model/model.py +47 -27
- meridian/model/model_test_data.py +75 -1
- meridian/model/posterior_sampler.py +19 -15
- meridian/model/prior_distribution.py +1 -1
- meridian/model/prior_sampler.py +32 -26
- meridian/model/spec.py +1 -1
- meridian/model/transformers.py +1 -1
- google_meridian-1.1.0.dist-info/RECORD +0 -41
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/top_level.txt +0 -0
meridian/data/test_utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -409,10 +409,7 @@ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
|
|
|
409
409
|
)
|
|
410
410
|
|
|
411
411
|
DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
|
|
412
|
-
coords=_REQUIRED_COORDS
|
|
413
|
-
| _MEDIA_COORDS
|
|
414
|
-
| _RF_COORDS
|
|
415
|
-
| _ORGANIC_RF_COORDS,
|
|
412
|
+
coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _ORGANIC_RF_COORDS,
|
|
416
413
|
data_vars=_REQUIRED_DATA_VARS
|
|
417
414
|
| _MEDIA_DATA_VARS
|
|
418
415
|
| _RF_DATA_VARS
|
|
@@ -818,11 +815,11 @@ def random_controls_da(
|
|
|
818
815
|
|
|
819
816
|
def random_kpi_da(
|
|
820
817
|
media: xr.DataArray,
|
|
821
|
-
controls: xr.DataArray,
|
|
822
818
|
n_geos: int,
|
|
823
819
|
n_times: int,
|
|
824
820
|
n_media_channels: int,
|
|
825
|
-
n_controls: int,
|
|
821
|
+
n_controls: int | None = None,
|
|
822
|
+
controls: xr.DataArray | None = None,
|
|
826
823
|
seed: int = 0,
|
|
827
824
|
integer_geos: bool = False,
|
|
828
825
|
) -> xr.DataArray:
|
|
@@ -838,19 +835,26 @@ def random_kpi_da(
|
|
|
838
835
|
n_media_channels,
|
|
839
836
|
axis=2,
|
|
840
837
|
)
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
838
|
+
if n_controls:
|
|
839
|
+
control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
|
|
840
|
+
control_geo_sd = np.repeat(
|
|
841
|
+
np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
|
|
842
|
+
..., np.newaxis
|
|
843
|
+
],
|
|
844
|
+
n_controls,
|
|
845
|
+
axis=2,
|
|
846
|
+
)
|
|
847
|
+
else:
|
|
848
|
+
control_geo_sd = 0
|
|
849
849
|
|
|
850
850
|
# Simulates outcome which is the dependent variable. Typically this is the
|
|
851
851
|
# number of units sold, but it can be any metric (e.g. revenue).
|
|
852
852
|
media_portion = np.random.normal(media_common, media_geo_sd).sum(axis=2)
|
|
853
|
-
|
|
853
|
+
if controls is not None:
|
|
854
|
+
control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
|
|
855
|
+
else:
|
|
856
|
+
control_portion = 0
|
|
857
|
+
|
|
854
858
|
error = np.random.normal(0, 2, size=(n_geos, n_times))
|
|
855
859
|
kpi = abs(media_portion + control_portion + error)
|
|
856
860
|
|
|
@@ -1161,7 +1165,7 @@ def random_dataset(
|
|
|
1161
1165
|
n_geos: int,
|
|
1162
1166
|
n_times: int,
|
|
1163
1167
|
n_media_times: int,
|
|
1164
|
-
n_controls: int,
|
|
1168
|
+
n_controls: int | None = None,
|
|
1165
1169
|
n_non_media_channels: int | None = None,
|
|
1166
1170
|
n_organic_media_channels: int | None = None,
|
|
1167
1171
|
n_organic_rf_channels: int | None = None,
|
|
@@ -1171,7 +1175,7 @@ def random_dataset(
|
|
|
1171
1175
|
seed: int = 0,
|
|
1172
1176
|
remove_media_time: bool = False,
|
|
1173
1177
|
integer_geos: bool = False,
|
|
1174
|
-
):
|
|
1178
|
+
) -> xr.Dataset:
|
|
1175
1179
|
"""Generates a random dataset."""
|
|
1176
1180
|
if n_media_channels:
|
|
1177
1181
|
media = random_media_da(
|
|
@@ -1232,14 +1236,18 @@ def random_dataset(
|
|
|
1232
1236
|
else:
|
|
1233
1237
|
revenue_per_kpi = None
|
|
1234
1238
|
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1239
|
+
if n_controls:
|
|
1240
|
+
controls = random_controls_da(
|
|
1241
|
+
media=media if n_media_channels else reach,
|
|
1242
|
+
n_geos=n_geos,
|
|
1243
|
+
n_times=n_times,
|
|
1244
|
+
n_controls=n_controls,
|
|
1245
|
+
seed=seed,
|
|
1246
|
+
integer_geos=integer_geos,
|
|
1247
|
+
)
|
|
1248
|
+
else:
|
|
1249
|
+
controls = None
|
|
1250
|
+
|
|
1243
1251
|
if n_non_media_channels:
|
|
1244
1252
|
non_media_treatments = random_non_media_treatments_da(
|
|
1245
1253
|
media=media if n_media_channels else reach,
|
|
@@ -1298,7 +1306,9 @@ def random_dataset(
|
|
|
1298
1306
|
n_geos=n_geos, seed=seed, integer_geos=integer_geos
|
|
1299
1307
|
)
|
|
1300
1308
|
|
|
1301
|
-
dataset = xr.combine_by_coords(
|
|
1309
|
+
dataset = xr.combine_by_coords(
|
|
1310
|
+
[kpi, population] + ([controls] if controls is not None else [])
|
|
1311
|
+
)
|
|
1302
1312
|
if revenue_per_kpi is not None:
|
|
1303
1313
|
dataset = xr.combine_by_coords([dataset, revenue_per_kpi])
|
|
1304
1314
|
if media is not None:
|
|
@@ -1347,7 +1357,7 @@ def random_dataset(
|
|
|
1347
1357
|
|
|
1348
1358
|
def dataset_to_dataframe(
|
|
1349
1359
|
dataset: xr.Dataset,
|
|
1350
|
-
controls_column_names: list[str],
|
|
1360
|
+
controls_column_names: list[str] | None = None,
|
|
1351
1361
|
media_column_names: list[str] | None = None,
|
|
1352
1362
|
media_spend_column_names: list[str] | None = None,
|
|
1353
1363
|
reach_column_names: list[str] | None = None,
|
|
@@ -1396,10 +1406,15 @@ def dataset_to_dataframe(
|
|
|
1396
1406
|
)
|
|
1397
1407
|
population = dataset[c.POPULATION].to_dataframe(name=c.POPULATION)
|
|
1398
1408
|
|
|
1399
|
-
|
|
1400
|
-
|
|
1409
|
+
if controls_column_names is not None:
|
|
1410
|
+
controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
|
|
1411
|
+
controls.columns = controls_column_names
|
|
1412
|
+
else:
|
|
1413
|
+
controls = None
|
|
1401
1414
|
|
|
1402
|
-
result = kpi.join(revenue_per_kpi).join(population)
|
|
1415
|
+
result = kpi.join(revenue_per_kpi).join(population)
|
|
1416
|
+
if controls is not None:
|
|
1417
|
+
result = result.join(controls)
|
|
1403
1418
|
|
|
1404
1419
|
if non_media_column_names is not None:
|
|
1405
1420
|
non_media_treatments = (
|
|
@@ -1471,7 +1486,7 @@ def random_dataframe(
|
|
|
1471
1486
|
n_geos,
|
|
1472
1487
|
n_times,
|
|
1473
1488
|
n_media_times,
|
|
1474
|
-
n_controls,
|
|
1489
|
+
n_controls=None,
|
|
1475
1490
|
n_media_channels=None,
|
|
1476
1491
|
n_rf_channels=None,
|
|
1477
1492
|
seed=0,
|
|
@@ -1489,7 +1504,9 @@ def random_dataframe(
|
|
|
1489
1504
|
|
|
1490
1505
|
return dataset_to_dataframe(
|
|
1491
1506
|
dataset,
|
|
1492
|
-
controls_column_names=
|
|
1507
|
+
controls_column_names=(
|
|
1508
|
+
_sample_names('control_', n_controls) if n_controls else None
|
|
1509
|
+
),
|
|
1493
1510
|
media_column_names=_sample_names('media_', n_media_channels),
|
|
1494
1511
|
media_spend_column_names=_sample_names('media_spend_', n_media_channels),
|
|
1495
1512
|
reach_column_names=_sample_names('reach_', n_rf_channels),
|
|
@@ -1499,7 +1516,7 @@ def random_dataframe(
|
|
|
1499
1516
|
|
|
1500
1517
|
|
|
1501
1518
|
def sample_coord_to_columns(
|
|
1502
|
-
n_controls: int,
|
|
1519
|
+
n_controls: int | None = None,
|
|
1503
1520
|
n_media_channels: int | None = None,
|
|
1504
1521
|
n_rf_channels: int | None = None,
|
|
1505
1522
|
n_non_media_channels: int | None = None,
|
|
@@ -1550,7 +1567,7 @@ def sample_coord_to_columns(
|
|
|
1550
1567
|
kpi=c.KPI,
|
|
1551
1568
|
revenue_per_kpi=c.REVENUE_PER_KPI if include_revenue_per_kpi else None,
|
|
1552
1569
|
population=c.POPULATION,
|
|
1553
|
-
controls=_sample_names('control_', n_controls),
|
|
1570
|
+
controls=(_sample_names('control_', n_controls) if n_controls else None),
|
|
1554
1571
|
media=media,
|
|
1555
1572
|
media_spend=media_spend,
|
|
1556
1573
|
reach=reach,
|
|
@@ -1668,7 +1685,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1668
1685
|
n_geos: int = 10,
|
|
1669
1686
|
n_times: int = 50,
|
|
1670
1687
|
n_media_times: int = 53,
|
|
1671
|
-
n_controls: int = 2,
|
|
1688
|
+
n_controls: int | None = 2,
|
|
1672
1689
|
n_non_media_channels: int | None = None,
|
|
1673
1690
|
n_media_channels: int | None = None,
|
|
1674
1691
|
n_rf_channels: int | None = None,
|
|
@@ -1694,10 +1711,10 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1694
1711
|
kpi_type=c.NON_REVENUE,
|
|
1695
1712
|
revenue_per_kpi=dataset.revenue_per_kpi,
|
|
1696
1713
|
population=dataset.population,
|
|
1697
|
-
controls=dataset.controls,
|
|
1698
|
-
non_media_treatments=
|
|
1699
|
-
|
|
1700
|
-
|
|
1714
|
+
controls=(dataset.controls if n_controls else None),
|
|
1715
|
+
non_media_treatments=(
|
|
1716
|
+
dataset.non_media_treatments if n_non_media_channels else None
|
|
1717
|
+
),
|
|
1701
1718
|
media=dataset.media if n_media_channels else None,
|
|
1702
1719
|
media_spend=dataset.media_spend if n_media_channels else None,
|
|
1703
1720
|
reach=dataset.reach if n_rf_channels else None,
|
|
@@ -1705,9 +1722,9 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1705
1722
|
rf_spend=dataset.rf_spend if n_rf_channels else None,
|
|
1706
1723
|
organic_media=dataset.organic_media if n_organic_media_channels else None,
|
|
1707
1724
|
organic_reach=dataset.organic_reach if n_organic_rf_channels else None,
|
|
1708
|
-
organic_frequency=
|
|
1709
|
-
|
|
1710
|
-
|
|
1725
|
+
organic_frequency=(
|
|
1726
|
+
dataset.organic_frequency if n_organic_rf_channels else None
|
|
1727
|
+
),
|
|
1711
1728
|
)
|
|
1712
1729
|
|
|
1713
1730
|
|
meridian/model/__init__.py
CHANGED
meridian/model/adstock_hill.py
CHANGED
meridian/model/knots.py
CHANGED
meridian/model/media.py
CHANGED
meridian/model/model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -214,7 +214,9 @@ class Meridian:
|
|
|
214
214
|
)
|
|
215
215
|
|
|
216
216
|
@functools.cached_property
|
|
217
|
-
def controls(self) -> tf.Tensor:
|
|
217
|
+
def controls(self) -> tf.Tensor | None:
|
|
218
|
+
if self.input_data.controls is None:
|
|
219
|
+
return None
|
|
218
220
|
return tf.convert_to_tensor(self.input_data.controls, dtype=tf.float32)
|
|
219
221
|
|
|
220
222
|
@functools.cached_property
|
|
@@ -271,6 +273,8 @@ class Meridian:
|
|
|
271
273
|
|
|
272
274
|
@property
|
|
273
275
|
def n_controls(self) -> int:
|
|
276
|
+
if self.input_data.control_variable is None:
|
|
277
|
+
return 0
|
|
274
278
|
return len(self.input_data.control_variable)
|
|
275
279
|
|
|
276
280
|
@property
|
|
@@ -304,7 +308,13 @@ class Meridian:
|
|
|
304
308
|
)
|
|
305
309
|
|
|
306
310
|
@functools.cached_property
|
|
307
|
-
def controls_transformer(
|
|
311
|
+
def controls_transformer(
|
|
312
|
+
self,
|
|
313
|
+
) -> transformers.CenteringAndScalingTransformer | None:
|
|
314
|
+
"""Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
|
|
315
|
+
if self.controls is None:
|
|
316
|
+
return None
|
|
317
|
+
|
|
308
318
|
if self.model_spec.control_population_scaling_id is not None:
|
|
309
319
|
controls_population_scaling_id = tf.convert_to_tensor(
|
|
310
320
|
self.model_spec.control_population_scaling_id, dtype=bool
|
|
@@ -343,8 +353,12 @@ class Meridian:
|
|
|
343
353
|
return transformers.KpiTransformer(self.kpi, self.population)
|
|
344
354
|
|
|
345
355
|
@functools.cached_property
|
|
346
|
-
def controls_scaled(self) -> tf.Tensor:
|
|
347
|
-
|
|
356
|
+
def controls_scaled(self) -> tf.Tensor | None:
|
|
357
|
+
if self.controls is not None:
|
|
358
|
+
# If `controls` is defined, then `controls_transformer` is also defined.
|
|
359
|
+
return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
|
|
360
|
+
else:
|
|
361
|
+
return None
|
|
348
362
|
|
|
349
363
|
@functools.cached_property
|
|
350
364
|
def non_media_treatments_normalized(self) -> tf.Tensor | None:
|
|
@@ -894,11 +908,12 @@ class Meridian:
|
|
|
894
908
|
if self.is_national:
|
|
895
909
|
return
|
|
896
910
|
|
|
897
|
-
self.
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
911
|
+
if self.input_data.controls is not None:
|
|
912
|
+
self._check_if_no_geo_variation(
|
|
913
|
+
self.controls_scaled,
|
|
914
|
+
constants.CONTROLS,
|
|
915
|
+
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
916
|
+
)
|
|
902
917
|
if self.input_data.non_media_treatments is not None:
|
|
903
918
|
self._check_if_no_geo_variation(
|
|
904
919
|
self.non_media_treatments_normalized,
|
|
@@ -971,12 +986,12 @@ class Meridian:
|
|
|
971
986
|
|
|
972
987
|
def _validate_time_invariants(self):
|
|
973
988
|
"""Validates model time invariants."""
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
989
|
+
if self.input_data.controls is not None:
|
|
990
|
+
self._check_if_no_time_variation(
|
|
991
|
+
self.controls_scaled,
|
|
992
|
+
constants.CONTROLS,
|
|
993
|
+
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
994
|
+
)
|
|
980
995
|
if self.input_data.non_media_treatments is not None:
|
|
981
996
|
self._check_if_no_time_variation(
|
|
982
997
|
self.non_media_treatments_normalized,
|
|
@@ -1356,31 +1371,36 @@ class Meridian:
|
|
|
1356
1371
|
self, n_chains: int, n_draws: int
|
|
1357
1372
|
) -> Mapping[str, np.ndarray | Sequence[str]]:
|
|
1358
1373
|
"""Creates data coordinates for inference data."""
|
|
1359
|
-
|
|
1374
|
+
media_channel_names = (
|
|
1360
1375
|
self.input_data.media_channel
|
|
1361
1376
|
if self.input_data.media_channel is not None
|
|
1362
1377
|
else np.array([])
|
|
1363
1378
|
)
|
|
1364
|
-
|
|
1379
|
+
rf_channel_names = (
|
|
1365
1380
|
self.input_data.rf_channel
|
|
1366
1381
|
if self.input_data.rf_channel is not None
|
|
1367
1382
|
else np.array([])
|
|
1368
1383
|
)
|
|
1369
|
-
|
|
1384
|
+
organic_media_channel_names = (
|
|
1370
1385
|
self.input_data.organic_media_channel
|
|
1371
1386
|
if self.input_data.organic_media_channel is not None
|
|
1372
1387
|
else np.array([])
|
|
1373
1388
|
)
|
|
1374
|
-
|
|
1389
|
+
organic_rf_channel_names = (
|
|
1375
1390
|
self.input_data.organic_rf_channel
|
|
1376
1391
|
if self.input_data.organic_rf_channel is not None
|
|
1377
1392
|
else np.array([])
|
|
1378
1393
|
)
|
|
1379
|
-
|
|
1394
|
+
non_media_channel_names = (
|
|
1380
1395
|
self.input_data.non_media_channel
|
|
1381
1396
|
if self.input_data.non_media_channel is not None
|
|
1382
1397
|
else np.array([])
|
|
1383
1398
|
)
|
|
1399
|
+
control_variable_names = (
|
|
1400
|
+
self.input_data.control_variable
|
|
1401
|
+
if self.input_data.control_variable is not None
|
|
1402
|
+
else np.array([])
|
|
1403
|
+
)
|
|
1384
1404
|
return {
|
|
1385
1405
|
constants.CHAIN: np.arange(n_chains),
|
|
1386
1406
|
constants.DRAW: np.arange(n_draws),
|
|
@@ -1388,12 +1408,12 @@ class Meridian:
|
|
|
1388
1408
|
constants.TIME: self.input_data.time,
|
|
1389
1409
|
constants.MEDIA_TIME: self.input_data.media_time,
|
|
1390
1410
|
constants.KNOTS: np.arange(self.knot_info.n_knots),
|
|
1391
|
-
constants.CONTROL_VARIABLE:
|
|
1392
|
-
constants.NON_MEDIA_CHANNEL:
|
|
1393
|
-
constants.MEDIA_CHANNEL:
|
|
1394
|
-
constants.RF_CHANNEL:
|
|
1395
|
-
constants.ORGANIC_MEDIA_CHANNEL:
|
|
1396
|
-
constants.ORGANIC_RF_CHANNEL:
|
|
1411
|
+
constants.CONTROL_VARIABLE: control_variable_names,
|
|
1412
|
+
constants.NON_MEDIA_CHANNEL: non_media_channel_names,
|
|
1413
|
+
constants.MEDIA_CHANNEL: media_channel_names,
|
|
1414
|
+
constants.RF_CHANNEL: rf_channel_names,
|
|
1415
|
+
constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
|
|
1416
|
+
constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
|
|
1397
1417
|
}
|
|
1398
1418
|
|
|
1399
1419
|
def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -70,6 +70,10 @@ class WithInputDataSamples:
|
|
|
70
70
|
_TEST_DIR,
|
|
71
71
|
"sample_prior_media_only.nc",
|
|
72
72
|
)
|
|
73
|
+
_TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
|
|
74
|
+
_TEST_DIR,
|
|
75
|
+
"sample_prior_media_only_no_controls.nc",
|
|
76
|
+
)
|
|
73
77
|
_TEST_SAMPLE_PRIOR_RF_ONLY_PATH = os.path.join(
|
|
74
78
|
_TEST_DIR,
|
|
75
79
|
"sample_prior_rf_only.nc",
|
|
@@ -82,6 +86,10 @@ class WithInputDataSamples:
|
|
|
82
86
|
_TEST_DIR,
|
|
83
87
|
"sample_posterior_media_only.nc",
|
|
84
88
|
)
|
|
89
|
+
_TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
|
|
90
|
+
_TEST_DIR,
|
|
91
|
+
"sample_posterior_media_only_no_controls.nc",
|
|
92
|
+
)
|
|
85
93
|
_TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH = os.path.join(
|
|
86
94
|
_TEST_DIR,
|
|
87
95
|
"sample_posterior_rf_only.nc",
|
|
@@ -172,6 +180,17 @@ class WithInputDataSamples:
|
|
|
172
180
|
seed=0,
|
|
173
181
|
)
|
|
174
182
|
)
|
|
183
|
+
self.input_data_with_media_and_rf_no_controls = (
|
|
184
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
185
|
+
n_geos=self._N_GEOS,
|
|
186
|
+
n_times=self._N_TIMES,
|
|
187
|
+
n_media_times=self._N_MEDIA_TIMES,
|
|
188
|
+
n_controls=None,
|
|
189
|
+
n_media_channels=self._N_MEDIA_CHANNELS,
|
|
190
|
+
n_rf_channels=self._N_RF_CHANNELS,
|
|
191
|
+
seed=0,
|
|
192
|
+
)
|
|
193
|
+
)
|
|
175
194
|
self.short_input_data_with_media_only = (
|
|
176
195
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
177
196
|
n_geos=self._N_GEOS,
|
|
@@ -182,6 +201,16 @@ class WithInputDataSamples:
|
|
|
182
201
|
seed=0,
|
|
183
202
|
)
|
|
184
203
|
)
|
|
204
|
+
self.short_input_data_with_media_only_no_controls = (
|
|
205
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
206
|
+
n_geos=self._N_GEOS,
|
|
207
|
+
n_times=self._N_TIMES_SHORT,
|
|
208
|
+
n_media_times=self._N_MEDIA_TIMES_SHORT,
|
|
209
|
+
n_controls=0,
|
|
210
|
+
n_media_channels=self._N_MEDIA_CHANNELS,
|
|
211
|
+
seed=0,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
185
214
|
self.short_input_data_with_rf_only = (
|
|
186
215
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
187
216
|
n_geos=self._N_GEOS,
|
|
@@ -231,6 +260,9 @@ class WithInputDataSamples:
|
|
|
231
260
|
test_prior_media_only = xr.open_dataset(
|
|
232
261
|
self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
|
|
233
262
|
)
|
|
263
|
+
test_prior_media_only_no_controls = xr.open_dataset(
|
|
264
|
+
self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
265
|
+
)
|
|
234
266
|
test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
|
|
235
267
|
self.test_dist_media_and_rf = collections.OrderedDict({
|
|
236
268
|
param: tf.convert_to_tensor(test_prior_media_and_rf[param])
|
|
@@ -243,6 +275,18 @@ class WithInputDataSamples:
|
|
|
243
275
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
244
276
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
245
277
|
})
|
|
278
|
+
self.test_dist_media_only_no_controls = collections.OrderedDict({
|
|
279
|
+
param: tf.convert_to_tensor(test_prior_media_only_no_controls[param])
|
|
280
|
+
for param in (
|
|
281
|
+
set(
|
|
282
|
+
constants.COMMON_PARAMETER_NAMES
|
|
283
|
+
+ constants.MEDIA_PARAMETER_NAMES
|
|
284
|
+
)
|
|
285
|
+
- set(
|
|
286
|
+
constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
})
|
|
246
290
|
self.test_dist_rf_only = collections.OrderedDict({
|
|
247
291
|
param: tf.convert_to_tensor(test_prior_rf_only[param])
|
|
248
292
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
@@ -255,6 +299,9 @@ class WithInputDataSamples:
|
|
|
255
299
|
test_posterior_media_only = xr.open_dataset(
|
|
256
300
|
self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
|
|
257
301
|
)
|
|
302
|
+
test_posterior_media_only_no_controls = xr.open_dataset(
|
|
303
|
+
self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
304
|
+
)
|
|
258
305
|
test_posterior_rf_only = xr.open_dataset(
|
|
259
306
|
self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
|
|
260
307
|
)
|
|
@@ -273,6 +320,21 @@ class WithInputDataSamples:
|
|
|
273
320
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
274
321
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
275
322
|
}
|
|
323
|
+
posterior_params_to_tensors_media_only_no_controls = {
|
|
324
|
+
param: _convert_with_swap(
|
|
325
|
+
test_posterior_media_only_no_controls[param],
|
|
326
|
+
n_burnin=self._N_BURNIN,
|
|
327
|
+
)
|
|
328
|
+
for param in (
|
|
329
|
+
set(
|
|
330
|
+
constants.COMMON_PARAMETER_NAMES
|
|
331
|
+
+ constants.MEDIA_PARAMETER_NAMES
|
|
332
|
+
)
|
|
333
|
+
- set(
|
|
334
|
+
constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
}
|
|
276
338
|
posterior_params_to_tensors_rf_only = {
|
|
277
339
|
param: _convert_with_swap(
|
|
278
340
|
test_posterior_rf_only[param], n_burnin=self._N_BURNIN
|
|
@@ -290,6 +352,18 @@ class WithInputDataSamples:
|
|
|
290
352
|
"StructTuple",
|
|
291
353
|
constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
|
|
292
354
|
)(**posterior_params_to_tensors_media_only)
|
|
355
|
+
self.test_posterior_states_media_only_no_controls = collections.namedtuple(
|
|
356
|
+
"StructTuple",
|
|
357
|
+
(
|
|
358
|
+
set(
|
|
359
|
+
constants.COMMON_PARAMETER_NAMES
|
|
360
|
+
+ constants.MEDIA_PARAMETER_NAMES
|
|
361
|
+
)
|
|
362
|
+
- set(
|
|
363
|
+
constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
|
|
364
|
+
)
|
|
365
|
+
),
|
|
366
|
+
)(**posterior_params_to_tensors_media_only_no_controls)
|
|
293
367
|
self.test_posterior_states_rf_only = collections.namedtuple(
|
|
294
368
|
"StructTuple",
|
|
295
369
|
constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -120,8 +120,6 @@ class PosteriorMCMCSampler:
|
|
|
120
120
|
def joint_dist_unpinned():
|
|
121
121
|
# Sample directly from prior.
|
|
122
122
|
knot_values = yield prior_broadcast.knot_values
|
|
123
|
-
gamma_c = yield prior_broadcast.gamma_c
|
|
124
|
-
xi_c = yield prior_broadcast.xi_c
|
|
125
123
|
sigma = yield prior_broadcast.sigma
|
|
126
124
|
|
|
127
125
|
tau_g_excl_baseline = yield tfp.distributions.Sample(
|
|
@@ -377,19 +375,25 @@ class PosteriorMCMCSampler:
|
|
|
377
375
|
combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1)
|
|
378
376
|
|
|
379
377
|
sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos]))
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
[n_geos, n_controls],
|
|
383
|
-
name=constants.GAMMA_GC_DEV,
|
|
384
|
-
)
|
|
385
|
-
gamma_gc = yield tfp.distributions.Deterministic(
|
|
386
|
-
gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
|
|
387
|
-
)
|
|
388
|
-
y_pred_combined_media = (
|
|
389
|
-
tau_gt
|
|
390
|
-
+ tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta)
|
|
391
|
-
+ tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc)
|
|
378
|
+
y_pred_combined_media = tau_gt + tf.einsum(
|
|
379
|
+
"gtm,gm->gt", combined_media_transformed, combined_beta
|
|
392
380
|
)
|
|
381
|
+
# Omit gamma_c, xi_c, and gamma_gc from joint distribution output if
|
|
382
|
+
# there are no control variables in the model.
|
|
383
|
+
if n_controls:
|
|
384
|
+
gamma_c = yield prior_broadcast.gamma_c
|
|
385
|
+
xi_c = yield prior_broadcast.xi_c
|
|
386
|
+
gamma_gc_dev = yield tfp.distributions.Sample(
|
|
387
|
+
tfp.distributions.Normal(0, 1),
|
|
388
|
+
[n_geos, n_controls],
|
|
389
|
+
name=constants.GAMMA_GC_DEV,
|
|
390
|
+
)
|
|
391
|
+
gamma_gc = yield tfp.distributions.Deterministic(
|
|
392
|
+
gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
|
|
393
|
+
)
|
|
394
|
+
y_pred_combined_media += tf.einsum(
|
|
395
|
+
"gtc,gc->gt", controls_scaled, gamma_gc
|
|
396
|
+
)
|
|
393
397
|
|
|
394
398
|
if mmm.non_media_treatments is not None:
|
|
395
399
|
xi_n = yield prior_broadcast.xi_n
|