google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
  2. google_meridian-1.1.2.dist-info/RECORD +46 -0
  3. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
  4. meridian/__init__.py +2 -2
  5. meridian/analysis/__init__.py +1 -1
  6. meridian/analysis/analyzer.py +29 -22
  7. meridian/analysis/formatter.py +1 -1
  8. meridian/analysis/optimizer.py +70 -44
  9. meridian/analysis/summarizer.py +1 -1
  10. meridian/analysis/summary_text.py +1 -1
  11. meridian/analysis/test_utils.py +1 -1
  12. meridian/analysis/visualizer.py +17 -8
  13. meridian/constants.py +3 -3
  14. meridian/data/__init__.py +4 -1
  15. meridian/data/arg_builder.py +1 -1
  16. meridian/data/data_frame_input_data_builder.py +614 -0
  17. meridian/data/input_data.py +12 -8
  18. meridian/data/input_data_builder.py +817 -0
  19. meridian/data/load.py +121 -428
  20. meridian/data/nd_array_input_data_builder.py +509 -0
  21. meridian/data/test_utils.py +60 -43
  22. meridian/data/time_coordinates.py +1 -1
  23. meridian/mlflow/__init__.py +17 -0
  24. meridian/mlflow/autolog.py +54 -0
  25. meridian/model/__init__.py +1 -1
  26. meridian/model/adstock_hill.py +1 -1
  27. meridian/model/knots.py +1 -1
  28. meridian/model/media.py +1 -1
  29. meridian/model/model.py +65 -37
  30. meridian/model/model_test_data.py +75 -1
  31. meridian/model/posterior_sampler.py +19 -15
  32. meridian/model/prior_distribution.py +1 -1
  33. meridian/model/prior_sampler.py +32 -26
  34. meridian/model/spec.py +18 -8
  35. meridian/model/transformers.py +1 -1
  36. google_meridian-1.1.0.dist-info/RECORD +0 -41
  37. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
  38. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -409,10 +409,7 @@ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
409
409
  )
410
410
 
411
411
  DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
412
- coords=_REQUIRED_COORDS
413
- | _MEDIA_COORDS
414
- | _RF_COORDS
415
- | _ORGANIC_RF_COORDS,
412
+ coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _ORGANIC_RF_COORDS,
416
413
  data_vars=_REQUIRED_DATA_VARS
417
414
  | _MEDIA_DATA_VARS
418
415
  | _RF_DATA_VARS
@@ -818,11 +815,11 @@ def random_controls_da(
818
815
 
819
816
  def random_kpi_da(
820
817
  media: xr.DataArray,
821
- controls: xr.DataArray,
822
818
  n_geos: int,
823
819
  n_times: int,
824
820
  n_media_channels: int,
825
- n_controls: int,
821
+ n_controls: int | None = None,
822
+ controls: xr.DataArray | None = None,
826
823
  seed: int = 0,
827
824
  integer_geos: bool = False,
828
825
  ) -> xr.DataArray:
@@ -838,19 +835,26 @@ def random_kpi_da(
838
835
  n_media_channels,
839
836
  axis=2,
840
837
  )
841
- control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
842
- control_geo_sd = np.repeat(
843
- np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
844
- ..., np.newaxis
845
- ],
846
- n_controls,
847
- axis=2,
848
- )
838
+ if n_controls:
839
+ control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
840
+ control_geo_sd = np.repeat(
841
+ np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
842
+ ..., np.newaxis
843
+ ],
844
+ n_controls,
845
+ axis=2,
846
+ )
847
+ else:
848
+ control_geo_sd = 0
849
849
 
850
850
  # Simulates outcome which is the dependent variable. Typically this is the
851
851
  # number of units sold, but it can be any metric (e.g. revenue).
852
852
  media_portion = np.random.normal(media_common, media_geo_sd).sum(axis=2)
853
- control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
853
+ if controls is not None:
854
+ control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
855
+ else:
856
+ control_portion = 0
857
+
854
858
  error = np.random.normal(0, 2, size=(n_geos, n_times))
855
859
  kpi = abs(media_portion + control_portion + error)
856
860
 
