google-meridian 1.1.2__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-meridian
3
- Version: 1.1.2
3
+ Version: 1.1.3
4
4
  Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
5
  Author-email: The Meridian Authors <no-reply@google.com>
6
6
  License:
@@ -397,7 +397,7 @@ To cite this repository:
397
397
  author = {Google Meridian Marketing Mix Modeling Team},
398
398
  title = {Meridian: Marketing Mix Modeling},
399
399
  url = {https://github.com/google/meridian},
400
- version = {1.1.2},
400
+ version = {1.1.3},
401
401
  year = {2025},
402
402
  }
403
403
  ```
@@ -1,10 +1,11 @@
1
- google_meridian-1.1.2.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
2
- meridian/__init__.py,sha256=rWkSMlr2TiRmH7Xf9z1Bj3grQiSbmrxl3dtGB9YGn9o,714
3
- meridian/constants.py,sha256=AWhDEP9VcyQtPCbZhM6cPXHeWuz19wjaqB5lGz6qBsw,17161
1
+ google_meridian-1.1.3.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
2
+ meridian/__init__.py,sha256=XROKwHNVQvEa371QCXAHik5wN_YKObOdJQX9bJ2c4M4,832
3
+ meridian/constants.py,sha256=VAVHyGfm9FyDd0dWomfqK5XYDUt9qJx7SAM4rzDh3RQ,17195
4
+ meridian/version.py,sha256=CUTXDDaOfXFTukX_ywPK6Q3PiK9hMyJbmJRBeb5ez7c,644
4
5
  meridian/analysis/__init__.py,sha256=nGBYz7k9FVdadO_WVGMKJcfq7Yy_TuuP8zgee4i9pSA,836
5
- meridian/analysis/analyzer.py,sha256=AP2YJpM2R2qMJ-rwtMmgu-cM-xJLJCFodSaP9K8f0Do,204458
6
+ meridian/analysis/analyzer.py,sha256=FY_SvnkmEqqCIS37UXB3bvaQi-U3BwLcSWhH1puTzdQ,206003
6
7
  meridian/analysis/formatter.py,sha256=ENIdR1CRiaVqIGEXx1HcnsA4ewgDD_nhsYCweJAThaw,7270
7
- meridian/analysis/optimizer.py,sha256=Se6_sg0O3A4p80vdVnRtDeyNaE5s-ywxKoU0CODQsWM,107608
8
+ meridian/analysis/optimizer.py,sha256=P4uMcV9ByqMapqa1TEqcnu-3NyTH9fR8QLszdKxRAFc,107801
8
9
  meridian/analysis/summarizer.py,sha256=IthOUTMufGvAvbxiDhaKwe7uYCyiTyiQ8vgdmUtdevs,18855
9
10
  meridian/analysis/summary_text.py,sha256=I_smDkZJYp2j77ea-9AIbgeraDa7-qUYyb-IthP2qO4,12438
10
11
  meridian/analysis/test_utils.py,sha256=ES1r1akhRjD4pf2oTaGqzDfGNu9weAcLv6UZRuIkfEc,77699
@@ -22,25 +23,25 @@ meridian/data/__init__.py,sha256=StIe-wfYnnbfUbKtZHwnAQcRQUS8XCZk_PCaEzw90Ww,929
22
23
  meridian/data/arg_builder.py,sha256=Kqlt88bOqFj6D3xNwvWo4MBwNwcDFHzd-wMfEOmLoPU,3741
23
24
  meridian/data/data_frame_input_data_builder.py,sha256=3m6wrcC0psmD2ijsXk3R4uByA0Tu2gJxZBGaTS6Z7Io,22040
24
25
  meridian/data/input_data.py,sha256=teJPKTBfW-AzBWgf_fEO_S_Z1J_veqQkCvctINaid6I,39749
25
- meridian/data/input_data_builder.py,sha256=fFJTmUuIdTnTnZPtZNTiEf4_fsqR_haY7O9ZOFj47bE,25409
26
- meridian/data/load.py,sha256=cvvesjL6Dc7pYu2nOl558gUOZVAW_B69GirzHocyY3Q,42855
26
+ meridian/data/input_data_builder.py,sha256=08E_MZLrCzwfjvjPWFVs7o_094vVJ5o6VmbTfrg4NUM,25602
27
+ meridian/data/load.py,sha256=B-12fBhsghN7wj0A9IWyT7BVogIXjuUDDvR34JJFwPM,45157
27
28
  meridian/data/nd_array_input_data_builder.py,sha256=lfpmnENGuSGKyUd7bDGAwoLqHqteOKmHdKl0VI2wCQA,16341
28
29
  meridian/data/test_utils.py,sha256=6GJrPmeaF4uzMxxRgzERGv4g1XMUHwI0s7qDVMZUjuI,55565
29
30
  meridian/data/time_coordinates.py,sha256=C5A5fscSLjPH6G9YT8OspgIlCrkMY7y8dMFEt3tNSnE,9874
30
31
  meridian/mlflow/__init__.py,sha256=elwXUqPQYi7VF9PYjelU1tydfcUrmtuoq6eJCOnV9bk,693
31
- meridian/mlflow/autolog.py,sha256=Duubd_Z2Exlk_MJqGTfMIfFjaDUqG_YnsRGjzY4Hn84,1696
32
+ meridian/mlflow/autolog.py,sha256=s240eLGAurzaNsulwRlyM1ZdBLvUzyr2eOMYgOyWAzk,6393
32
33
  meridian/model/__init__.py,sha256=9NFfqUE5WgFc-9lQMkbfkwwV-bQIz0tsQ_3Jyq0A4SU,982
33
34
  meridian/model/adstock_hill.py,sha256=20A_6rbDUAADEkkHspB7JpCm5tYfYS1FQ6hJMLu21Pk,9283
34
35
  meridian/model/knots.py,sha256=KPEgnb-UdQQ4QBugOYEke-zBgEghgTmeCMoeiJ30meY,8054
35
- meridian/model/media.py,sha256=R0LnMUNTuGzXD2lzNRRORA4-p21xpdhkVVsvFaWtEK0,13819
36
- meridian/model/model.py,sha256=KM2EU7eAK5UHDAn1jbUEI_SBrDkz-Bc93R8qRBEiic8,61500
36
+ meridian/model/media.py,sha256=3BaPX8xYAFMEvf0mz3mBSCIDWViIs7M218nrCklc6Fk,14099
37
+ meridian/model/model.py,sha256=BlLPyskHrEx5D71mUZFbNxS2VjkQgaiaE6hLKvQ5D3A,61489
37
38
  meridian/model/model_test_data.py,sha256=hDDTEzm72LknW9c5E_dNsy4Mm4Tfs6AirhGf_QxykFs,15552
38
- meridian/model/posterior_sampler.py,sha256=jjLqcYEAorVJ_2nmhpkVUjCGAyNUZYPTEXVTDHufbqA,27727
39
+ meridian/model/posterior_sampler.py,sha256=K49zWTTelME2rL1JLeFAdMPzL0OwrBvyAXA3oR-kgSI,27801
39
40
  meridian/model/prior_distribution.py,sha256=IEDU1rabcmKNY8lxwbbO4OUAlMHPIMa7flM_zsu3DLM,42417
40
- meridian/model/prior_sampler.py,sha256=jSaxFmJzyN2OKqKyU059Ar4Yr565w4zlInPl4zxjGZk,23212
41
+ meridian/model/prior_sampler.py,sha256=cmu6jG-bSEkYDkjVUxl3iSxrL7r-LN7a77cb2Vc0LoA,23218
41
42
  meridian/model/spec.py,sha256=0HNiMQUWQpYvWYOZr1_fj2ah8tH-bEyfEjoqgBZ9Lc0,18049
42
43
  meridian/model/transformers.py,sha256=nRjzq1fQG0ypldxboM7Gqok6WSAXAS1witRXoAzeH9Q,7763
43
- google_meridian-1.1.2.dist-info/METADATA,sha256=qoLl6RDBz8LxrnJ3c4-hDiIMcL5OC8G8X61rotn6PGs,22201
44
- google_meridian-1.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- google_meridian-1.1.2.dist-info/top_level.txt,sha256=nwaCebZvvU34EopTKZsjK0OMTFjVnkf4FfnBN_TAc0g,9
46
- google_meridian-1.1.2.dist-info/RECORD,,
44
+ google_meridian-1.1.3.dist-info/METADATA,sha256=5W_XWui7q5gH68OC3Z-PXbDOeBftDbWuhqznNv7fOAk,22201
45
+ google_meridian-1.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
+ google_meridian-1.1.3.dist-info/top_level.txt,sha256=nwaCebZvvU34EopTKZsjK0OMTFjVnkf4FfnBN_TAc0g,9
47
+ google_meridian-1.1.3.dist-info/RECORD,,
meridian/__init__.py CHANGED
@@ -13,10 +13,12 @@
13
13
  # limitations under the License.
14
14
 
15
15
  """Meridian API."""
16
-
17
- __version__ = "1.1.2"
18
-
19
-
20
16
  from meridian import analysis
21
17
  from meridian import data
22
18
  from meridian import model
19
+ from meridian.version import __version__
20
+
21
+ try:
22
+ from meridian import mlflow # pylint: disable=g-import-not-at-top
23
+ except ImportError:
24
+ pass
@@ -65,6 +65,8 @@ class DataTensors(tf.experimental.ExtensionType):
65
65
  time dimension `T`.
66
66
  frequency: Optional tensor with dimensions `(n_geos, T, n_rf_channels)` for
67
67
  any time dimension `T`.
68
+ rf_impressions: Optional tensor with dimensions `(n_geos, T, n_rf_channels)`
69
+ for any time dimension `T`.
68
70
  rf_spend: Optional tensor with dimensions `(n_geos, T, n_rf_channels)` for
69
71
  any time dimension `T`.
70
72
  organic_media: Optional tensor with dimensions `(n_geos, T,
@@ -86,6 +88,7 @@ class DataTensors(tf.experimental.ExtensionType):
86
88
  media_spend: Optional[tf.Tensor]
87
89
  reach: Optional[tf.Tensor]
88
90
  frequency: Optional[tf.Tensor]
91
+ rf_impressions: Optional[tf.Tensor]
89
92
  rf_spend: Optional[tf.Tensor]
90
93
  organic_media: Optional[tf.Tensor]
91
94
  organic_reach: Optional[tf.Tensor]
@@ -101,6 +104,7 @@ class DataTensors(tf.experimental.ExtensionType):
101
104
  media_spend: Optional[tf.Tensor] = None,
102
105
  reach: Optional[tf.Tensor] = None,
103
106
  frequency: Optional[tf.Tensor] = None,
107
+ rf_impressions: Optional[tf.Tensor] = None,
104
108
  rf_spend: Optional[tf.Tensor] = None,
105
109
  organic_media: Optional[tf.Tensor] = None,
106
110
  organic_reach: Optional[tf.Tensor] = None,
@@ -118,6 +122,11 @@ class DataTensors(tf.experimental.ExtensionType):
118
122
  self.frequency = (
119
123
  tf.cast(frequency, tf.float32) if frequency is not None else None
120
124
  )
125
+ self.rf_impressions = (
126
+ tf.cast(rf_impressions, tf.float32)
127
+ if rf_impressions is not None
128
+ else None
129
+ )
121
130
  self.rf_spend = (
122
131
  tf.cast(rf_spend, tf.float32) if rf_spend is not None else None
123
132
  )
@@ -189,7 +198,10 @@ class DataTensors(tf.experimental.ExtensionType):
189
198
  """
190
199
  for field in self._tf_extension_type_fields():
191
200
  new_tensor = getattr(self, field.name)
192
- old_tensor = getattr(meridian.input_data, field.name)
201
+ if field.name == constants.RF_IMPRESSIONS:
202
+ old_tensor = getattr(meridian.rf_tensors, field.name)
203
+ else:
204
+ old_tensor = getattr(meridian.input_data, field.name)
193
205
  # The time dimension is always the second dimension, except for when spend
194
206
  # data is provided with only one dimension of (n_channels).
195
207
  if (
@@ -293,7 +305,13 @@ class DataTensors(tf.experimental.ExtensionType):
293
305
  "This is not supported and will be ignored."
294
306
  )
295
307
  if field.name in required_variables:
296
- if getattr(meridian.input_data, field.name) is None:
308
+ if field.name == constants.RF_IMPRESSIONS:
309
+ if meridian.n_rf_channels == 0:
310
+ raise ValueError(
311
+ "New `rf_impressions` is not allowed because there are no R&F"
312
+ " channels in the Meridian model."
313
+ )
314
+ elif getattr(meridian.input_data, field.name) is None:
297
315
  raise ValueError(
298
316
  f"New `{field.name}` is not allowed because the input data to the"
299
317
  f" Meridian model does not contain `{field.name}`."
@@ -322,7 +340,10 @@ class DataTensors(tf.experimental.ExtensionType):
322
340
  if var_name in [constants.REVENUE_PER_KPI, constants.TIME]:
323
341
  continue
324
342
  new_tensor = getattr(self, var_name)
325
- old_tensor = getattr(meridian.input_data, var_name)
343
+ if var_name == constants.RF_IMPRESSIONS:
344
+ old_tensor = getattr(meridian.rf_tensors, var_name)
345
+ else:
346
+ old_tensor = getattr(meridian.input_data, var_name)
326
347
  if new_tensor is not None:
327
348
  assert old_tensor is not None
328
349
  if new_tensor.shape[-1] != old_tensor.shape[-1]:
@@ -337,7 +358,10 @@ class DataTensors(tf.experimental.ExtensionType):
337
358
  """Validates the time dimension of the specified data variables."""
338
359
  for var_name in required_fields:
339
360
  new_tensor = getattr(self, var_name)
340
- old_tensor = getattr(meridian.input_data, var_name)
361
+ if var_name == constants.RF_IMPRESSIONS:
362
+ old_tensor = getattr(meridian.rf_tensors, var_name)
363
+ else:
364
+ old_tensor = getattr(meridian.input_data, var_name)
341
365
 
342
366
  # Skip spend data with only 1 dimension of (n_channels).
343
367
  if (
@@ -375,7 +399,10 @@ class DataTensors(tf.experimental.ExtensionType):
375
399
  missing_params = []
376
400
  for var_name in required_fields:
377
401
  new_tensor = getattr(self, var_name)
378
- old_tensor = getattr(meridian.input_data, var_name)
402
+ if var_name == constants.RF_IMPRESSIONS:
403
+ old_tensor = getattr(meridian.rf_tensors, var_name)
404
+ else:
405
+ old_tensor = getattr(meridian.input_data, var_name)
379
406
 
380
407
  if old_tensor is None:
381
408
  continue
@@ -3415,6 +3442,7 @@ class Analyzer:
3415
3442
  def optimal_freq(
3416
3443
  self,
3417
3444
  new_data: DataTensors | None = None,
3445
+ max_frequency: float | None = None,
3418
3446
  freq_grid: Sequence[float] | None = None,
3419
3447
  use_posterior: bool = True,
3420
3448
  use_kpi: bool = False,
@@ -3443,7 +3471,7 @@ class Analyzer:
3443
3471
  ROI numerator is KPI units.
3444
3472
 
3445
3473
  Args:
3446
- new_data: Optional `DataTensors` object containing `reach`, `frequency`,
3474
+ new_data: Optional `DataTensors` object containing `rf_impressions`,
3447
3475
  `rf_spend`, and `revenue_per_kpi`. If provided, the optimal frequency is
3448
3476
  calculated using the values of the tensors passed in `new_data` and the
3449
3477
  original values of all the remaining tensors. If `None`, the historical
@@ -3451,6 +3479,10 @@ class Analyzer:
3451
3479
  tensors in `new_data` is provided with a different number of time
3452
3480
  periods than in `InputData`, then all tensors must be provided with the
3453
3481
  same number of time periods.
3482
+ max_frequency: Maximum frequency value used to calculate the frequency
3483
+ grid. If `None`, the maximum frequency value is calculated from the
3484
+ historic frequency (maximum value of Meridian.input_data, not
3485
+ `new_data`). If `freq_grid` is provided, this argument has no effect.
3454
3486
  freq_grid: List of frequency values. The ROI of each channel is calculated
3455
3487
  for each frequency value in the list. By default, the list includes
3456
3488
  numbers from `1.0` to the maximum frequency in increments of `0.1`.
@@ -3506,7 +3538,11 @@ class Analyzer:
3506
3538
  )
3507
3539
 
3508
3540
  filled_data = new_data.validate_and_fill_missing_data(
3509
- constants.RF_DATA,
3541
+ [
3542
+ constants.RF_IMPRESSIONS,
3543
+ constants.RF_SPEND,
3544
+ constants.REVENUE_PER_KPI,
3545
+ ],
3510
3546
  self._meridian,
3511
3547
  )
3512
3548
  # TODO: Once treatment type filtering is added, remove adding
@@ -3527,7 +3563,9 @@ class Analyzer:
3527
3563
  (self._meridian.n_geos, n_times, self._meridian.n_media_channels)
3528
3564
  )
3529
3565
 
3530
- max_freq = np.max(np.array(filled_data.frequency))
3566
+ max_freq = max_frequency or np.max(
3567
+ np.array(self._meridian.rf_tensors.frequency)
3568
+ )
3531
3569
  if freq_grid is None:
3532
3570
  freq_grid = np.arange(1, max_freq, 0.1)
3533
3571
 
@@ -3537,8 +3575,8 @@ class Analyzer:
3537
3575
  metric_grid = np.zeros((len(freq_grid), self._meridian.n_rf_channels, 4))
3538
3576
 
3539
3577
  for i, freq in enumerate(freq_grid):
3540
- new_frequency = tf.ones_like(filled_data.frequency) * freq
3541
- new_reach = filled_data.frequency * filled_data.reach / new_frequency
3578
+ new_frequency = tf.ones_like(filled_data.rf_impressions) * freq
3579
+ new_reach = filled_data.rf_impressions / new_frequency
3542
3580
  new_roi_data = DataTensors(
3543
3581
  reach=new_reach,
3544
3582
  frequency=new_frequency,
@@ -3568,12 +3606,10 @@ class Analyzer:
3568
3606
 
3569
3607
  optimal_frequency = [freq_grid[i] for i in optimal_freq_idx]
3570
3608
  optimal_frequency_tensor = tf.convert_to_tensor(
3571
- tf.ones_like(filled_data.frequency) * optimal_frequency,
3609
+ tf.ones_like(filled_data.rf_impressions) * optimal_frequency,
3572
3610
  tf.float32,
3573
3611
  )
3574
- optimal_reach = (
3575
- filled_data.frequency * filled_data.reach / optimal_frequency_tensor
3576
- )
3612
+ optimal_reach = filled_data.rf_impressions / optimal_frequency_tensor
3577
3613
 
3578
3614
  new_summary_metrics_data = DataTensors(
3579
3615
  reach=optimal_reach,
@@ -1863,9 +1863,14 @@ class BudgetOptimizer:
1863
1863
  )
1864
1864
  )
1865
1865
  if self._meridian.n_rf_channels > 0 and use_optimal_frequency:
1866
+ opt_freq_data = analyzer.DataTensors(
1867
+ rf_impressions=filled_data.reach * filled_data.frequency,
1868
+ rf_spend=filled_data.rf_spend,
1869
+ revenue_per_kpi=filled_data.revenue_per_kpi,
1870
+ )
1866
1871
  optimal_frequency = tf.convert_to_tensor(
1867
1872
  self._analyzer.optimal_freq(
1868
- new_data=filled_data.filter_fields(c.RF_DATA),
1873
+ new_data=opt_freq_data,
1869
1874
  use_posterior=use_posterior,
1870
1875
  selected_times=selected_times,
1871
1876
  use_kpi=use_kpi,
meridian/constants.py CHANGED
@@ -63,6 +63,7 @@ CONTROLS = 'controls'
63
63
  POPULATION = 'population'
64
64
  REACH = 'reach'
65
65
  FREQUENCY = 'frequency'
66
+ RF_IMPRESSIONS = 'rf_impressions'
66
67
  RF_SPEND = 'rf_spend'
67
68
  ORGANIC_MEDIA = 'organic_media'
68
69
  ORGANIC_REACH = 'organic_reach'
@@ -646,12 +646,13 @@ class InputDataBuilder(abc.ABC):
646
646
  """Normalizes the given `DataArray`'s coordinates in Meridian convention.
647
647
 
648
648
  Validates that time values are in the conventional Meridian format and
649
- that geos have national name if national.
649
+ that geos have national name if national. If geo coordinates are not string-
650
+ typed, they will be converted to strings.
650
651
 
651
652
  Args:
652
653
  da: The DataArray to normalize.
653
- time_dimension_name: The name of the time dimension. If None, the
654
- will skip time normalization.
654
+ time_dimension_name: The name of the time dimension. If None, the will
655
+ skip time normalization.
655
656
 
656
657
  Returns:
657
658
  The normalized DataArray.
@@ -686,6 +687,11 @@ class InputDataBuilder(abc.ABC):
686
687
  da = da.assign_coords(
687
688
  {constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]},
688
689
  )
