google-meridian 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -409,10 +409,7 @@ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
409
409
  )
410
410
 
411
411
  DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
412
- coords=_REQUIRED_COORDS
413
- | _MEDIA_COORDS
414
- | _RF_COORDS
415
- | _ORGANIC_RF_COORDS,
412
+ coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _ORGANIC_RF_COORDS,
416
413
  data_vars=_REQUIRED_DATA_VARS
417
414
  | _MEDIA_DATA_VARS
418
415
  | _RF_DATA_VARS
@@ -818,11 +815,11 @@ def random_controls_da(
818
815
 
819
816
  def random_kpi_da(
820
817
  media: xr.DataArray,
821
- controls: xr.DataArray,
822
818
  n_geos: int,
823
819
  n_times: int,
824
820
  n_media_channels: int,
825
- n_controls: int,
821
+ n_controls: int | None = None,
822
+ controls: xr.DataArray | None = None,
826
823
  seed: int = 0,
827
824
  integer_geos: bool = False,
828
825
  ) -> xr.DataArray:
@@ -838,19 +835,26 @@ def random_kpi_da(
838
835
  n_media_channels,
839
836
  axis=2,
840
837
  )
841
- control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
842
- control_geo_sd = np.repeat(
843
- np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
844
- ..., np.newaxis
845
- ],
846
- n_controls,
847
- axis=2,
848
- )
838
+ if n_controls:
839
+ control_geo_sd = abs(np.random.normal(0, 5, size=n_geos))
840
+ control_geo_sd = np.repeat(
841
+ np.repeat(control_geo_sd[:, np.newaxis], n_times, axis=1)[
842
+ ..., np.newaxis
843
+ ],
844
+ n_controls,
845
+ axis=2,
846
+ )
847
+ else:
848
+ control_geo_sd = 0
849
849
 
850
850
  # Simulates outcome which is the dependent variable. Typically this is the
851
851
  # number of units sold, but it can be any metric (e.g. revenue).
852
852
  media_portion = np.random.normal(media_common, media_geo_sd).sum(axis=2)
853
- control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
853
+ if controls is not None:
854
+ control_portion = np.random.normal(controls, control_geo_sd).sum(axis=2)
855
+ else:
856
+ control_portion = 0
857
+
854
858
  error = np.random.normal(0, 2, size=(n_geos, n_times))
855
859
  kpi = abs(media_portion + control_portion + error)
856
860
 
@@ -1161,7 +1165,7 @@ def random_dataset(
1161
1165
  n_geos: int,
1162
1166
  n_times: int,
1163
1167
  n_media_times: int,
1164
- n_controls: int,
1168
+ n_controls: int | None = None,
1165
1169
  n_non_media_channels: int | None = None,
1166
1170
  n_organic_media_channels: int | None = None,
1167
1171
  n_organic_rf_channels: int | None = None,
@@ -1171,7 +1175,7 @@ def random_dataset(
1171
1175
  seed: int = 0,
1172
1176
  remove_media_time: bool = False,
1173
1177
  integer_geos: bool = False,
1174
- ):
1178
+ ) -> xr.Dataset:
1175
1179
  """Generates a random dataset."""
1176
1180
  if n_media_channels:
1177
1181
  media = random_media_da(
@@ -1232,14 +1236,18 @@ def random_dataset(
1232
1236
  else:
1233
1237
  revenue_per_kpi = None
1234
1238
 
1235
- controls = random_controls_da(
1236
- media=media if n_media_channels else reach,
1237
- n_geos=n_geos,
1238
- n_times=n_times,
1239
- n_controls=n_controls,
1240
- seed=seed,
1241
- integer_geos=integer_geos,
1242
- )
1239
+ if n_controls:
1240
+ controls = random_controls_da(
1241
+ media=media if n_media_channels else reach,
1242
+ n_geos=n_geos,
1243
+ n_times=n_times,
1244
+ n_controls=n_controls,
1245
+ seed=seed,
1246
+ integer_geos=integer_geos,
1247
+ )
1248
+ else:
1249
+ controls = None
1250
+
1243
1251
  if n_non_media_channels:
1244
1252
  non_media_treatments = random_non_media_treatments_da(
1245
1253
  media=media if n_media_channels else reach,
@@ -1298,7 +1306,9 @@ def random_dataset(
1298
1306
  n_geos=n_geos, seed=seed, integer_geos=integer_geos
1299
1307
  )
1300
1308
 
1301
- dataset = xr.combine_by_coords([kpi, population, controls])
1309
+ dataset = xr.combine_by_coords(
1310
+ [kpi, population] + ([controls] if controls is not None else [])
1311
+ )
1302
1312
  if revenue_per_kpi is not None:
1303
1313
  dataset = xr.combine_by_coords([dataset, revenue_per_kpi])
1304
1314
  if media is not None:
@@ -1347,7 +1357,7 @@ def random_dataset(
1347
1357
 
1348
1358
  def dataset_to_dataframe(
1349
1359
  dataset: xr.Dataset,
1350
- controls_column_names: list[str],
1360
+ controls_column_names: list[str] | None = None,
1351
1361
  media_column_names: list[str] | None = None,
1352
1362
  media_spend_column_names: list[str] | None = None,
1353
1363
  reach_column_names: list[str] | None = None,
@@ -1396,10 +1406,15 @@ def dataset_to_dataframe(
1396
1406
  )
1397
1407
  population = dataset[c.POPULATION].to_dataframe(name=c.POPULATION)
1398
1408
 
1399
- controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
1400
- controls.columns = controls_column_names
1409
+ if controls_column_names is not None:
1410
+ controls = dataset[c.CONTROLS].to_dataframe(name=c.CONTROLS).unstack()
1411
+ controls.columns = controls_column_names
1412
+ else:
1413
+ controls = None
1401
1414
 
1402
- result = kpi.join(revenue_per_kpi).join(population).join(controls)
1415
+ result = kpi.join(revenue_per_kpi).join(population)
1416
+ if controls is not None:
1417
+ result = result.join(controls)
1403
1418
 
1404
1419
  if non_media_column_names is not None:
1405
1420
  non_media_treatments = (
@@ -1471,7 +1486,7 @@ def random_dataframe(
1471
1486
  n_geos,
1472
1487
  n_times,
1473
1488
  n_media_times,
1474
- n_controls,
1489
+ n_controls=None,
1475
1490
  n_media_channels=None,
1476
1491
  n_rf_channels=None,
1477
1492
  seed=0,
@@ -1489,7 +1504,9 @@ def random_dataframe(
1489
1504
 
1490
1505
  return dataset_to_dataframe(
1491
1506
  dataset,
1492
- controls_column_names=_sample_names('control_', n_controls),
1507
+ controls_column_names=(
1508
+ _sample_names('control_', n_controls) if n_controls else None
1509
+ ),
1493
1510
  media_column_names=_sample_names('media_', n_media_channels),
1494
1511
  media_spend_column_names=_sample_names('media_spend_', n_media_channels),
1495
1512
  reach_column_names=_sample_names('reach_', n_rf_channels),
@@ -1499,7 +1516,7 @@ def random_dataframe(
1499
1516
 
1500
1517
 
1501
1518
  def sample_coord_to_columns(
1502
- n_controls: int,
1519
+ n_controls: int | None = None,
1503
1520
  n_media_channels: int | None = None,
1504
1521
  n_rf_channels: int | None = None,
1505
1522
  n_non_media_channels: int | None = None,
@@ -1550,7 +1567,7 @@ def sample_coord_to_columns(
1550
1567
  kpi=c.KPI,
1551
1568
  revenue_per_kpi=c.REVENUE_PER_KPI if include_revenue_per_kpi else None,
1552
1569
  population=c.POPULATION,
1553
- controls=_sample_names('control_', n_controls),
1570
+ controls=(_sample_names('control_', n_controls) if n_controls else None),
1554
1571
  media=media,
1555
1572
  media_spend=media_spend,
1556
1573
  reach=reach,
@@ -1668,7 +1685,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1668
1685
  n_geos: int = 10,
1669
1686
  n_times: int = 50,
1670
1687
  n_media_times: int = 53,
1671
- n_controls: int = 2,
1688
+ n_controls: int | None = 2,
1672
1689
  n_non_media_channels: int | None = None,
1673
1690
  n_media_channels: int | None = None,
1674
1691
  n_rf_channels: int | None = None,
@@ -1694,10 +1711,10 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1694
1711
  kpi_type=c.NON_REVENUE,
1695
1712
  revenue_per_kpi=dataset.revenue_per_kpi,
1696
1713
  population=dataset.population,
1697
- controls=dataset.controls,
1698
- non_media_treatments=dataset.non_media_treatments
1699
- if n_non_media_channels
1700
- else None,
1714
+ controls=(dataset.controls if n_controls else None),
1715
+ non_media_treatments=(
1716
+ dataset.non_media_treatments if n_non_media_channels else None
1717
+ ),
1701
1718
  media=dataset.media if n_media_channels else None,
1702
1719
  media_spend=dataset.media_spend if n_media_channels else None,
1703
1720
  reach=dataset.reach if n_rf_channels else None,
@@ -1705,9 +1722,9 @@ def sample_input_data_non_revenue_revenue_per_kpi(
1705
1722
  rf_spend=dataset.rf_spend if n_rf_channels else None,
1706
1723
  organic_media=dataset.organic_media if n_organic_media_channels else None,
1707
1724
  organic_reach=dataset.organic_reach if n_organic_rf_channels else None,
1708
- organic_frequency=dataset.organic_frequency
1709
- if n_organic_rf_channels
1710
- else None,
1725
+ organic_frequency=(
1726
+ dataset.organic_frequency if n_organic_rf_channels else None
1727
+ ),
1711
1728
  )
1712
1729
 
1713
1730
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
meridian/model/knots.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
meridian/model/media.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
meridian/model/model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -214,7 +214,9 @@ class Meridian:
214
214
  )
215
215
 
216
216
  @functools.cached_property
217
- def controls(self) -> tf.Tensor:
217
+ def controls(self) -> tf.Tensor | None:
218
+ if self.input_data.controls is None:
219
+ return None
218
220
  return tf.convert_to_tensor(self.input_data.controls, dtype=tf.float32)
219
221
 
220
222
  @functools.cached_property
@@ -271,6 +273,8 @@ class Meridian:
271
273
 
272
274
  @property
273
275
  def n_controls(self) -> int:
276
+ if self.input_data.control_variable is None:
277
+ return 0
274
278
  return len(self.input_data.control_variable)
275
279
 
276
280
  @property
@@ -304,7 +308,13 @@ class Meridian:
304
308
  )
305
309
 
306
310
  @functools.cached_property
307
- def controls_transformer(self) -> transformers.CenteringAndScalingTransformer:
311
+ def controls_transformer(
312
+ self,
313
+ ) -> transformers.CenteringAndScalingTransformer | None:
314
+ """Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
315
+ if self.controls is None:
316
+ return None
317
+
308
318
  if self.model_spec.control_population_scaling_id is not None:
309
319
  controls_population_scaling_id = tf.convert_to_tensor(
310
320
  self.model_spec.control_population_scaling_id, dtype=bool
@@ -343,8 +353,12 @@ class Meridian:
343
353
  return transformers.KpiTransformer(self.kpi, self.population)
344
354
 
345
355
  @functools.cached_property
346
- def controls_scaled(self) -> tf.Tensor:
347
- return self.controls_transformer.forward(self.controls)
356
+ def controls_scaled(self) -> tf.Tensor | None:
357
+ if self.controls is not None:
358
+ # If `controls` is defined, then `controls_transformer` is also defined.
359
+ return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
360
+ else:
361
+ return None
348
362
 
349
363
  @functools.cached_property
350
364
  def non_media_treatments_normalized(self) -> tf.Tensor | None:
@@ -894,11 +908,12 @@ class Meridian:
894
908
  if self.is_national:
895
909
  return
896
910
 
897
- self._check_if_no_geo_variation(
898
- self.controls_scaled,
899
- constants.CONTROLS,
900
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
901
- )
911
+ if self.input_data.controls is not None:
912
+ self._check_if_no_geo_variation(
913
+ self.controls_scaled,
914
+ constants.CONTROLS,
915
+ self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
916
+ )
902
917
  if self.input_data.non_media_treatments is not None:
903
918
  self._check_if_no_geo_variation(
904
919
  self.non_media_treatments_normalized,
@@ -971,12 +986,12 @@ class Meridian:
971
986
 
972
987
  def _validate_time_invariants(self):
973
988
  """Validates model time invariants."""
974
-
975
- self._check_if_no_time_variation(
976
- self.controls_scaled,
977
- constants.CONTROLS,
978
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
979
- )
989
+ if self.input_data.controls is not None:
990
+ self._check_if_no_time_variation(
991
+ self.controls_scaled,
992
+ constants.CONTROLS,
993
+ self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
994
+ )
980
995
  if self.input_data.non_media_treatments is not None:
981
996
  self._check_if_no_time_variation(
982
997
  self.non_media_treatments_normalized,
@@ -1356,31 +1371,36 @@ class Meridian:
1356
1371
  self, n_chains: int, n_draws: int
1357
1372
  ) -> Mapping[str, np.ndarray | Sequence[str]]:
1358
1373
  """Creates data coordinates for inference data."""
1359
- media_channel_values = (
1374
+ media_channel_names = (
1360
1375
  self.input_data.media_channel
1361
1376
  if self.input_data.media_channel is not None
1362
1377
  else np.array([])
1363
1378
  )
1364
- rf_channel_values = (
1379
+ rf_channel_names = (
1365
1380
  self.input_data.rf_channel
1366
1381
  if self.input_data.rf_channel is not None
1367
1382
  else np.array([])
1368
1383
  )
1369
- organic_media_channel_values = (
1384
+ organic_media_channel_names = (
1370
1385
  self.input_data.organic_media_channel
1371
1386
  if self.input_data.organic_media_channel is not None
1372
1387
  else np.array([])
1373
1388
  )
1374
- organic_rf_channel_values = (
1389
+ organic_rf_channel_names = (
1375
1390
  self.input_data.organic_rf_channel
1376
1391
  if self.input_data.organic_rf_channel is not None
1377
1392
  else np.array([])
1378
1393
  )
1379
- non_media_channel_values = (
1394
+ non_media_channel_names = (
1380
1395
  self.input_data.non_media_channel
1381
1396
  if self.input_data.non_media_channel is not None
1382
1397
  else np.array([])
1383
1398
  )
1399
+ control_variable_names = (
1400
+ self.input_data.control_variable
1401
+ if self.input_data.control_variable is not None
1402
+ else np.array([])
1403
+ )
1384
1404
  return {
1385
1405
  constants.CHAIN: np.arange(n_chains),
1386
1406
  constants.DRAW: np.arange(n_draws),
@@ -1388,12 +1408,12 @@ class Meridian:
1388
1408
  constants.TIME: self.input_data.time,
1389
1409
  constants.MEDIA_TIME: self.input_data.media_time,
1390
1410
  constants.KNOTS: np.arange(self.knot_info.n_knots),
1391
- constants.CONTROL_VARIABLE: self.input_data.control_variable,
1392
- constants.NON_MEDIA_CHANNEL: non_media_channel_values,
1393
- constants.MEDIA_CHANNEL: media_channel_values,
1394
- constants.RF_CHANNEL: rf_channel_values,
1395
- constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_values,
1396
- constants.ORGANIC_RF_CHANNEL: organic_rf_channel_values,
1411
+ constants.CONTROL_VARIABLE: control_variable_names,
1412
+ constants.NON_MEDIA_CHANNEL: non_media_channel_names,
1413
+ constants.MEDIA_CHANNEL: media_channel_names,
1414
+ constants.RF_CHANNEL: rf_channel_names,
1415
+ constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
1416
+ constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
1397
1417
  }
1398
1418
 
1399
1419
  def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -70,6 +70,10 @@ class WithInputDataSamples:
70
70
  _TEST_DIR,
71
71
  "sample_prior_media_only.nc",
72
72
  )
73
+ _TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
74
+ _TEST_DIR,
75
+ "sample_prior_media_only_no_controls.nc",
76
+ )
73
77
  _TEST_SAMPLE_PRIOR_RF_ONLY_PATH = os.path.join(
74
78
  _TEST_DIR,
75
79
  "sample_prior_rf_only.nc",
@@ -82,6 +86,10 @@ class WithInputDataSamples:
82
86
  _TEST_DIR,
83
87
  "sample_posterior_media_only.nc",
84
88
  )
89
+ _TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH = os.path.join(
90
+ _TEST_DIR,
91
+ "sample_posterior_media_only_no_controls.nc",
92
+ )
85
93
  _TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH = os.path.join(
86
94
  _TEST_DIR,
87
95
  "sample_posterior_rf_only.nc",
@@ -172,6 +180,17 @@ class WithInputDataSamples:
172
180
  seed=0,
173
181
  )
174
182
  )
183
+ self.input_data_with_media_and_rf_no_controls = (
184
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
185
+ n_geos=self._N_GEOS,
186
+ n_times=self._N_TIMES,
187
+ n_media_times=self._N_MEDIA_TIMES,
188
+ n_controls=None,
189
+ n_media_channels=self._N_MEDIA_CHANNELS,
190
+ n_rf_channels=self._N_RF_CHANNELS,
191
+ seed=0,
192
+ )
193
+ )
175
194
  self.short_input_data_with_media_only = (
176
195
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
177
196
  n_geos=self._N_GEOS,
@@ -182,6 +201,16 @@ class WithInputDataSamples:
182
201
  seed=0,
183
202
  )
184
203
  )
204
+ self.short_input_data_with_media_only_no_controls = (
205
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
206
+ n_geos=self._N_GEOS,
207
+ n_times=self._N_TIMES_SHORT,
208
+ n_media_times=self._N_MEDIA_TIMES_SHORT,
209
+ n_controls=0,
210
+ n_media_channels=self._N_MEDIA_CHANNELS,
211
+ seed=0,
212
+ )
213
+ )
185
214
  self.short_input_data_with_rf_only = (
186
215
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
187
216
  n_geos=self._N_GEOS,
@@ -231,6 +260,9 @@ class WithInputDataSamples:
231
260
  test_prior_media_only = xr.open_dataset(
232
261
  self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
233
262
  )
263
+ test_prior_media_only_no_controls = xr.open_dataset(
264
+ self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
265
+ )
234
266
  test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
235
267
  self.test_dist_media_and_rf = collections.OrderedDict({
236
268
  param: tf.convert_to_tensor(test_prior_media_and_rf[param])
@@ -243,6 +275,18 @@ class WithInputDataSamples:
243
275
  for param in constants.COMMON_PARAMETER_NAMES
244
276
  + constants.MEDIA_PARAMETER_NAMES
245
277
  })
278
+ self.test_dist_media_only_no_controls = collections.OrderedDict({
279
+ param: tf.convert_to_tensor(test_prior_media_only_no_controls[param])
280
+ for param in (
281
+ set(
282
+ constants.COMMON_PARAMETER_NAMES
283
+ + constants.MEDIA_PARAMETER_NAMES
284
+ )
285
+ - set(
286
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
287
+ )
288
+ )
289
+ })
246
290
  self.test_dist_rf_only = collections.OrderedDict({
247
291
  param: tf.convert_to_tensor(test_prior_rf_only[param])
248
292
  for param in constants.COMMON_PARAMETER_NAMES
@@ -255,6 +299,9 @@ class WithInputDataSamples:
255
299
  test_posterior_media_only = xr.open_dataset(
256
300
  self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
257
301
  )
302
+ test_posterior_media_only_no_controls = xr.open_dataset(
303
+ self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
304
+ )
258
305
  test_posterior_rf_only = xr.open_dataset(
259
306
  self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
260
307
  )
@@ -273,6 +320,21 @@ class WithInputDataSamples:
273
320
  for param in constants.COMMON_PARAMETER_NAMES
274
321
  + constants.MEDIA_PARAMETER_NAMES
275
322
  }
323
+ posterior_params_to_tensors_media_only_no_controls = {
324
+ param: _convert_with_swap(
325
+ test_posterior_media_only_no_controls[param],
326
+ n_burnin=self._N_BURNIN,
327
+ )
328
+ for param in (
329
+ set(
330
+ constants.COMMON_PARAMETER_NAMES
331
+ + constants.MEDIA_PARAMETER_NAMES
332
+ )
333
+ - set(
334
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
335
+ )
336
+ )
337
+ }
276
338
  posterior_params_to_tensors_rf_only = {
277
339
  param: _convert_with_swap(
278
340
  test_posterior_rf_only[param], n_burnin=self._N_BURNIN
@@ -290,6 +352,18 @@ class WithInputDataSamples:
290
352
  "StructTuple",
291
353
  constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
292
354
  )(**posterior_params_to_tensors_media_only)
355
+ self.test_posterior_states_media_only_no_controls = collections.namedtuple(
356
+ "StructTuple",
357
+ (
358
+ set(
359
+ constants.COMMON_PARAMETER_NAMES
360
+ + constants.MEDIA_PARAMETER_NAMES
361
+ )
362
+ - set(
363
+ constants.CONTROL_PARAMETERS + constants.GEO_CONTROL_PARAMETERS
364
+ )
365
+ ),
366
+ )(**posterior_params_to_tensors_media_only_no_controls)
293
367
  self.test_posterior_states_rf_only = collections.namedtuple(
294
368
  "StructTuple",
295
369
  constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -120,8 +120,6 @@ class PosteriorMCMCSampler:
120
120
  def joint_dist_unpinned():
121
121
  # Sample directly from prior.
122
122
  knot_values = yield prior_broadcast.knot_values
123
- gamma_c = yield prior_broadcast.gamma_c
124
- xi_c = yield prior_broadcast.xi_c
125
123
  sigma = yield prior_broadcast.sigma
126
124
 
127
125
  tau_g_excl_baseline = yield tfp.distributions.Sample(
@@ -377,19 +375,25 @@ class PosteriorMCMCSampler:
377
375
  combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1)
378
376
 
379
377
  sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos]))
380
- gamma_gc_dev = yield tfp.distributions.Sample(
381
- tfp.distributions.Normal(0, 1),
382
- [n_geos, n_controls],
383
- name=constants.GAMMA_GC_DEV,
384
- )
385
- gamma_gc = yield tfp.distributions.Deterministic(
386
- gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
387
- )
388
- y_pred_combined_media = (
389
- tau_gt
390
- + tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta)
391
- + tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc)
378
+ y_pred_combined_media = tau_gt + tf.einsum(
379
+ "gtm,gm->gt", combined_media_transformed, combined_beta
392
380
  )
381
+ # Omit gamma_c, xi_c, and gamma_gc from joint distribution output if
382
+ # there are no control variables in the model.
383
+ if n_controls:
384
+ gamma_c = yield prior_broadcast.gamma_c
385
+ xi_c = yield prior_broadcast.xi_c
386
+ gamma_gc_dev = yield tfp.distributions.Sample(
387
+ tfp.distributions.Normal(0, 1),
388
+ [n_geos, n_controls],
389
+ name=constants.GAMMA_GC_DEV,
390
+ )
391
+ gamma_gc = yield tfp.distributions.Deterministic(
392
+ gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
393
+ )
394
+ y_pred_combined_media += tf.einsum(
395
+ "gtc,gc->gt", controls_scaled, gamma_gc
396
+ )
393
397
 
394
398
  if mmm.non_media_treatments is not None:
395
399
  xi_n = yield prior_broadcast.xi_n
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.