@@ -1161,7 +1165,7 @@ def random_dataset(
1161
1165
  n_geos: int,
1162
1166
  n_times: int,
1163
1167
  n_media_times: int,
1164
- n_controls: int,
1168
+ n_controls: int | None = None,
1165
1169
  n_non_media_channels: int | None = None,
1166
1170
  n_organic_media_channels: int | None = None,
1167
1171
  n_organic_rf_channels: int | None = None,
@@ -1171,7 +1175,7 @@ def random_dataset(
1171
1175
  seed: int = 0,
1172
1176
  remove_media_time: bool = False,
1173
1177
  integer_geos: bool = False,
1174
- ):
1178
+ ) -> xr.Dataset:
1175
1179
  """Generates a random dataset."""
1176
1180
  if n_media_channels:
1177
1181
  media = random_media_da(
@@ -1232,14 +1236,18 @@ def random_dataset(
1232
1236
  else:
1233
1237
  revenue_per_kpi = None
1234
1238
 
1235
- controls = random_controls_da(
1236
- media=media if n_media_channels else reach,
1237
- n_geos=n_geos,
1238
- n_times=n_times,
1239
- n_controls=n_controls,
1240
- seed=seed,
1241
- integer_geos=integer_geos,
1242
- )
1239
+ if n_controls:
1240
+ controls = random_controls_da(
1241
+ media=media if n_media_channels else reach,
1242
+ n_geos=n_geos,
1243
+ n_times=n_times,
1244
+ n_controls=n_controls,
1245
+ seed=seed,
1246
+ integer_geos=integer_geos,
1247
+ )
1248
+ else:
1249
+ controls = None
1250
+
1243
1251
  if n_non_media_channels:
1244
1252
  non_media_treatments = random_non_media_treatments_da(
1245
1253
  media=media if n_media_channels else reach,
@@ -1298,7 +1306,9 @@ def random_dataset(
1298
1306
  n_geos=n_geos, seed=seed, integer_geos=integer_geos
1299
1307
  )
1300
1308
 
1301
- dataset = xr.combine_by_coords([kpi, population, controls])
1309
+ dataset = xr.combine_by_coords(
1310
+ [kpi, population] + ([controls] if controls is not None else [])
1311
+ )
1302
1312
  if revenue_per_kpi is not None:
1303
1313
  dataset = xr.combine_by_coords([dataset, revenue_per_kpi])
1304
1314
  if media is not None:
@@ -1347,7 +1357,7 @@ def random_dataset(
1347
1357
 
1348
1358
  def dataset_to_dataframe(
1349
1359
  dataset: xr.Dataset,
1350
- controls_column_names: list[str],
1360
+ controls_column_names: list[str] | None = None,
1351
1361
  media_column_names: list[str] | None = None,
1352
1362
  media_spend_column_names: list[str] | None = None,
1353
1363
  reach_column_names: list[str] | None = None,
@@ -1396,10 +1406,15 @@ def dataset_to_dataframe(
1396
1406
  )
1397
1407
  population = dataset[c.POPULATION].to_dataframe(name=c.POPULATION)
1398
1408
 
1399
- controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
1400
- controls.columns = controls_column_names
1409
+ if controls_column_names is not None:
1410
+ controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
1411
+ controls.columns = controls_column_names
1412
+ else:
1413
+ controls = None
1401
1414
 
1402
- result = kpi.join(revenue_per_kpi).join(population).join(controls)
1415
+ result = kpi.join(revenue_per_kpi).join(population)
1416
+ if controls is not None:
1417
+ result = result.join(controls)
1403
1418
 
1404
1419
  if non_media_column_names is not None:
1405
1420
  non_media_treatments = (
@@ -1471,7 +1486,7 @@ def random_dataframe(
1471
1486
  n_geos,
1472
1487
  n_times,
1473
1488
  n_media_times,
1474
- n_controls,
1489
+ n_controls=None,
1475
1490
  n_media_channels=None,
1476
1491
  n_rf_channels=None,
1477
1492
  seed=0,
@@ -1489,7 +1504,9 @@ def random_dataframe(
1489
1504
 
1490
1505
  return dataset_to_dataframe(
1491
1506
  dataset,
1492
- controls_column_names=_sample_names('control_', n_controls),
1507
+ controls_column_names=(
1508
+ _sample_names('control_', n_controls) if n_controls else None
1509
+ ),
1493
1510
  media_column_names=_sample_names('media_', n_media_channels),
1494
1511
  media_spend_column_names=_sample_names('media_spend_', n_media_channels),
1495
1512
  reach_column_names=_sample_names('reach_', n_rf_channels),
@@ -1499,7 +1516,7 @@ def random_dataframe(
1499
1516
 
1500
1517
 
1501
1518
  def sample_coord_to_columns(
1502
- n_controls: int,
1519
+ n_controls: int | None = None,
1503
1520
  n_media_channels: int | None = None,
1504
1521
  n_rf_channels: int | None = None,
1505
1522
  n_non_media_channels: int | None = None,
@@ -1550,7 +1567,7 @@ def sample_coord_to_columns(
1550
1567
  kpi=c.KPI,
1551
1568
  revenue_per_kpi=c.REVENUE_PER_KPI if include_revenue_per_kpi else None,
1552
1569
  population=c.POPULATION,
1553
- controls=_sample_names('control_', n_controls),
1570
+ controls=(_sample_names('control_', n_controls) if n_controls else None),
1554
1571
  media=media,
1555
1572
  media_spend=media_spend,
1556
1573
  reach=reach,
@@ -1668,7 +1685,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1668
1685
  n_geos: int = 10,
1669
1686
  n_times: int = 50,
1670
1687
  n_media_times: int = 53,
1671
- n_controls: int = 2,
1688
+ n_controls: int | None = 2,
1672
1689
  n_non_media_channels: int | None = None,
1673
1690
  n_media_channels: int | None = None,
1674
1691
  n_rf_channels: int | None = None,
@@ -1694,10 +1711,10 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1694
1711
  kpi_type=c.NON_REVENUE,
1695
1712
  revenue_per_kpi=dataset.revenue_per_kpi,
1696
1713
  population=dataset.population,
1697
- controls=dataset.controls,
1698
- non_media_treatments=dataset.non_media_treatments
1699
- if n_non_media_channels
1700
- else None,
1714
+ controls=(dataset.controls if n_controls else None),
1715
+ non_media_treatments=(
1716
+ dataset.non_media_treatments if n_non_media_channels else None
1717
+ ),
1701
1718
  media=dataset.media if n_media_channels else None,
1702
1719
  media_spend=dataset.media_spend if n_media_channels else None,
1703
1720
  reach=dataset.reach if n_rf_channels else None,
@@ -1705,9 +1722,9 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1705
1722
  rf_spend=dataset.rf_spend if n_rf_channels else None,
1706
1723
  organic_media=dataset.organic_media if n_organic_media_channels else None,
1707
1724
  organic_reach=dataset.organic_reach if n_organic_rf_channels else None,
1708
- organic_frequency=dataset.organic_frequency
1709
- if n_organic_rf_channels
1710
- else None,
1725
+ organic_frequency=(
1726
+ dataset.organic_frequency if n_organic_rf_channels else None
1727
+ ),
1711
1728
  )
1712
1729
 
1713
1730
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Meridian MLflow module that contains autologging functionality."""
16
+
17
+ from meridian.mlflow import autolog
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """MLflow autologging integration for Meridian."""
16
+
17
+ from typing import Any, Callable
18
+
19
+ import arviz as az
20
+ import meridian
21
+ import mlflow
22
+ from mlflow.utils.autologging_utils import autologging_integration, safe_patch
23
+ from meridian.model import model
24
+
25
+ FLAVOR_NAME = "meridian"
26
+
27
+
28
+ def _log_versions() -> None:
29
+ """Logs Meridian and ArviZ versions."""
30
+ mlflow.log_param("meridian_version", meridian.__version__)
31
+ mlflow.log_param("arviz_version", az.__version__)
32
+
33
+
34
+ @autologging_integration(FLAVOR_NAME)
35
+ def autolog(
36
+ disable: bool = False, # pylint: disable=unused-argument
37
+ silent: bool = False, # pylint: disable=unused-argument
38
+ ) -> None:
39
+ """Enables MLflow tracking for Meridian.
40
+
41
+ See https://mlflow.org/docs/latest/tracking/
42
+
43
+ Args:
44
+ disable: Whether to disable autologging.
45
+ silent: Whether to suppress all event logs and warnings from MLflow.
46
+ """
47
+
48
+ def patch_meridian_init(
49
+ original: Callable[..., Any], *args, **kwargs
50
+ ) -> Callable[..., Any]:
51
+ _log_versions()
52
+ return original(*args, **kwargs)
53
+
54
+ safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
meridian/model/knots.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
meridian/model/media.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
meridian/model/model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -214,7 +214,9 @@ class Meridian:
214
214
  )
215
215
 
216
216
  @functools.cached_property
217
- def controls(self) -> tf.Tensor:
217
+ def controls(self) -> tf.Tensor | None:
218
+ if self.input_data.controls is None:
219
+ return None
218
220
  return tf.convert_to_tensor(self.input_data.controls, dtype=tf.float32)
219
221
 
220
222
  @functools.cached_property
@@ -271,6 +273,8 @@ class Meridian:
271
273
 
272
274
  @property
273
275
  def n_controls(self) -> int:
276
+ if self.input_data.control_variable is None:
277
+ return 0
274
278
  return len(self.input_data.control_variable)
275
279
 
276
280
  @property
@@ -304,7 +308,13 @@ class Meridian:
304
308
  )
305
309
 
306
310
  @functools.cached_property
307
- def controls_transformer(self) -> transformers.CenteringAndScalingTransformer:
311
+ def controls_transformer(
312
+ self,
313
+ ) -> transformers.CenteringAndScalingTransformer | None:
314
+ """Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
315
+ if self.controls is None:
316
+ return None
317
+
308
318
  if self.model_spec.control_population_scaling_id is not None:
309
319
  controls_population_scaling_id = tf.convert_to_tensor(
310
320
  self.model_spec.control_population_scaling_id, dtype=bool
@@ -343,8 +353,12 @@ class Meridian:
343
353
  return transformers.KpiTransformer(self.kpi, self.population)
344
354
 
345
355
  @functools.cached_property
346
- def controls_scaled(self) -> tf.Tensor:
347
- return self.controls_transformer.forward(self.controls)
356
+ def controls_scaled(self) -> tf.Tensor | None:
357
+ if self.controls is not None:
358
+ # If `controls` is defined, then `controls_transformer` is also defined.
359
+ return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
360
+ else:
361
+ return None
348
362
 
349
363
  @functools.cached_property
350
364
  def non_media_treatments_normalized(self) -> tf.Tensor | None:
@@ -894,11 +908,12 @@ class Meridian:
894
908
  if self.is_national:
895
909
  return
896
910
 
897
- self._check_if_no_geo_variation(
898
- self.controls_scaled,
899
- constants.CONTROLS,
900
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
901
- )
911
+ if self.input_data.controls is not None:
912
+ self._check_if_no_geo_variation(
913
+ self.controls_scaled,
914
+ constants.CONTROLS,
915
+ self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
916
+ )
902
917
  if self.input_data.non_media_treatments is not None:
903
918
  self._check_if_no_geo_variation(
904
919
  self.non_media_treatments_normalized,
@@ -971,12 +986,12 @@ class Meridian:
971
986
 
972
987
  def _validate_time_invariants(self):
973
988
  """Validates model time invariants."""
974
-
975
- self._check_if_no_time_variation(
976
- self.controls_scaled,
977
- constants.CONTROLS,
978
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
979
- )
989
+ if self.input_data.controls is not None:
990
+ self._check_if_no_time_variation(
991
+ self.controls_scaled,
992
+ constants.CONTROLS,
993
+ self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
994
+ )
980
995
  if self.input_data.non_media_treatments is not None:
981
996
  self._check_if_no_time_variation(
982
997
  self.non_media_treatments_normalized,
@@ -1031,16 +1046,24 @@ class Meridian:
1031
1046
  mask = tf.equal(counts, self.n_geos)
1032
1047
  col_idx_bad = tf.boolean_mask(col_idx_unique, mask)
1033
1048
  dims_bad = tf.gather(data_dims, col_idx_bad)
1034
-
1035
- if col_idx_bad.shape[0] and not self.is_national:
1036
- raise ValueError(
1037
- f"The following {data_name} variables do not vary across time, making"
1038
- f" a model with geo main effects unidentifiable: {dims_bad}. This can"
1039
- " lead to poor model convergence. Since these variables only vary"
1040
- " across geo and not across time, they are collinear with geo and"
1041
- " redundant in a model with geo main effects. To address this, drop"
1042
- " the listed variables that do not vary across time."
1043
- )
1049
+ if col_idx_bad.shape[0]:
1050
+ if self.is_national:
1051
+ raise ValueError(
1052
+ f"The following {data_name} variables do not vary across time,"
1053
+ " which is equivalent to no signal at all in a national model:"
1054
+ f" {dims_bad}. This can lead to poor model convergence. To address"
1055
+ " this, drop the listed variables that do not vary across time."
1056
+ )
1057
+ else:
1058
+ raise ValueError(
1059
+ f"The following {data_name} variables do not vary across time,"
1060
+ f" making a model with geo main effects unidentifiable: {dims_bad}."
1061
+ " This can lead to poor model convergence. Since these variables"
1062
+ " only vary across geo and not across time, they are collinear"
1063
+ " with geo and redundant in a model with geo main effects. To"
1064
+ " address this, drop the listed variables that do not vary across"
1065
+ " time."
1066
+ )
1044
1067
 
1045
1068
  def _validate_kpi_transformer(self):
1046
1069
  """Validates the KPI transformer."""
@@ -1356,31 +1379,36 @@ class Meridian:
1356
1379
  self, n_chains: int, n_draws: int
1357
1380
  ) -> Mapping[str, np.ndarray | Sequence[str]]:
1358
1381
  """Creates data coordinates for inference data."""
1359
- media_channel_values = (
1382
+ media_channel_names = (
1360
1383
  self.input_data.media_channel
1361
1384
  if self.input_data.media_channel is not None
1362
1385
  else np.array([])
1363
1386
  )
1364
- rf_channel_values = (
1387
+ rf_channel_names = (
1365
1388
  self.input_data.rf_channel
1366
1389
  if self.input_data.rf_channel is not None
1367
1390
  else np.array([])
1368
1391
  )
1369
- organic_media_channel_values = (
1392
+ organic_media_channel_names = (
1370
1393
  self.input_data.organic_media_channel
1371
1394
  if self.input_data.organic_media_channel is not None
1372
1395
  else np.array([])
1373
1396
  )
1374
- organic_rf_channel_values = (
1397
+ organic_rf_channel_names = (
1375
1398
  self.input_data.organic_rf_channel
1376
1399
  if self.input_data.organic_rf_channel is not None
1377
1400
  else np.array([])
1378
1401
  )
1379
- non_media_channel_values = (
1402
+ non_media_channel_names = (
1380
1403
  self.input_data.non_media_channel
1381
1404
  if self.input_data.non_media_channel is not None
1382
1405
  else np.array([])
1383
1406
  )
1407
+ control_variable_names = (
1408
+ self.input_data.control_variable
1409
+ if self.input_data.control_variable is not None
1410
+ else np.array([])
1411
+ )
1384
1412
  return {
1385
1413
  constants.CHAIN: np.arange(n_chains),
1386
1414
  constants.DRAW: np.arange(n_draws),
@@ -1388,12 +1416,12 @@ class Meridian:
1388
1416
  constants.TIME: self.input_data.time,
1389
1417
  constants.MEDIA_TIME: self.input_data.media_time,
1390
1418
  constants.KNOTS: np.arange(self.knot_info.n_knots),
1391
- constants.CONTROL_VARIABLE: self.input_data.control_variable,
1392
- constants.NON_MEDIA_CHANNEL: non_media_channel_values,
1393
- constants.MEDIA_CHANNEL: media_channel_values,
1394
- constants.RF_CHANNEL: rf_channel_values,
1395
- constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_values,
1396
- constants.ORGANIC_RF_CHANNEL: organic_rf_channel_values,
1419
+ constants.CONTROL_VARIABLE: control_variable_names,
1420
+ constants.NON_MEDIA_CHANNEL: non_media_channel_names,
1421
+ constants.MEDIA_CHANNEL: media_channel_names,
1422
+ constants.RF_CHANNEL: rf_channel_names,
1423
+ constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
1424
+ constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
1397
1425
  }
1398
1426
 
1399
1427
  def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -70,6 +70,10 @@ class WithInputDataSamples:
70
70
  _TEST_DIR,
71
71
  "sample_prior_media_only.nc",
72
72
  )
73
+ _TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
74
+ _TEST_DIR,
75
+ "sample_prior_media_only_no_controls.nc",
76
+ )
73
77
  _TEST_SAMPLE_PRIOR_RF_ONLY_PATH = os.path.join(
74
78
  _TEST_DIR,
75
79
  "sample_prior_rf_only.nc",
@@ -82,6 +86,10 @@ class WithInputDataSamples:
82
86
  _TEST_DIR,
83
87
  "sample_posterior_media_only.nc",
84
88
  )
89
+ _TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
90
+ _TEST_DIR,
91
+ "sample_posterior_media_only_no_controls.nc",
92
+ )
85
93
  _TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH = os.path.join(
86
94
  _TEST_DIR,
87
95
  "sample_posterior_rf_only.nc",
@@ -172,6 +180,17 @@ class WithInputDataSamples:
172
180
  seed=0,
173
181
  )
174
182
  )
183
+ self.input_data_with_media_and_rf_no_controls = (
184
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
185
+ n_geos=self._N_GEOS,
186
+ n_times=self._N_TIMES,
187
+ n_media_times=self._N_MEDIA_TIMES,
188
+ n_controls=None,
189
+ n_media_channels=self._N_MEDIA_CHANNELS,
190
+ n_rf_channels=self._N_RF_CHANNELS,
191
+ seed=0,
192
+ )
193
+ )
175
194
  self.short_input_data_with_media_only = (
176
195
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
177
196
  n_geos=self._N_GEOS,
@@ -182,6 +201,16 @@ class WithInputDataSamples:
182
201
  seed=0,
183
202
  )
184
203
  )
204
+ self.short_input_data_with_media_only_no_controls = (
205
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
206
+ n_geos=self._N_GEOS,
207
+ n_times=self._N_TIMES_SHORT,
208
+ n_media_times=self._N_MEDIA_TIMES_SHORT,
209
+ n_controls=0,
210
+ n_media_channels=self._N_MEDIA_CHANNELS,
211
+ seed=0,
212
+ )
213
+ )
185
214
  self.short_input_data_with_rf_only = (
186
215
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
187
216
  n_geos=self._N_GEOS,
@@ -231,6 +260,9 @@ class WithInputDataSamples:
231
260
  test_prior_media_only = xr.open_dataset(
232
261
  self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
233
262
  )
263
+ test_prior_media_only_no_controls = xr.open_dataset(
264
+ self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
265
+ )
234
266
  test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
235
267
  self.test_dist_media_and_rf = collections.OrderedDict({
236
268
  param: tf.convert_to_tensor(test_prior_media_and_rf[param])
@@ -243,6 +275,18 @@ class WithInputDataSamples:
243
275
  for param in constants.COMMON_PARAMETER_NAMES
244
276
  + constants.MEDIA_PARAMETER_NAMES
245
277
  })
278
+ self.test_dist_media_only_no_controls = collections.OrderedDict({
279
+ param: tf.convert_to_tensor(test_prior_media_only_no_controls[param])
280
+ for param in (
281
+ set(
282
+ constants.COMMON_PARAMETER_NAMES
283
+ + constants.MEDIA_PARAMETER_NAMES
284
+ )
285
+ - set(
286
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
287
+ )
288
+ )
289
+ })
246
290
  self.test_dist_rf_only = collections.OrderedDict({
247
291
  param: tf.convert_to_tensor(test_prior_rf_only[param])
248
292
  for param in constants.COMMON_PARAMETER_NAMES
@@ -255,6 +299,9 @@ class WithInputDataSamples:
255
299
  test_posterior_media_only = xr.open_dataset(
256
300
  self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
257
301
  )
302
+ test_posterior_media_only_no_controls = xr.open_dataset(
303
+ self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
304
+ )
258
305
  test_posterior_rf_only = xr.open_dataset(
259
306
  self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
260
307
  )
@@ -273,6 +320,21 @@ class WithInputDataSamples:
273
320
  for param in constants.COMMON_PARAMETER_NAMES
274
321
  + constants.MEDIA_PARAMETER_NAMES
275
322
  }
323
+ posterior_params_to_tensors_media_only_no_controls = {
324
+ param: _convert_with_swap(
325
+ test_posterior_media_only_no_controls[param],
326
+ n_burnin=self._N_BURNIN,
327
+ )
328
+ for param in (
329
+ set(
330
+ constants.COMMON_PARAMETER_NAMES
331
+ + constants.MEDIA_PARAMETER_NAMES
332
+ )
333
+ - set(
334
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
335
+ )
336
+ )
337
+ }
276
338
  posterior_params_to_tensors_rf_only = {
277
339
  param: _convert_with_swap(
278
340
  test_posterior_rf_only[param], n_burnin=self._N_BURNIN
@@ -290,6 +352,18 @@ class WithInputDataSamples:
290
352
  "StructTuple",
291
353
  constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
292
354
  )(**posterior_params_to_tensors_media_only)
355
+ self.test_posterior_states_media_only_no_controls = collections.namedtuple(
356
+ "StructTuple",
357
+ (
358
+ set(
359
+ constants.COMMON_PARAMETER_NAMES
360
+ + constants.MEDIA_PARAMETER_NAMES
361
+ )
362
+ - set(
363
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
364
+ )
365
+ ),
366
+ )(**posterior_params_to_tensors_media_only_no_controls)
293
367
  self.test_posterior_states_rf_only = collections.namedtuple(
294
368
  "StructTuple",
295
369
  constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,