690
+ else:
691
+ da = da.assign_coords(
692
+ {constants.GEO: da.coords[constants.GEO].astype(str)}
693
+ )
694
+
689
695
  return da
690
696
 
691
697
  def _validate_set(self, component: str, da: xr.DataArray):
meridian/data/load.py CHANGED
@@ -816,12 +816,55 @@ class DataFrameDataLoader(InputDataLoader):
816
816
  'organic_frequency': 'organic_frequency_to_channel',
817
817
  })
818
818
  for coord_name, channel_dict in required_mappings.items():
819
+ if getattr(self.coord_to_columns, coord_name, None) is not None:
820
+ if getattr(self, channel_dict, None) is None:
821
+ raise ValueError(
822
+ f"When {coord_name} data is provided, '{channel_dict}' is"
823
+ ' required.'
824
+ )
825
+ else:
826
+ if set(getattr(self, channel_dict)) != set(
827
+ getattr(self.coord_to_columns, coord_name)
828
+ ):
829
+ raise ValueError(
830
+ f'The {channel_dict} keys must have the same set of values as'
831
+ f' the {coord_name} columns.'
832
+ )
833
+ if (
834
+ self.media_to_channel is not None
835
+ and self.media_spend_to_channel is not None
836
+ ):
837
+ if set(self.media_to_channel.values()) != set(
838
+ self.media_spend_to_channel.values()
839
+ ):
840
+ raise ValueError(
841
+ 'The media and media_spend columns must have the same set of'
842
+ ' channels.'
843
+ )
844
+ if (
845
+ self.reach_to_channel is not None
846
+ and self.frequency_to_channel is not None
847
+ and self.rf_spend_to_channel is not None
848
+ ):
819
849
  if (
820
- getattr(self.coord_to_columns, coord_name, None) is not None
821
- and getattr(self, channel_dict, None) is None
850
+ set(self.reach_to_channel.values())
851
+ != set(self.frequency_to_channel.values())
852
+ != set(self.rf_spend_to_channel.values())
822
853
  ):
