google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
- google_meridian-1.1.2.dist-info/RECORD +46 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
- meridian/__init__.py +2 -2
- meridian/analysis/__init__.py +1 -1
- meridian/analysis/analyzer.py +29 -22
- meridian/analysis/formatter.py +1 -1
- meridian/analysis/optimizer.py +70 -44
- meridian/analysis/summarizer.py +1 -1
- meridian/analysis/summary_text.py +1 -1
- meridian/analysis/test_utils.py +1 -1
- meridian/analysis/visualizer.py +17 -8
- meridian/constants.py +3 -3
- meridian/data/__init__.py +4 -1
- meridian/data/arg_builder.py +1 -1
- meridian/data/data_frame_input_data_builder.py +614 -0
- meridian/data/input_data.py +12 -8
- meridian/data/input_data_builder.py +817 -0
- meridian/data/load.py +121 -428
- meridian/data/nd_array_input_data_builder.py +509 -0
- meridian/data/test_utils.py +60 -43
- meridian/data/time_coordinates.py +1 -1
- meridian/mlflow/__init__.py +17 -0
- meridian/mlflow/autolog.py +54 -0
- meridian/model/__init__.py +1 -1
- meridian/model/adstock_hill.py +1 -1
- meridian/model/knots.py +1 -1
- meridian/model/media.py +1 -1
- meridian/model/model.py +65 -37
- meridian/model/model_test_data.py +75 -1
- meridian/model/posterior_sampler.py +19 -15
- meridian/model/prior_distribution.py +1 -1
- meridian/model/prior_sampler.py +32 -26
- meridian/model/spec.py +18 -8
- meridian/model/transformers.py +1 -1
- google_meridian-1.1.0.dist-info/RECORD +0 -41
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
meridian/data/test_utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -409,10 +409,7 @@ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
|
|
|
409
409
|
)
|
|
410
410
|
|
|
411
411
|
DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
|
|
412
|
-
coords=_REQUIRED_COORDS
|
|
413
|
-
| _MEDIA_COORDS
|
|
414
|
-
| _RF_COORDS
|
|
415
|
-
| _ORGANIC_RF_COORDS,
|
|
412
|
+
coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _ORGANIC_RF_COORDS,
|
|
416
413
|
data_vars=_REQUIRED_DATA_VARS
|
|
417
414
|
| _MEDIA_DATA_VARS
|
|
418
415
|
| _RF_DATA_VARS
|
|
@@ -818,11 +815,11 @@ def random_controls_da(
|
|
|
818
815
|
|
|
819
816
|
def random_kpi_da(
|
|
820
817
|
media: xr.DataArray,
|
|
821
|
-
controls: xr.DataArray,
|
|
822
818
|
n_geos: int,
|
|
823
819
|
n_times: int,
|
|
824
820
|
n_media_channels: int,
|
|
825
|
-
n_controls: int,
|
|
821
|
+
n_controls: int | None = None,
|
|
822
|
+
controls: xr.DataArray | None = None,
|
|
826
823
|
seed: int = 0,
|
|
827
824
|
integer_geos: bool = False,
|
|
828
825
|
) -> xr.DataArray:
|
|
@@ -838,19 +835,26 @@ def random_kpi_da(
|
|
|
838
835
|
n_media_channels,
|
|
839
836
|
axis=2,
|
|
840
837
|
)
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
838
|
+
if n_controls:
|
|
839
|
+
control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
|
|
840
|
+
control_geo_sd = np.repeat(
|
|
841
|
+
np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
|
|
842
|
+
..., np.newaxis
|
|
843
|
+
],
|
|
844
|
+
n_controls,
|
|
845
|
+
axis=2,
|
|
846
|
+
)
|
|
847
|
+
else:
|
|
848
|
+
control_geo_sd = 0
|
|
849
849
|
|
|
850
850
|
# Simulates outcome which is the dependent variable. Typically this is the
|
|
851
851
|
# number of units sold, but it can be any metric (e.g. revenue).
|
|
852
852
|
media_portion = np.random.normal(media_common, media_geo_sd).sum(axis=2)
|
|
853
|
-
|
|
853
|
+
if controls is not None:
|
|
854
|
+
control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
|
|
855
|
+
else:
|
|
856
|
+
control_portion = 0
|
|
857
|
+
|
|
854
858
|
error = np.random.normal(0, 2, size=(n_geos, n_times))
|
|
855
859
|
kpi = abs(media_portion + control_portion + error)
|
|
856
860
|
|
|
@@ -1161,7 +1165,7 @@ def random_dataset(
|
|
|
1161
1165
|
n_geos: int,
|
|
1162
1166
|
n_times: int,
|
|
1163
1167
|
n_media_times: int,
|
|
1164
|
-
n_controls: int,
|
|
1168
|
+
n_controls: int | None = None,
|
|
1165
1169
|
n_non_media_channels: int | None = None,
|
|
1166
1170
|
n_organic_media_channels: int | None = None,
|
|
1167
1171
|
n_organic_rf_channels: int | None = None,
|
|
@@ -1171,7 +1175,7 @@ def random_dataset(
|
|
|
1171
1175
|
seed: int = 0,
|
|
1172
1176
|
remove_media_time: bool = False,
|
|
1173
1177
|
integer_geos: bool = False,
|
|
1174
|
-
):
|
|
1178
|
+
) -> xr.Dataset:
|
|
1175
1179
|
"""Generates a random dataset."""
|
|
1176
1180
|
if n_media_channels:
|
|
1177
1181
|
media = random_media_da(
|
|
@@ -1232,14 +1236,18 @@ def random_dataset(
|
|
|
1232
1236
|
else:
|
|
1233
1237
|
revenue_per_kpi = None
|
|
1234
1238
|
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1239
|
+
if n_controls:
|
|
1240
|
+
controls = random_controls_da(
|
|
1241
|
+
media=media if n_media_channels else reach,
|
|
1242
|
+
n_geos=n_geos,
|
|
1243
|
+
n_times=n_times,
|
|
1244
|
+
n_controls=n_controls,
|
|
1245
|
+
seed=seed,
|
|
1246
|
+
integer_geos=integer_geos,
|
|
1247
|
+
)
|
|
1248
|
+
else:
|
|
1249
|
+
controls = None
|
|
1250
|
+
|
|
1243
1251
|
if n_non_media_channels:
|
|
1244
1252
|
non_media_treatments = random_non_media_treatments_da(
|
|
1245
1253
|
media=media if n_media_channels else reach,
|
|
@@ -1298,7 +1306,9 @@ def random_dataset(
|
|
|
1298
1306
|
n_geos=n_geos, seed=seed, integer_geos=integer_geos
|
|
1299
1307
|
)
|
|
1300
1308
|
|
|
1301
|
-
dataset = xr.combine_by_coords(
|
|
1309
|
+
dataset = xr.combine_by_coords(
|
|
1310
|
+
[kpi, population] + ([controls] if controls is not None else [])
|
|
1311
|
+
)
|
|
1302
1312
|
if revenue_per_kpi is not None:
|
|
1303
1313
|
dataset = xr.combine_by_coords([dataset, revenue_per_kpi])
|
|
1304
1314
|
if media is not None:
|
|
@@ -1347,7 +1357,7 @@ def random_dataset(
|
|
|
1347
1357
|
|
|
1348
1358
|
def dataset_to_dataframe(
|
|
1349
1359
|
dataset: xr.Dataset,
|
|
1350
|
-
controls_column_names: list[str],
|
|
1360
|
+
controls_column_names: list[str] | None = None,
|
|
1351
1361
|
media_column_names: list[str] | None = None,
|
|
1352
1362
|
media_spend_column_names: list[str] | None = None,
|
|
1353
1363
|
reach_column_names: list[str] | None = None,
|
|
@@ -1396,10 +1406,15 @@ def dataset_to_dataframe(
|
|
|
1396
1406
|
)
|
|
1397
1407
|
population = dataset[c.POPULATION].to_dataframe(name=c.POPULATION)
|
|
1398
1408
|
|
|
1399
|
-
|
|
1400
|
-
|
|
1409
|
+
if controls_column_names is not None:
|
|
1410
|
+
controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
|
|
1411
|
+
controls.columns = controls_column_names
|
|
1412
|
+
else:
|
|
1413
|
+
controls = None
|
|
1401
1414
|
|
|
1402
|
-
result = kpi.join(revenue_per_kpi).join(population)
|
|
1415
|
+
result = kpi.join(revenue_per_kpi).join(population)
|
|
1416
|
+
if controls is not None:
|
|
1417
|
+
result = result.join(controls)
|
|
1403
1418
|
|
|
1404
1419
|
if non_media_column_names is not None:
|
|
1405
1420
|
non_media_treatments = (
|
|
@@ -1471,7 +1486,7 @@ def random_dataframe(
|
|
|
1471
1486
|
n_geos,
|
|
1472
1487
|
n_times,
|
|
1473
1488
|
n_media_times,
|
|
1474
|
-
n_controls,
|
|
1489
|
+
n_controls=None,
|
|
1475
1490
|
n_media_channels=None,
|
|
1476
1491
|
n_rf_channels=None,
|
|
1477
1492
|
seed=0,
|
|
@@ -1489,7 +1504,9 @@ def random_dataframe(
|
|
|
1489
1504
|
|
|
1490
1505
|
return dataset_to_dataframe(
|
|
1491
1506
|
dataset,
|
|
1492
|
-
controls_column_names=
|
|
1507
|
+
controls_column_names=(
|
|
1508
|
+
_sample_names('control_', n_controls) if n_controls else None
|
|
1509
|
+
),
|
|
1493
1510
|
media_column_names=_sample_names('media_', n_media_channels),
|
|
1494
1511
|
media_spend_column_names=_sample_names('media_spend_', n_media_channels),
|
|
1495
1512
|
reach_column_names=_sample_names('reach_', n_rf_channels),
|
|
@@ -1499,7 +1516,7 @@ def random_dataframe(
|
|
|
1499
1516
|
|
|
1500
1517
|
|
|
1501
1518
|
def sample_coord_to_columns(
|
|
1502
|
-
n_controls: int,
|
|
1519
|
+
n_controls: int | None = None,
|
|
1503
1520
|
n_media_channels: int | None = None,
|
|
1504
1521
|
n_rf_channels: int | None = None,
|
|
1505
1522
|
n_non_media_channels: int | None = None,
|
|
@@ -1550,7 +1567,7 @@ def sample_coord_to_columns(
|
|
|
1550
1567
|
kpi=c.KPI,
|
|
1551
1568
|
revenue_per_kpi=c.REVENUE_PER_KPI if include_revenue_per_kpi else None,
|
|
1552
1569
|
population=c.POPULATION,
|
|
1553
|
-
controls=_sample_names('control_', n_controls),
|
|
1570
|
+
controls=(_sample_names('control_', n_controls) if n_controls else None),
|
|
1554
1571
|
media=media,
|
|
1555
1572
|
media_spend=media_spend,
|
|
1556
1573
|
reach=reach,
|
|
@@ -1668,7 +1685,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1668
1685
|
n_geos: int = 10,
|
|
1669
1686
|
n_times: int = 50,
|
|
1670
1687
|
n_media_times: int = 53,
|
|
1671
|
-
n_controls: int = 2,
|
|
1688
|
+
n_controls: int | None = 2,
|
|
1672
1689
|
n_non_media_channels: int | None = None,
|
|
1673
1690
|
n_media_channels: int | None = None,
|
|
1674
1691
|
n_rf_channels: int | None = None,
|
|
@@ -1694,10 +1711,10 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1694
1711
|
kpi_type=c.NON_REVENUE,
|
|
1695
1712
|
revenue_per_kpi=dataset.revenue_per_kpi,
|
|
1696
1713
|
population=dataset.population,
|
|
1697
|
-
controls=dataset.controls,
|
|
1698
|
-
non_media_treatments=
|
|
1699
|
-
|
|
1700
|
-
|
|
1714
|
+
controls=(dataset.controls if n_controls else None),
|
|
1715
|
+
non_media_treatments=(
|
|
1716
|
+
dataset.non_media_treatments if n_non_media_channels else None
|
|
1717
|
+
),
|
|
1701
1718
|
media=dataset.media if n_media_channels else None,
|
|
1702
1719
|
media_spend=dataset.media_spend if n_media_channels else None,
|
|
1703
1720
|
reach=dataset.reach if n_rf_channels else None,
|
|
@@ -1705,9 +1722,9 @@ def sample_input_data_non_revenue_revenue_per_kpi(
|
|
|
1705
1722
|
rf_spend=dataset.rf_spend if n_rf_channels else None,
|
|
1706
1723
|
organic_media=dataset.organic_media if n_organic_media_channels else None,
|
|
1707
1724
|
organic_reach=dataset.organic_reach if n_organic_rf_channels else None,
|
|
1708
|
-
organic_frequency=
|
|
1709
|
-
|
|
1710
|
-
|
|
1725
|
+
organic_frequency=(
|
|
1726
|
+
dataset.organic_frequency if n_organic_rf_channels else None
|
|
1727
|
+
),
|
|
1711
1728
|
)
|
|
1712
1729
|
|
|
1713
1730
|
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Meridian MLflow module that contains autologging functionality."""
|
|
16
|
+
|
|
17
|
+
from meridian.mlflow import autolog
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""MLflow autologging integration for Meridian."""
|
|
16
|
+
|
|
17
|
+
from typing import Any, Callable
|
|
18
|
+
|
|
19
|
+
import arviz as az
|
|
20
|
+
import meridian
|
|
21
|
+
import mlflow
|
|
22
|
+
from mlflow.utils.autologging_utils import autologging_integration, safe_patch
|
|
23
|
+
from meridian.model import model
|
|
24
|
+
|
|
25
|
+
FLAVOR_NAME = "meridian"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _log_versions() -> None:
|
|
29
|
+
"""Logs Meridian and ArviZ versions."""
|
|
30
|
+
mlflow.log_param("meridian_version", meridian.__version__)
|
|
31
|
+
mlflow.log_param("arviz_version", az.__version__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@autologging_integration(FLAVOR_NAME)
|
|
35
|
+
def autolog(
|
|
36
|
+
disable: bool = False, # pylint: disable=unused-argument
|
|
37
|
+
silent: bool = False, # pylint: disable=unused-argument
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Enables MLflow tracking for Meridian.
|
|
40
|
+
|
|
41
|
+
See https://mlflow.org/docs/latest/tracking/
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
disable: Whether to disable autologging.
|
|
45
|
+
silent: Whether to suppress all event logs and warnings from MLflow.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def patch_meridian_init(
|
|
49
|
+
original: Callable[..., Any], *args, **kwargs
|
|
50
|
+
) -> Callable[..., Any]:
|
|
51
|
+
_log_versions()
|
|
52
|
+
return original(*args, **kwargs)
|
|
53
|
+
|
|
54
|
+
safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
|
meridian/model/__init__.py
CHANGED
meridian/model/adstock_hill.py
CHANGED
meridian/model/knots.py
CHANGED
meridian/model/media.py
CHANGED
meridian/model/model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -214,7 +214,9 @@ class Meridian:
|
|
|
214
214
|
)
|
|
215
215
|
|
|
216
216
|
@functools.cached_property
|
|
217
|
-
def controls(self) -> tf.Tensor:
|
|
217
|
+
def controls(self) -> tf.Tensor | None:
|
|
218
|
+
if self.input_data.controls is None:
|
|
219
|
+
return None
|
|
218
220
|
return tf.convert_to_tensor(self.input_data.controls, dtype=tf.float32)
|
|
219
221
|
|
|
220
222
|
@functools.cached_property
|
|
@@ -271,6 +273,8 @@ class Meridian:
|
|
|
271
273
|
|
|
272
274
|
@property
|
|
273
275
|
def n_controls(self) -> int:
|
|
276
|
+
if self.input_data.control_variable is None:
|
|
277
|
+
return 0
|
|
274
278
|
return len(self.input_data.control_variable)
|
|
275
279
|
|
|
276
280
|
@property
|
|
@@ -304,7 +308,13 @@ class Meridian:
|
|
|
304
308
|
)
|
|
305
309
|
|
|
306
310
|
@functools.cached_property
|
|
307
|
-
def controls_transformer(
|
|
311
|
+
def controls_transformer(
|
|
312
|
+
self,
|
|
313
|
+
) -> transformers.CenteringAndScalingTransformer | None:
|
|
314
|
+
"""Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
|
|
315
|
+
if self.controls is None:
|
|
316
|
+
return None
|
|
317
|
+
|
|
308
318
|
if self.model_spec.control_population_scaling_id is not None:
|
|
309
319
|
controls_population_scaling_id = tf.convert_to_tensor(
|
|
310
320
|
self.model_spec.control_population_scaling_id, dtype=bool
|
|
@@ -343,8 +353,12 @@ class Meridian:
|
|
|
343
353
|
return transformers.KpiTransformer(self.kpi, self.population)
|
|
344
354
|
|
|
345
355
|
@functools.cached_property
|
|
346
|
-
def controls_scaled(self) -> tf.Tensor:
|
|
347
|
-
|
|
356
|
+
def controls_scaled(self) -> tf.Tensor | None:
|
|
357
|
+
if self.controls is not None:
|
|
358
|
+
# If `controls` is defined, then `controls_transformer` is also defined.
|
|
359
|
+
return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
|
|
360
|
+
else:
|
|
361
|
+
return None
|
|
348
362
|
|
|
349
363
|
@functools.cached_property
|
|
350
364
|
def non_media_treatments_normalized(self) -> tf.Tensor | None:
|
|
@@ -894,11 +908,12 @@ class Meridian:
|
|
|
894
908
|
if self.is_national:
|
|
895
909
|
return
|
|
896
910
|
|
|
897
|
-
self.
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
911
|
+
if self.input_data.controls is not None:
|
|
912
|
+
self._check_if_no_geo_variation(
|
|
913
|
+
self.controls_scaled,
|
|
914
|
+
constants.CONTROLS,
|
|
915
|
+
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
916
|
+
)
|
|
902
917
|
if self.input_data.non_media_treatments is not None:
|
|
903
918
|
self._check_if_no_geo_variation(
|
|
904
919
|
self.non_media_treatments_normalized,
|
|
@@ -971,12 +986,12 @@ class Meridian:
|
|
|
971
986
|
|
|
972
987
|
def _validate_time_invariants(self):
|
|
973
988
|
"""Validates model time invariants."""
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
989
|
+
if self.input_data.controls is not None:
|
|
990
|
+
self._check_if_no_time_variation(
|
|
991
|
+
self.controls_scaled,
|
|
992
|
+
constants.CONTROLS,
|
|
993
|
+
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
994
|
+
)
|
|
980
995
|
if self.input_data.non_media_treatments is not None:
|
|
981
996
|
self._check_if_no_time_variation(
|
|
982
997
|
self.non_media_treatments_normalized,
|
|
@@ -1031,16 +1046,24 @@ class Meridian:
|
|
|
1031
1046
|
mask = tf.equal(counts, self.n_geos)
|
|
1032
1047
|
col_idx_bad = tf.boolean_mask(col_idx_unique, mask)
|
|
1033
1048
|
dims_bad = tf.gather(data_dims, col_idx_bad)
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1049
|
+
if col_idx_bad.shape[0]:
|
|
1050
|
+
if self.is_national:
|
|
1051
|
+
raise ValueError(
|
|
1052
|
+
f"The following {data_name} variables do not vary across time,"
|
|
1053
|
+
" which is equivalent to no signal at all in a national model:"
|
|
1054
|
+
f" {dims_bad}. This can lead to poor model convergence. To address"
|
|
1055
|
+
" this, drop the listed variables that do not vary across time."
|
|
1056
|
+
)
|
|
1057
|
+
else:
|
|
1058
|
+
raise ValueError(
|
|
1059
|
+
f"The following {data_name} variables do not vary across time,"
|
|
1060
|
+
f" making a model with geo main effects unidentifiable: {dims_bad}."
|
|
1061
|
+
" This can lead to poor model convergence. Since these variables"
|
|
1062
|
+
" only vary across geo and not across time, they are collinear"
|
|
1063
|
+
" with geo and redundant in a model with geo main effects. To"
|
|
1064
|
+
" address this, drop the listed variables that do not vary across"
|
|
1065
|
+
" time."
|
|
1066
|
+
)
|
|
1044
1067
|
|
|
1045
1068
|
def _validate_kpi_transformer(self):
|
|
1046
1069
|
"""Validates the KPI transformer."""
|
|
@@ -1356,31 +1379,36 @@ class Meridian:
|
|
|
1356
1379
|
self, n_chains: int, n_draws: int
|
|
1357
1380
|
) -> Mapping[str, np.ndarray | Sequence[str]]:
|
|
1358
1381
|
"""Creates data coordinates for inference data."""
|
|
1359
|
-
|
|
1382
|
+
media_channel_names = (
|
|
1360
1383
|
self.input_data.media_channel
|
|
1361
1384
|
if self.input_data.media_channel is not None
|
|
1362
1385
|
else np.array([])
|
|
1363
1386
|
)
|
|
1364
|
-
|
|
1387
|
+
rf_channel_names = (
|
|
1365
1388
|
self.input_data.rf_channel
|
|
1366
1389
|
if self.input_data.rf_channel is not None
|
|
1367
1390
|
else np.array([])
|
|
1368
1391
|
)
|
|
1369
|
-
|
|
1392
|
+
organic_media_channel_names = (
|
|
1370
1393
|
self.input_data.organic_media_channel
|
|
1371
1394
|
if self.input_data.organic_media_channel is not None
|
|
1372
1395
|
else np.array([])
|
|
1373
1396
|
)
|
|
1374
|
-
|
|
1397
|
+
organic_rf_channel_names = (
|
|
1375
1398
|
self.input_data.organic_rf_channel
|
|
1376
1399
|
if self.input_data.organic_rf_channel is not None
|
|
1377
1400
|
else np.array([])
|
|
1378
1401
|
)
|
|
1379
|
-
|
|
1402
|
+
non_media_channel_names = (
|
|
1380
1403
|
self.input_data.non_media_channel
|
|
1381
1404
|
if self.input_data.non_media_channel is not None
|
|
1382
1405
|
else np.array([])
|
|
1383
1406
|
)
|
|
1407
|
+
control_variable_names = (
|
|
1408
|
+
self.input_data.control_variable
|
|
1409
|
+
if self.input_data.control_variable is not None
|
|
1410
|
+
else np.array([])
|
|
1411
|
+
)
|
|
1384
1412
|
return {
|
|
1385
1413
|
constants.CHAIN: np.arange(n_chains),
|
|
1386
1414
|
constants.DRAW: np.arange(n_draws),
|
|
@@ -1388,12 +1416,12 @@ class Meridian:
|
|
|
1388
1416
|
constants.TIME: self.input_data.time,
|
|
1389
1417
|
constants.MEDIA_TIME: self.input_data.media_time,
|
|
1390
1418
|
constants.KNOTS: np.arange(self.knot_info.n_knots),
|
|
1391
|
-
constants.CONTROL_VARIABLE:
|
|
1392
|
-
constants.NON_MEDIA_CHANNEL:
|
|
1393
|
-
constants.MEDIA_CHANNEL:
|
|
1394
|
-
constants.RF_CHANNEL:
|
|
1395
|
-
constants.ORGANIC_MEDIA_CHANNEL:
|
|
1396
|
-
constants.ORGANIC_RF_CHANNEL:
|
|
1419
|
+
constants.CONTROL_VARIABLE: control_variable_names,
|
|
1420
|
+
constants.NON_MEDIA_CHANNEL: non_media_channel_names,
|
|
1421
|
+
constants.MEDIA_CHANNEL: media_channel_names,
|
|
1422
|
+
constants.RF_CHANNEL: rf_channel_names,
|
|
1423
|
+
constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
|
|
1424
|
+
constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
|
|
1397
1425
|
}
|
|
1398
1426
|
|
|
1399
1427
|
def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -70,6 +70,10 @@ class WithInputDataSamples:
|
|
|
70
70
|
_TEST_DIR,
|
|
71
71
|
"sample_prior_media_only.nc",
|
|
72
72
|
)
|
|
73
|
+
_TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
|
|
74
|
+
_TEST_DIR,
|
|
75
|
+
"sample_prior_media_only_no_controls.nc",
|
|
76
|
+
)
|
|
73
77
|
_TEST_SAMPLE_PRIOR_RF_ONLY_PATH = os.path.join(
|
|
74
78
|
_TEST_DIR,
|
|
75
79
|
"sample_prior_rf_only.nc",
|
|
@@ -82,6 +86,10 @@ class WithInputDataSamples:
|
|
|
82
86
|
_TEST_DIR,
|
|
83
87
|
"sample_posterior_media_only.nc",
|
|
84
88
|
)
|
|
89
|
+
_TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
|
|
90
|
+
_TEST_DIR,
|
|
91
|
+
"sample_posterior_media_only_no_controls.nc",
|
|
92
|
+
)
|
|
85
93
|
_TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH = os.path.join(
|
|
86
94
|
_TEST_DIR,
|
|
87
95
|
"sample_posterior_rf_only.nc",
|
|
@@ -172,6 +180,17 @@ class WithInputDataSamples:
|
|
|
172
180
|
seed=0,
|
|
173
181
|
)
|
|
174
182
|
)
|
|
183
|
+
self.input_data_with_media_and_rf_no_controls = (
|
|
184
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
185
|
+
n_geos=self._N_GEOS,
|
|
186
|
+
n_times=self._N_TIMES,
|
|
187
|
+
n_media_times=self._N_MEDIA_TIMES,
|
|
188
|
+
n_controls=None,
|
|
189
|
+
n_media_channels=self._N_MEDIA_CHANNELS,
|
|
190
|
+
n_rf_channels=self._N_RF_CHANNELS,
|
|
191
|
+
seed=0,
|
|
192
|
+
)
|
|
193
|
+
)
|
|
175
194
|
self.short_input_data_with_media_only = (
|
|
176
195
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
177
196
|
n_geos=self._N_GEOS,
|
|
@@ -182,6 +201,16 @@ class WithInputDataSamples:
|
|
|
182
201
|
seed=0,
|
|
183
202
|
)
|
|
184
203
|
)
|
|
204
|
+
self.short_input_data_with_media_only_no_controls = (
|
|
205
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
206
|
+
n_geos=self._N_GEOS,
|
|
207
|
+
n_times=self._N_TIMES_SHORT,
|
|
208
|
+
n_media_times=self._N_MEDIA_TIMES_SHORT,
|
|
209
|
+
n_controls=0,
|
|
210
|
+
n_media_channels=self._N_MEDIA_CHANNELS,
|
|
211
|
+
seed=0,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
185
214
|
self.short_input_data_with_rf_only = (
|
|
186
215
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
187
216
|
n_geos=self._N_GEOS,
|
|
@@ -231,6 +260,9 @@ class WithInputDataSamples:
|
|
|
231
260
|
test_prior_media_only = xr.open_dataset(
|
|
232
261
|
self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
|
|
233
262
|
)
|
|
263
|
+
test_prior_media_only_no_controls = xr.open_dataset(
|
|
264
|
+
self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
265
|
+
)
|
|
234
266
|
test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
|
|
235
267
|
self.test_dist_media_and_rf = collections.OrderedDict({
|
|
236
268
|
param: tf.convert_to_tensor(test_prior_media_and_rf[param])
|
|
@@ -243,6 +275,18 @@ class WithInputDataSamples:
|
|
|
243
275
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
244
276
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
245
277
|
})
|
|
278
|
+
self.test_dist_media_only_no_controls = collections.OrderedDict({
|
|
279
|
+
param: tf.convert_to_tensor(test_prior_media_only_no_controls[param])
|
|
280
|
+
for param in (
|
|
281
|
+
set(
|
|
282
|
+
constants.COMMON_PARAMETER_NAMES
|
|
283
|
+
+ constants.MEDIA_PARAMETER_NAMES
|
|
284
|
+
)
|
|
285
|
+
- set(
|
|
286
|
+
constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
})
|
|
246
290
|
self.test_dist_rf_only = collections.OrderedDict({
|
|
247
291
|
param: tf.convert_to_tensor(test_prior_rf_only[param])
|
|
248
292
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
@@ -255,6 +299,9 @@ class WithInputDataSamples:
|
|
|
255
299
|
test_posterior_media_only = xr.open_dataset(
|
|
256
300
|
self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
|
|
257
301
|
)
|
|
302
|
+
test_posterior_media_only_no_controls = xr.open_dataset(
|
|
303
|
+
self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
304
|
+
)
|
|
258
305
|
test_posterior_rf_only = xr.open_dataset(
|
|
259
306
|
self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
|
|
260
307
|
)
|
|
@@ -273,6 +320,21 @@ class WithInputDataSamples:
|
|
|
273
320
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
274
321
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
275
322
|
}
|
|
323
|
+
posterior_params_to_tensors_media_only_no_controls = {
|
|
324
|
+
param: _convert_with_swap(
|
|
325
|
+
test_posterior_media_only_no_controls[param],
|
|
326
|
+
n_burnin=self._N_BURNIN,
|
|
327
|
+
)
|
|
328
|
+
for param in (
|
|
329
|
+
set(
|
|
330
|
+
constants.COMMON_PARAMETER_NAMES
|
|
331
|
+
+ constants.MEDIA_PARAMETER_NAMES
|
|
332
|
+
)
|
|
333
|
+
- set(
|
|
334
|
+
constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
|
|
335
|
+
)
|
|
336
|
+
)
|
|
337
|
+
}
|
|
276
338
|
posterior_params_to_tensors_rf_only = {
|
|
277
339
|
param: _convert_with_swap(
|
|
278
340
|
test_posterior_rf_only[param], n_burnin=self._N_BURNIN
|
|
@@ -290,6 +352,18 @@ class WithInputDataSamples:
|
|
|
290
352
|
"StructTuple",
|
|
291
353
|
constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
|
|
292
354
|
)(**posterior_params_to_tensors_media_only)
|
|
355
|
+
self.test_posterior_states_media_only_no_controls = collections.namedtuple(
|
|
356
|
+
"StructTuple",
|
|
357
|
+
(
|
|
358
|
+
set(
|
|
359
|
+
constants.COMMON_PARAMETER_NAMES
|
|
360
|
+
+ constants.MEDIA_PARAMETER_NAMES
|
|
361
|
+
)
|
|
362
|
+
- set(
|
|
363
|
+
constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
|
|
364
|
+
)
|
|
365
|
+
),
|
|
366
|
+
)(**posterior_params_to_tensors_media_only_no_controls)
|
|
293
367
|
self.test_posterior_states_rf_only = collections.namedtuple(
|
|
294
368
|
"StructTuple",
|
|
295
369
|
constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
|