823
854
  raise ValueError(
824
- f"When {coord_name} data is provided, '{channel_dict}' is required."
855
+ 'The reach, frequency, and rf_spend columns must have the same set'
856
+ ' of channels.'
857
+ )
858
+ if (
859
+ self.organic_reach_to_channel is not None
860
+ and self.organic_frequency_to_channel is not None
861
+ ):
862
+ if set(self.organic_reach_to_channel.values()) != set(
863
+ self.organic_frequency_to_channel.values()
864
+ ):
865
+ raise ValueError(
866
+ 'The organic_reach and organic_frequency columns must have the'
867
+ ' same set of channels.'
825
868
  )
826
869
 
827
870
  def load(self) -> input_data.InputData:
@@ -861,28 +904,36 @@ class DataFrameDataLoader(InputDataLoader):
861
904
  self.coord_to_columns.geo,
862
905
  )
863
906
  if (
864
- self.coord_to_columns.media is not None
865
- and self.media_to_channel is not None
907
+ self.media_to_channel is not None
908
+ and self.media_spend_to_channel is not None
866
909
  ):
910
+ sorted_channels = sorted(self.media_to_channel.values())
911
+ inv_media_map = {v: k for k, v in self.media_to_channel.items()}
912
+ inv_spend_map = {v: k for k, v in self.media_spend_to_channel.items()}
913
+
867
914
  builder.with_media(
868
915
  self.df,
869
- list(self.coord_to_columns.media),
870
- list(self.coord_to_columns.media_spend),
871
- list(self.media_to_channel.values()),
916
+ [inv_media_map[ch] for ch in sorted_channels],
917
+ [inv_spend_map[ch] for ch in sorted_channels],
918
+ sorted_channels,
872
919
  self.coord_to_columns.time,
873
920
  self.coord_to_columns.geo,
874
921
  )
875
-
876
922
  if (
877
- self.coord_to_columns.reach is not None
878
- and self.reach_to_channel is not None
923
+ self.reach_to_channel is not None
924
+ and self.frequency_to_channel is not None
925
+ and self.rf_spend_to_channel is not None
879
926
  ):
927
+ sorted_channels = sorted(self.reach_to_channel.values())
928
+ inv_reach_map = {v: k for k, v in self.reach_to_channel.items()}
929
+ inv_freq_map = {v: k for k, v in self.frequency_to_channel.items()}
930
+ inv_rf_spend_map = {v: k for k, v in self.rf_spend_to_channel.items()}
880
931
  builder.with_reach(
881
932
  self.df,
882
- list(self.coord_to_columns.reach),
883
- list(self.coord_to_columns.frequency),
884
- list(self.coord_to_columns.rf_spend),
885
- list(self.reach_to_channel.values()),
933
+ [inv_reach_map[ch] for ch in sorted_channels],
934
+ [inv_freq_map[ch] for ch in sorted_channels],
935
+ [inv_rf_spend_map[ch] for ch in sorted_channels],
936
+ sorted_channels,
886
937
  self.coord_to_columns.time,
887
938
  self.coord_to_columns.geo,
888
939
  )
@@ -895,14 +946,19 @@ class DataFrameDataLoader(InputDataLoader):
895
946
  self.coord_to_columns.geo,
896
947
  )
897
948
  if (
898
- self.coord_to_columns.organic_reach is not None
899
- and self.organic_reach_to_channel is not None
949
+ self.organic_reach_to_channel is not None
950
+ and self.organic_frequency_to_channel is not None
900
951
  ):
952
+ sorted_channels = sorted(self.organic_reach_to_channel.values())
953
+ inv_reach_map = {v: k for k, v in self.organic_reach_to_channel.items()}
954
+ inv_freq_map = {
955
+ v: k for k, v in self.organic_frequency_to_channel.items()
956
+ }
901
957
  builder.with_organic_reach(
902
958
  self.df,
903
- list(self.coord_to_columns.organic_reach),
904
- list(self.coord_to_columns.organic_frequency),
905
- list(self.organic_reach_to_channel.values()),
959
+ [inv_reach_map[ch] for ch in sorted_channels],
960
+ [inv_freq_map[ch] for ch in sorted_channels],
961
+ sorted_channels,
906
962
  self.coord_to_columns.time,
907
963
  self.coord_to_columns.geo,
908
964
  )
@@ -12,29 +12,130 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- """MLflow autologging integration for Meridian."""
15
+ """MLflow autologging integration for Meridian.
16
16
 
17
+ This module enables MLflow tracking for Meridian. When enabled via `autolog()`,
18
+ parameters, metrics, and other metadata will be automatically logged to MLflow,
19
+ allowing for improved experiment tracking and analysis.
20
+
21
+ To enable MLflow autologging for your Meridian workflows, simply call
22
+ `autolog.autolog()` once before your model run.
23
+
24
+ Example usage:
25
+
26
+ ```python
27
+ import mlflow
28
+ from meridian.data import load
29
+ from meridian.mlflow import autolog
30
+ from meridian.model import model
31
+
32
+ # Enable autologging (call this once per session)
33
+ autolog.autolog(log_metrics=True)
34
+
35
+ # Start an MLflow run (optionally name it for better grouping)
36
+ with mlflow.start_run(run_name="my_run"):
37
+ # Load data
38
+ data = load.CsvDataLoader(...).load()
39
+
40
+ # Initialize Meridian model
41
+ mmm = model.Meridian(input_data=data)
42
+
43
+ # Run Meridian sampling processes
44
+ mmm.sample_prior(n_draws=100, seed=123)
45
+ mmm.sample_posterior(n_chains=7, n_adapt=500, n_burnin=500, n_keep=1000,
46
+ seed=1)
47
+
48
+ # After the run completes, you can retrieve run results using the MLflow client.
49
+ client = mlflow.tracking.MlflowClient()
50
+
51
+ # Get the experiment ID for the run you just launched
52
+ experiment_id = "0"
53
+
54
+ # Search for runs matching the run name
55
+ runs = client.search_runs(
56
+ experiment_id,
57
+ max_results=1000,
58
+ filter_string=f"attributes.run_name = 'my_run'"
59
+ )
60
+
61
+ # Print details of the run
62
+ if runs:
63
+ print(runs[0])
64
+ else:
65
+ print("No runs found.")
66
+ ```
67
+ """
68
+
69
+ import dataclasses
70
+ import inspect
71
+ import json
17
72
  from typing import Any, Callable
18
73
 
19
74
  import arviz as az
20
- import meridian
75
+ from meridian.analysis import visualizer
21
76
  import mlflow
22
77
  from mlflow.utils.autologging_utils import autologging_integration, safe_patch
23
78
  from meridian.model import model
79
+ from meridian.model import posterior_sampler
80
+ from meridian.model import prior_sampler
81
+ from meridian.model import spec
82
+ from meridian.version import __version__
83
+ import numpy as np
84
+ import tensorflow_probability as tfp
85
+
24
86
 
25
87
  FLAVOR_NAME = "meridian"
26
88
 
27
89
 
90
+ __all__ = ["autolog"]
91
+
92
+
28
93
  def _log_versions() -> None:
29
94
  """Logs Meridian and ArviZ versions."""
30
- mlflow.log_param("meridian_version", meridian.__version__)
95
+ mlflow.log_param("meridian_version", __version__)
31
96
  mlflow.log_param("arviz_version", az.__version__)
32
97
 
33
98
 
99
+ def _log_model_spec(model_spec: spec.ModelSpec) -> None:
100
+ """Logs the `ModelSpec` object."""
101
+ # TODO: Replace with serde api when it's available.
102
+ # PriorDistribution is logged separately.
103
+ excluded_fields = ["prior"]
104
+
105
+ for field in dataclasses.fields(model_spec):
106
+ if field.name in excluded_fields:
107
+ continue
108
+
109
+ field_value = getattr(model_spec, field.name)
110
+
111
+ # Stringify numpy arrays before logging.
112
+ if isinstance(field_value, np.ndarray):
113
+ field_value = json.dumps(field_value.tolist())
114
+
115
+ mlflow.log_param(f"spec.{field.name}", field_value)
116
+
117
+
118
+ def _log_priors(model_spec: spec.ModelSpec) -> None:
119
+ """Logs the `PriorDistribution` object."""
120
+ # TODO: Replace with serde api when it's available.
121
+ priors = model_spec.prior
122
+ for field in dataclasses.fields(priors):
123
+ field_value = getattr(priors, field.name)
124
+
125
+ # Stringify Distributions and numpy arrays.
126
+ if isinstance(field_value, tfp.distributions.Distribution):
127
+ field_value = str(field_value)
128
+ elif isinstance(field_value, np.ndarray):
129
+ field_value = json.dumps(field_value.tolist())
130
+
131
+ mlflow.log_param(f"prior.{field.name}", field_value)
132
+
133
+
34
134
  @autologging_integration(FLAVOR_NAME)
35
135
  def autolog(
36
136
  disable: bool = False, # pylint: disable=unused-argument
37
137
  silent: bool = False, # pylint: disable=unused-argument
138
+ log_metrics: bool = False,
38
139
  ) -> None:
39
140
  """Enables MLflow tracking for Meridian.
40
141
 
@@ -43,12 +144,63 @@ def autolog(
43
144
  Args:
44
145
  disable: Whether to disable autologging.
45
146
  silent: Whether to suppress all event logs and warnings from MLflow.
147
+ log_metrics: Whether model metrics should be logged. Enabling this option
148
+ involves the creation of post-modeling objects to compute relevant
149
+ performance metrics. Metrics include R-Squared, MAPE, and wMAPE values.
46
150
  """
47
151
 
48
152
  def patch_meridian_init(
49
- original: Callable[..., Any], *args, **kwargs
50
- ) -> Callable[..., Any]:
153
+ original: Callable[..., Any], self, *args, **kwargs
154
+ ) -> model.Meridian:
51
155
  _log_versions()
52
- return original(*args, **kwargs)
156
+ mmm = original(self, *args, **kwargs)
157
+ _log_model_spec(self.model_spec)
158
+ _log_priors(self.model_spec)
159
+ return mmm
160
+
161
+ def patch_prior_sampling(original: Callable[..., Any], self, *args, **kwargs):
162
+ mlflow.log_param("sample_prior.n_draws", kwargs.get("n_draws", "default"))
163
+ mlflow.log_param("sample_prior.seed", kwargs.get("seed", "default"))
164
+ return original(self, *args, **kwargs)
165
+
166
+ def patch_posterior_sampling(
167
+ original: Callable[..., Any], self, *args, **kwargs
168
+ ):
169
+ excluded_fields = ["current_state", "pins"]
170
+ params = [
171
+ name
172
+ for name, value in inspect.signature(original).parameters.items()
173
+ if name != "self"
174
+ and value.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
175
+ and name not in excluded_fields
176
+ ]
177
+
178
+ for param in params:
179
+ mlflow.log_param(
180
+ f"sample_posterior.{param}", kwargs.get(param, "default")
181
+ )
182
+
183
+ original(self, *args, **kwargs)
184
+ if log_metrics:
185
+ model_diagnostics = visualizer.ModelDiagnostics(self.model)
186
+ df_diag = model_diagnostics.predictive_accuracy_table()
187
+
188
+ get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
189
+
190
+ mlflow.log_metric("R_Squared", get_metric("R_Squared"))
191
+ mlflow.log_metric("MAPE", get_metric("MAPE"))
192
+ mlflow.log_metric("wMAPE", get_metric("wMAPE"))
53
193
 
54
194
  safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
195
+ safe_patch(
196
+ FLAVOR_NAME,
197
+ prior_sampler.PriorDistributionSampler,
198
+ "__call__",
199
+ patch_prior_sampling,
200
+ )
201
+ safe_patch(
202
+ FLAVOR_NAME,
203
+ posterior_sampler.PosteriorMCMCSampler,
204
+ "__call__",
205
+ patch_posterior_sampling,
206
+ )
meridian/model/media.py CHANGED
@@ -207,6 +207,8 @@ class RfTensors:
207
207
  Attributes:
208
208
  reach: A tensor constructed from `InputData.reach`.
209
209
  frequency: A tensor constructed from `InputData.frequency`.
210
+ rf_impressions: A tensor constructed from `InputData.reach` *
211
+ `InputData.frequency`.
210
212
  rf_spend: A tensor constructed from `InputData.rf_spend`.
211
213
  reach_transformer: A `MediaTransformer` to scale RF tensors using the
212
214
  model's RF data.
@@ -233,6 +235,7 @@ class RfTensors:
233
235
 
234
236
  reach: tf.Tensor | None = None
235
237
  frequency: tf.Tensor | None = None
238
+ rf_impressions: tf.Tensor | None = None
236
239
  rf_spend: tf.Tensor | None = None
237
240
  reach_transformer: transformers.MediaTransformer | None = None
238
241
  reach_scaled: tf.Tensor | None = None
@@ -250,6 +253,9 @@ def build_rf_tensors(
250
253
 
251
254
  reach = tf.convert_to_tensor(input_data.reach, dtype=tf.float32)
252
255
  frequency = tf.convert_to_tensor(input_data.frequency, dtype=tf.float32)
256
+ rf_impressions = (
257
+ reach * frequency if reach is not None and frequency is not None else None
258
+ )
253
259
  rf_spend = tf.convert_to_tensor(input_data.rf_spend, dtype=tf.float32)
254
260
  reach_transformer = transformers.MediaTransformer(
255
261
  reach, tf.convert_to_tensor(input_data.population, dtype=tf.float32)
@@ -292,6 +298,7 @@ def build_rf_tensors(
292
298
  return RfTensors(
293
299
  reach=reach,
294
300
  frequency=frequency,
301
+ rf_impressions=rf_impressions,
295
302
  rf_spend=rf_spend,
296
303
  reach_transformer=reach_transformer,
297
304
  reach_scaled=reach_scaled,
meridian/model/model.py CHANGED
@@ -1447,8 +1447,7 @@ class Meridian:
1447
1447
  see [PRNGS and seeds]
1448
1448
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1449
1449
  """
1450
- prior_inference_data = self.prior_sampler_callable(n_draws, seed)
1451
- self.inference_data.extend(prior_inference_data, join="right")
1450
+ self.prior_sampler_callable(n_draws=n_draws, seed=seed)
1452
1451
 
1453
1452
  def sample_posterior(
1454
1453
  self,
@@ -1527,22 +1526,21 @@ class Meridian:
1527
1526
  [ResourceExhaustedError when running Meridian.sample_posterior]
1528
1527
  (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
1529
1528
  """
1530
- posterior_inference_data = self.posterior_sampler_callable(
1531
- n_chains,
1532
- n_adapt,
1533
- n_burnin,
1534
- n_keep,
1535
- current_state,
1536
- init_step_size,
1537
- dual_averaging_kwargs,
1538
- max_tree_depth,
1539
- max_energy_diff,
1540
- unrolled_leapfrog_steps,
1541
- parallel_iterations,
1542
- seed,
1529
+ self.posterior_sampler_callable(
1530
+ n_chains=n_chains,
1531
+ n_adapt=n_adapt,
1532
+ n_burnin=n_burnin,
1533
+ n_keep=n_keep,
1534
+ current_state=current_state,
1535
+ init_step_size=init_step_size,
1536
+ dual_averaging_kwargs=dual_averaging_kwargs,
1537
+ max_tree_depth=max_tree_depth,
1538
+ max_energy_diff=max_energy_diff,
1539
+ unrolled_leapfrog_steps=unrolled_leapfrog_steps,
1540
+ parallel_iterations=parallel_iterations,
1541
+ seed=seed,
1543
1542
  **pins,
1544
1543
  )
1545
- self.inference_data.extend(posterior_inference_data, join="right")
1546
1544
 
1547
1545
 
1548
1546
  def save_mmm(mmm: Meridian, file_path: str):
@@ -85,9 +85,13 @@ class PosteriorMCMCSampler:
85
85
  def __init__(self, meridian: "model.Meridian"):
86
86
  self._meridian = meridian
87
87
 
88
+ @property
89
+ def model(self) -> "model.Meridian":
90
+ return self._meridian
91
+
88
92
  def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution:
89
93
  """Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
90
- mmm = self._meridian
94
+ mmm = self.model
91
95
  mmm.populate_cached_properties()
92
96
 
93
97
  # This lists all the derived properties and states of this Meridian object
@@ -453,7 +457,7 @@ class PosteriorMCMCSampler:
453
457
  return joint_dist_unpinned
454
458
 
455
459
  def _get_joint_dist(self) -> tfp.distributions.Distribution:
456
- mmm = self._meridian
460
+ mmm = self.model
457
461
  y = (
458
462
  tf.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
459
463
  if mmm.holdout_id is not None
@@ -476,7 +480,7 @@ class PosteriorMCMCSampler:
476
480
  parallel_iterations: int = 10,
477
481
  seed: Sequence[int] | int | None = None,
478
482
  **pins,
479
- ) -> az.InferenceData:
483
+ ) -> None:
480
484
  """Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
481
485
 
482
486
  For more information about the arguments, see [`windowed_adaptive_nuts`]
@@ -529,9 +533,6 @@ class PosteriorMCMCSampler:
529
533
  **pins: These are used to condition the provided joint distribution, and
530
534
  are passed directly to `joint_dist.experimental_pin(**pins)`.
531
535
 
532
- Returns:
533
- An Arviz `InferenceData` object containing posterior samples only.
534
-
535
536
  Throws:
536
537
  MCMCOOMError: If the model is out of memory. Try reducing `n_keep` or pass
537
538
  a list of integers as `n_chains` to sample chains serially. For more
@@ -589,10 +590,10 @@ class PosteriorMCMCSampler:
589
590
  if k not in constants.UNSAVED_PARAMETERS
590
591
  }
591
592
  # Create Arviz InferenceData for posterior draws.
592
- posterior_coords = self._meridian.create_inference_data_coords(
593
+ posterior_coords = self.model.create_inference_data_coords(
593
594
  total_chains, n_keep
594
595
  )
595
- posterior_dims = self._meridian.create_inference_data_dims()
596
+ posterior_dims = self.model.create_inference_data_dims()
596
597
  infdata_posterior = az.convert_to_inference_data(
597
598
  mcmc_states, coords=posterior_coords, dims=posterior_dims
598
599
  )
@@ -654,4 +655,7 @@ class PosteriorMCMCSampler:
654
655
  dims=sample_stats_dims,
655
656
  group="sample_stats",
656
657
  )
657
- return az.concat(infdata_posterior, infdata_trace, infdata_sample_stats)
658
+ posterior_inference_data = az.concat(
659
+ infdata_posterior, infdata_trace, infdata_sample_stats
660
+ )
661
+ self.model.inference_data.extend(posterior_inference_data, join="right")
@@ -588,22 +588,20 @@ class PriorDistributionSampler:
588
588
  | non_media_treatments_vars
589
589
  )
590
590
 
591
- def __call__(self, n_draws: int, seed: int | None = None) -> az.InferenceData:
591
+ def __call__(self, n_draws: int, seed: int | None = None) -> None:
592
592
  """Draws samples from prior distributions.
593
593
 
594
- Returns:
595
- An Arviz `InferenceData` object containing prior samples only.
596
-
597
594
  Args:
598
595
  n_draws: Number of samples drawn from the prior distribution.
599
596
  seed: Used to set the seed for reproducible results. For more information,
600
597
  see [PRNGS and seeds]
601
598
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
602
599
  """
603
- prior_draws = self._sample_prior(n_draws, seed=seed)
600
+ prior_draws = self._sample_prior(n_draws=n_draws, seed=seed)
604
601
  # Create Arviz InferenceData for prior draws.
605
602
  prior_coords = self._meridian.create_inference_data_coords(1, n_draws)
606
603
  prior_dims = self._meridian.create_inference_data_dims()
607
- return az.convert_to_inference_data(
604
+ prior_inference_data = az.convert_to_inference_data(
608
605
  prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR
609
606
  )
607
+ self._meridian.inference_data.extend(prior_inference_data, join="right")
meridian/version.py ADDED
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for Meridian version."""
16
+
17
+ __version__ = "1.1.3"