google-meridian 1.1.2__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.3.dist-info}/METADATA +2 -2
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.3.dist-info}/RECORD +17 -16
- meridian/__init__.py +6 -4
- meridian/analysis/analyzer.py +50 -14
- meridian/analysis/optimizer.py +6 -1
- meridian/constants.py +1 -0
- meridian/data/input_data_builder.py +9 -3
- meridian/data/load.py +76 -20
- meridian/mlflow/autolog.py +158 -6
- meridian/model/media.py +7 -0
- meridian/model/model.py +14 -16
- meridian/model/posterior_sampler.py +13 -9
- meridian/model/prior_sampler.py +4 -6
- meridian/version.py +17 -0
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.3.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.3.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.3.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: google-meridian
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.3
|
|
4
4
|
Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
|
|
5
5
|
Author-email: The Meridian Authors <no-reply@google.com>
|
|
6
6
|
License:
|
|
@@ -397,7 +397,7 @@ To cite this repository:
|
|
|
397
397
|
author = {Google Meridian Marketing Mix Modeling Team},
|
|
398
398
|
title = {Meridian: Marketing Mix Modeling},
|
|
399
399
|
url = {https://github.com/google/meridian},
|
|
400
|
-
version = {1.1.
|
|
400
|
+
version = {1.1.3},
|
|
401
401
|
year = {2025},
|
|
402
402
|
}
|
|
403
403
|
```
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
google_meridian-1.1.
|
|
2
|
-
meridian/__init__.py,sha256=
|
|
3
|
-
meridian/constants.py,sha256=
|
|
1
|
+
google_meridian-1.1.3.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
|
2
|
+
meridian/__init__.py,sha256=XROKwHNVQvEa371QCXAHik5wN_YKObOdJQX9bJ2c4M4,832
|
|
3
|
+
meridian/constants.py,sha256=VAVHyGfm9FyDd0dWomfqK5XYDUt9qJx7SAM4rzDh3RQ,17195
|
|
4
|
+
meridian/version.py,sha256=CUTXDDaOfXFTukX_ywPK6Q3PiK9hMyJbmJRBeb5ez7c,644
|
|
4
5
|
meridian/analysis/__init__.py,sha256=nGBYz7k9FVdadO_WVGMKJcfq7Yy_TuuP8zgee4i9pSA,836
|
|
5
|
-
meridian/analysis/analyzer.py,sha256=
|
|
6
|
+
meridian/analysis/analyzer.py,sha256=FY_SvnkmEqqCIS37UXB3bvaQi-U3BwLcSWhH1puTzdQ,206003
|
|
6
7
|
meridian/analysis/formatter.py,sha256=ENIdR1CRiaVqIGEXx1HcnsA4ewgDD_nhsYCweJAThaw,7270
|
|
7
|
-
meridian/analysis/optimizer.py,sha256=
|
|
8
|
+
meridian/analysis/optimizer.py,sha256=P4uMcV9ByqMapqa1TEqcnu-3NyTH9fR8QLszdKxRAFc,107801
|
|
8
9
|
meridian/analysis/summarizer.py,sha256=IthOUTMufGvAvbxiDhaKwe7uYCyiTyiQ8vgdmUtdevs,18855
|
|
9
10
|
meridian/analysis/summary_text.py,sha256=I_smDkZJYp2j77ea-9AIbgeraDa7-qUYyb-IthP2qO4,12438
|
|
10
11
|
meridian/analysis/test_utils.py,sha256=ES1r1akhRjD4pf2oTaGqzDfGNu9weAcLv6UZRuIkfEc,77699
|
|
@@ -22,25 +23,25 @@ meridian/data/__init__.py,sha256=StIe-wfYnnbfUbKtZHwnAQcRQUS8XCZk_PCaEzw90Ww,929
|
|
|
22
23
|
meridian/data/arg_builder.py,sha256=Kqlt88bOqFj6D3xNwvWo4MBwNwcDFHzd-wMfEOmLoPU,3741
|
|
23
24
|
meridian/data/data_frame_input_data_builder.py,sha256=3m6wrcC0psmD2ijsXk3R4uByA0Tu2gJxZBGaTS6Z7Io,22040
|
|
24
25
|
meridian/data/input_data.py,sha256=teJPKTBfW-AzBWgf_fEO_S_Z1J_veqQkCvctINaid6I,39749
|
|
25
|
-
meridian/data/input_data_builder.py,sha256=
|
|
26
|
-
meridian/data/load.py,sha256=
|
|
26
|
+
meridian/data/input_data_builder.py,sha256=08E_MZLrCzwfjvjPWFVs7o_094vVJ5o6VmbTfrg4NUM,25602
|
|
27
|
+
meridian/data/load.py,sha256=B-12fBhsghN7wj0A9IWyT7BVogIXjuUDDvR34JJFwPM,45157
|
|
27
28
|
meridian/data/nd_array_input_data_builder.py,sha256=lfpmnENGuSGKyUd7bDGAwoLqHqteOKmHdKl0VI2wCQA,16341
|
|
28
29
|
meridian/data/test_utils.py,sha256=6GJrPmeaF4uzMxxRgzERGv4g1XMUHwI0s7qDVMZUjuI,55565
|
|
29
30
|
meridian/data/time_coordinates.py,sha256=C5A5fscSLjPH6G9YT8OspgIlCrkMY7y8dMFEt3tNSnE,9874
|
|
30
31
|
meridian/mlflow/__init__.py,sha256=elwXUqPQYi7VF9PYjelU1tydfcUrmtuoq6eJCOnV9bk,693
|
|
31
|
-
meridian/mlflow/autolog.py,sha256=
|
|
32
|
+
meridian/mlflow/autolog.py,sha256=s240eLGAurzaNsulwRlyM1ZdBLvUzyr2eOMYgOyWAzk,6393
|
|
32
33
|
meridian/model/__init__.py,sha256=9NFfqUE5WgFc-9lQMkbfkwwV-bQIz0tsQ_3Jyq0A4SU,982
|
|
33
34
|
meridian/model/adstock_hill.py,sha256=20A_6rbDUAADEkkHspB7JpCm5tYfYS1FQ6hJMLu21Pk,9283
|
|
34
35
|
meridian/model/knots.py,sha256=KPEgnb-UdQQ4QBugOYEke-zBgEghgTmeCMoeiJ30meY,8054
|
|
35
|
-
meridian/model/media.py,sha256=
|
|
36
|
-
meridian/model/model.py,sha256=
|
|
36
|
+
meridian/model/media.py,sha256=3BaPX8xYAFMEvf0mz3mBSCIDWViIs7M218nrCklc6Fk,14099
|
|
37
|
+
meridian/model/model.py,sha256=BlLPyskHrEx5D71mUZFbNxS2VjkQgaiaE6hLKvQ5D3A,61489
|
|
37
38
|
meridian/model/model_test_data.py,sha256=hDDTEzm72LknW9c5E_dNsy4Mm4Tfs6AirhGf_QxykFs,15552
|
|
38
|
-
meridian/model/posterior_sampler.py,sha256=
|
|
39
|
+
meridian/model/posterior_sampler.py,sha256=K49zWTTelME2rL1JLeFAdMPzL0OwrBvyAXA3oR-kgSI,27801
|
|
39
40
|
meridian/model/prior_distribution.py,sha256=IEDU1rabcmKNY8lxwbbO4OUAlMHPIMa7flM_zsu3DLM,42417
|
|
40
|
-
meridian/model/prior_sampler.py,sha256=
|
|
41
|
+
meridian/model/prior_sampler.py,sha256=cmu6jG-bSEkYDkjVUxl3iSxrL7r-LN7a77cb2Vc0LoA,23218
|
|
41
42
|
meridian/model/spec.py,sha256=0HNiMQUWQpYvWYOZr1_fj2ah8tH-bEyfEjoqgBZ9Lc0,18049
|
|
42
43
|
meridian/model/transformers.py,sha256=nRjzq1fQG0ypldxboM7Gqok6WSAXAS1witRXoAzeH9Q,7763
|
|
43
|
-
google_meridian-1.1.
|
|
44
|
-
google_meridian-1.1.
|
|
45
|
-
google_meridian-1.1.
|
|
46
|
-
google_meridian-1.1.
|
|
44
|
+
google_meridian-1.1.3.dist-info/METADATA,sha256=5W_XWui7q5gH68OC3Z-PXbDOeBftDbWuhqznNv7fOAk,22201
|
|
45
|
+
google_meridian-1.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
46
|
+
google_meridian-1.1.3.dist-info/top_level.txt,sha256=nwaCebZvvU34EopTKZsjK0OMTFjVnkf4FfnBN_TAc0g,9
|
|
47
|
+
google_meridian-1.1.3.dist-info/RECORD,,
|
meridian/__init__.py
CHANGED
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
"""Meridian API."""
|
|
16
|
-
|
|
17
|
-
__version__ = "1.1.2"
|
|
18
|
-
|
|
19
|
-
|
|
20
16
|
from meridian import analysis
|
|
21
17
|
from meridian import data
|
|
22
18
|
from meridian import model
|
|
19
|
+
from meridian.version import __version__
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from meridian import mlflow # pylint: disable=g-import-not-at-top
|
|
23
|
+
except ImportError:
|
|
24
|
+
pass
|
meridian/analysis/analyzer.py
CHANGED
|
@@ -65,6 +65,8 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
65
65
|
time dimension `T`.
|
|
66
66
|
frequency: Optional tensor with dimensions `(n_geos, T, n_rf_channels)` for
|
|
67
67
|
any time dimension `T`.
|
|
68
|
+
rf_impressions: Optional tensor with dimensions `(n_geos, T, n_rf_channels)`
|
|
69
|
+
for any time dimension `T`.
|
|
68
70
|
rf_spend: Optional tensor with dimensions `(n_geos, T, n_rf_channels)` for
|
|
69
71
|
any time dimension `T`.
|
|
70
72
|
organic_media: Optional tensor with dimensions `(n_geos, T,
|
|
@@ -86,6 +88,7 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
86
88
|
media_spend: Optional[tf.Tensor]
|
|
87
89
|
reach: Optional[tf.Tensor]
|
|
88
90
|
frequency: Optional[tf.Tensor]
|
|
91
|
+
rf_impressions: Optional[tf.Tensor]
|
|
89
92
|
rf_spend: Optional[tf.Tensor]
|
|
90
93
|
organic_media: Optional[tf.Tensor]
|
|
91
94
|
organic_reach: Optional[tf.Tensor]
|
|
@@ -101,6 +104,7 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
101
104
|
media_spend: Optional[tf.Tensor] = None,
|
|
102
105
|
reach: Optional[tf.Tensor] = None,
|
|
103
106
|
frequency: Optional[tf.Tensor] = None,
|
|
107
|
+
rf_impressions: Optional[tf.Tensor] = None,
|
|
104
108
|
rf_spend: Optional[tf.Tensor] = None,
|
|
105
109
|
organic_media: Optional[tf.Tensor] = None,
|
|
106
110
|
organic_reach: Optional[tf.Tensor] = None,
|
|
@@ -118,6 +122,11 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
118
122
|
self.frequency = (
|
|
119
123
|
tf.cast(frequency, tf.float32) if frequency is not None else None
|
|
120
124
|
)
|
|
125
|
+
self.rf_impressions = (
|
|
126
|
+
tf.cast(rf_impressions, tf.float32)
|
|
127
|
+
if rf_impressions is not None
|
|
128
|
+
else None
|
|
129
|
+
)
|
|
121
130
|
self.rf_spend = (
|
|
122
131
|
tf.cast(rf_spend, tf.float32) if rf_spend is not None else None
|
|
123
132
|
)
|
|
@@ -189,7 +198,10 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
189
198
|
"""
|
|
190
199
|
for field in self._tf_extension_type_fields():
|
|
191
200
|
new_tensor = getattr(self, field.name)
|
|
192
|
-
|
|
201
|
+
if field.name == constants.RF_IMPRESSIONS:
|
|
202
|
+
old_tensor = getattr(meridian.rf_tensors, field.name)
|
|
203
|
+
else:
|
|
204
|
+
old_tensor = getattr(meridian.input_data, field.name)
|
|
193
205
|
# The time dimension is always the second dimension, except for when spend
|
|
194
206
|
# data is provided with only one dimension of (n_channels).
|
|
195
207
|
if (
|
|
@@ -293,7 +305,13 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
293
305
|
"This is not supported and will be ignored."
|
|
294
306
|
)
|
|
295
307
|
if field.name in required_variables:
|
|
296
|
-
if
|
|
308
|
+
if field.name == constants.RF_IMPRESSIONS:
|
|
309
|
+
if meridian.n_rf_channels == 0:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"New `rf_impressions` is not allowed because there are no R&F"
|
|
312
|
+
" channels in the Meridian model."
|
|
313
|
+
)
|
|
314
|
+
elif getattr(meridian.input_data, field.name) is None:
|
|
297
315
|
raise ValueError(
|
|
298
316
|
f"New `{field.name}` is not allowed because the input data to the"
|
|
299
317
|
f" Meridian model does not contain `{field.name}`."
|
|
@@ -322,7 +340,10 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
322
340
|
if var_name in [constants.REVENUE_PER_KPI, constants.TIME]:
|
|
323
341
|
continue
|
|
324
342
|
new_tensor = getattr(self, var_name)
|
|
325
|
-
|
|
343
|
+
if var_name == constants.RF_IMPRESSIONS:
|
|
344
|
+
old_tensor = getattr(meridian.rf_tensors, var_name)
|
|
345
|
+
else:
|
|
346
|
+
old_tensor = getattr(meridian.input_data, var_name)
|
|
326
347
|
if new_tensor is not None:
|
|
327
348
|
assert old_tensor is not None
|
|
328
349
|
if new_tensor.shape[-1] != old_tensor.shape[-1]:
|
|
@@ -337,7 +358,10 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
337
358
|
"""Validates the time dimension of the specified data variables."""
|
|
338
359
|
for var_name in required_fields:
|
|
339
360
|
new_tensor = getattr(self, var_name)
|
|
340
|
-
|
|
361
|
+
if var_name == constants.RF_IMPRESSIONS:
|
|
362
|
+
old_tensor = getattr(meridian.rf_tensors, var_name)
|
|
363
|
+
else:
|
|
364
|
+
old_tensor = getattr(meridian.input_data, var_name)
|
|
341
365
|
|
|
342
366
|
# Skip spend data with only 1 dimension of (n_channels).
|
|
343
367
|
if (
|
|
@@ -375,7 +399,10 @@ class DataTensors(tf.experimental.ExtensionType):
|
|
|
375
399
|
missing_params = []
|
|
376
400
|
for var_name in required_fields:
|
|
377
401
|
new_tensor = getattr(self, var_name)
|
|
378
|
-
|
|
402
|
+
if var_name == constants.RF_IMPRESSIONS:
|
|
403
|
+
old_tensor = getattr(meridian.rf_tensors, var_name)
|
|
404
|
+
else:
|
|
405
|
+
old_tensor = getattr(meridian.input_data, var_name)
|
|
379
406
|
|
|
380
407
|
if old_tensor is None:
|
|
381
408
|
continue
|
|
@@ -3415,6 +3442,7 @@ class Analyzer:
|
|
|
3415
3442
|
def optimal_freq(
|
|
3416
3443
|
self,
|
|
3417
3444
|
new_data: DataTensors | None = None,
|
|
3445
|
+
max_frequency: float | None = None,
|
|
3418
3446
|
freq_grid: Sequence[float] | None = None,
|
|
3419
3447
|
use_posterior: bool = True,
|
|
3420
3448
|
use_kpi: bool = False,
|
|
@@ -3443,7 +3471,7 @@ class Analyzer:
|
|
|
3443
3471
|
ROI numerator is KPI units.
|
|
3444
3472
|
|
|
3445
3473
|
Args:
|
|
3446
|
-
new_data: Optional `DataTensors` object containing `
|
|
3474
|
+
new_data: Optional `DataTensors` object containing `rf_impressions`,
|
|
3447
3475
|
`rf_spend`, and `revenue_per_kpi`. If provided, the optimal frequency is
|
|
3448
3476
|
calculated using the values of the tensors passed in `new_data` and the
|
|
3449
3477
|
original values of all the remaining tensors. If `None`, the historical
|
|
@@ -3451,6 +3479,10 @@ class Analyzer:
|
|
|
3451
3479
|
tensors in `new_data` is provided with a different number of time
|
|
3452
3480
|
periods than in `InputData`, then all tensors must be provided with the
|
|
3453
3481
|
same number of time periods.
|
|
3482
|
+
max_frequency: Maximum frequency value used to calculate the frequency
|
|
3483
|
+
grid. If `None`, the maximum frequency value is calculated from the
|
|
3484
|
+
historic frequency (maximum value of Meridian.input_data, not
|
|
3485
|
+
`new_data`). If `freq_grid` is provided, this argument has no effect.
|
|
3454
3486
|
freq_grid: List of frequency values. The ROI of each channel is calculated
|
|
3455
3487
|
for each frequency value in the list. By default, the list includes
|
|
3456
3488
|
numbers from `1.0` to the maximum frequency in increments of `0.1`.
|
|
@@ -3506,7 +3538,11 @@ class Analyzer:
|
|
|
3506
3538
|
)
|
|
3507
3539
|
|
|
3508
3540
|
filled_data = new_data.validate_and_fill_missing_data(
|
|
3509
|
-
|
|
3541
|
+
[
|
|
3542
|
+
constants.RF_IMPRESSIONS,
|
|
3543
|
+
constants.RF_SPEND,
|
|
3544
|
+
constants.REVENUE_PER_KPI,
|
|
3545
|
+
],
|
|
3510
3546
|
self._meridian,
|
|
3511
3547
|
)
|
|
3512
3548
|
# TODO: Once treatment type filtering is added, remove adding
|
|
@@ -3527,7 +3563,9 @@ class Analyzer:
|
|
|
3527
3563
|
(self._meridian.n_geos, n_times, self._meridian.n_media_channels)
|
|
3528
3564
|
)
|
|
3529
3565
|
|
|
3530
|
-
max_freq = np.max(
|
|
3566
|
+
max_freq = max_frequency or np.max(
|
|
3567
|
+
np.array(self._meridian.rf_tensors.frequency)
|
|
3568
|
+
)
|
|
3531
3569
|
if freq_grid is None:
|
|
3532
3570
|
freq_grid = np.arange(1, max_freq, 0.1)
|
|
3533
3571
|
|
|
@@ -3537,8 +3575,8 @@ class Analyzer:
|
|
|
3537
3575
|
metric_grid = np.zeros((len(freq_grid), self._meridian.n_rf_channels, 4))
|
|
3538
3576
|
|
|
3539
3577
|
for i, freq in enumerate(freq_grid):
|
|
3540
|
-
new_frequency = tf.ones_like(filled_data.
|
|
3541
|
-
new_reach = filled_data.
|
|
3578
|
+
new_frequency = tf.ones_like(filled_data.rf_impressions) * freq
|
|
3579
|
+
new_reach = filled_data.rf_impressions / new_frequency
|
|
3542
3580
|
new_roi_data = DataTensors(
|
|
3543
3581
|
reach=new_reach,
|
|
3544
3582
|
frequency=new_frequency,
|
|
@@ -3568,12 +3606,10 @@ class Analyzer:
|
|
|
3568
3606
|
|
|
3569
3607
|
optimal_frequency = [freq_grid[i] for i in optimal_freq_idx]
|
|
3570
3608
|
optimal_frequency_tensor = tf.convert_to_tensor(
|
|
3571
|
-
tf.ones_like(filled_data.
|
|
3609
|
+
tf.ones_like(filled_data.rf_impressions) * optimal_frequency,
|
|
3572
3610
|
tf.float32,
|
|
3573
3611
|
)
|
|
3574
|
-
optimal_reach =
|
|
3575
|
-
filled_data.frequency * filled_data.reach / optimal_frequency_tensor
|
|
3576
|
-
)
|
|
3612
|
+
optimal_reach = filled_data.rf_impressions / optimal_frequency_tensor
|
|
3577
3613
|
|
|
3578
3614
|
new_summary_metrics_data = DataTensors(
|
|
3579
3615
|
reach=optimal_reach,
|
meridian/analysis/optimizer.py
CHANGED
|
@@ -1863,9 +1863,14 @@ class BudgetOptimizer:
|
|
|
1863
1863
|
)
|
|
1864
1864
|
)
|
|
1865
1865
|
if self._meridian.n_rf_channels > 0 and use_optimal_frequency:
|
|
1866
|
+
opt_freq_data = analyzer.DataTensors(
|
|
1867
|
+
rf_impressions=filled_data.reach * filled_data.frequency,
|
|
1868
|
+
rf_spend=filled_data.rf_spend,
|
|
1869
|
+
revenue_per_kpi=filled_data.revenue_per_kpi,
|
|
1870
|
+
)
|
|
1866
1871
|
optimal_frequency = tf.convert_to_tensor(
|
|
1867
1872
|
self._analyzer.optimal_freq(
|
|
1868
|
-
new_data=
|
|
1873
|
+
new_data=opt_freq_data,
|
|
1869
1874
|
use_posterior=use_posterior,
|
|
1870
1875
|
selected_times=selected_times,
|
|
1871
1876
|
use_kpi=use_kpi,
|
meridian/constants.py
CHANGED
|
@@ -646,12 +646,13 @@ class InputDataBuilder(abc.ABC):
|
|
|
646
646
|
"""Normalizes the given `DataArray`'s coordinates in Meridian convention.
|
|
647
647
|
|
|
648
648
|
Validates that time values are in the conventional Meridian format and
|
|
649
|
-
that geos have national name if national.
|
|
649
|
+
that geos have national name if national. If geo coordinates are not string-
|
|
650
|
+
typed, they will be converted to strings.
|
|
650
651
|
|
|
651
652
|
Args:
|
|
652
653
|
da: The DataArray to normalize.
|
|
653
|
-
time_dimension_name: The name of the time dimension. If None, the
|
|
654
|
-
|
|
654
|
+
time_dimension_name: The name of the time dimension. If None, the will
|
|
655
|
+
skip time normalization.
|
|
655
656
|
|
|
656
657
|
Returns:
|
|
657
658
|
The normalized DataArray.
|
|
@@ -686,6 +687,11 @@ class InputDataBuilder(abc.ABC):
|
|
|
686
687
|
da = da.assign_coords(
|
|
687
688
|
{constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]},
|
|
688
689
|
)
|
|
690
|
+
else:
|
|
691
|
+
da = da.assign_coords(
|
|
692
|
+
{constants.GEO: da.coords[constants.GEO].astype(str)}
|
|
693
|
+
)
|
|
694
|
+
|
|
689
695
|
return da
|
|
690
696
|
|
|
691
697
|
def _validate_set(self, component: str, da: xr.DataArray):
|
meridian/data/load.py
CHANGED
|
@@ -816,12 +816,55 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
816
816
|
'organic_frequency': 'organic_frequency_to_channel',
|
|
817
817
|
})
|
|
818
818
|
for coord_name, channel_dict in required_mappings.items():
|
|
819
|
+
if getattr(self.coord_to_columns, coord_name, None) is not None:
|
|
820
|
+
if getattr(self, channel_dict, None) is None:
|
|
821
|
+
raise ValueError(
|
|
822
|
+
f"When {coord_name} data is provided, '{channel_dict}' is"
|
|
823
|
+
' required.'
|
|
824
|
+
)
|
|
825
|
+
else:
|
|
826
|
+
if set(getattr(self, channel_dict)) != set(
|
|
827
|
+
getattr(self.coord_to_columns, coord_name)
|
|
828
|
+
):
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f'The {channel_dict} keys must have the same set of values as'
|
|
831
|
+
f' the {coord_name} columns.'
|
|
832
|
+
)
|
|
833
|
+
if (
|
|
834
|
+
self.media_to_channel is not None
|
|
835
|
+
and self.media_spend_to_channel is not None
|
|
836
|
+
):
|
|
837
|
+
if set(self.media_to_channel.values()) != set(
|
|
838
|
+
self.media_spend_to_channel.values()
|
|
839
|
+
):
|
|
840
|
+
raise ValueError(
|
|
841
|
+
'The media and media_spend columns must have the same set of'
|
|
842
|
+
' channels.'
|
|
843
|
+
)
|
|
844
|
+
if (
|
|
845
|
+
self.reach_to_channel is not None
|
|
846
|
+
and self.frequency_to_channel is not None
|
|
847
|
+
and self.rf_spend_to_channel is not None
|
|
848
|
+
):
|
|
819
849
|
if (
|
|
820
|
-
|
|
821
|
-
|
|
850
|
+
set(self.reach_to_channel.values())
|
|
851
|
+
!= set(self.frequency_to_channel.values())
|
|
852
|
+
!= set(self.rf_spend_to_channel.values())
|
|
822
853
|
):
|
|
823
854
|
raise ValueError(
|
|
824
|
-
|
|
855
|
+
'The reach, frequency, and rf_spend columns must have the same set'
|
|
856
|
+
' of channels.'
|
|
857
|
+
)
|
|
858
|
+
if (
|
|
859
|
+
self.organic_reach_to_channel is not None
|
|
860
|
+
and self.organic_frequency_to_channel is not None
|
|
861
|
+
):
|
|
862
|
+
if set(self.organic_reach_to_channel.values()) != set(
|
|
863
|
+
self.organic_frequency_to_channel.values()
|
|
864
|
+
):
|
|
865
|
+
raise ValueError(
|
|
866
|
+
'The organic_reach and organic_frequency columns must have the'
|
|
867
|
+
' same set of channels.'
|
|
825
868
|
)
|
|
826
869
|
|
|
827
870
|
def load(self) -> input_data.InputData:
|
|
@@ -861,28 +904,36 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
861
904
|
self.coord_to_columns.geo,
|
|
862
905
|
)
|
|
863
906
|
if (
|
|
864
|
-
self.
|
|
865
|
-
and self.
|
|
907
|
+
self.media_to_channel is not None
|
|
908
|
+
and self.media_spend_to_channel is not None
|
|
866
909
|
):
|
|
910
|
+
sorted_channels = sorted(self.media_to_channel.values())
|
|
911
|
+
inv_media_map = {v: k for k, v in self.media_to_channel.items()}
|
|
912
|
+
inv_spend_map = {v: k for k, v in self.media_spend_to_channel.items()}
|
|
913
|
+
|
|
867
914
|
builder.with_media(
|
|
868
915
|
self.df,
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
916
|
+
[inv_media_map[ch] for ch in sorted_channels],
|
|
917
|
+
[inv_spend_map[ch] for ch in sorted_channels],
|
|
918
|
+
sorted_channels,
|
|
872
919
|
self.coord_to_columns.time,
|
|
873
920
|
self.coord_to_columns.geo,
|
|
874
921
|
)
|
|
875
|
-
|
|
876
922
|
if (
|
|
877
|
-
self.
|
|
878
|
-
and self.
|
|
923
|
+
self.reach_to_channel is not None
|
|
924
|
+
and self.frequency_to_channel is not None
|
|
925
|
+
and self.rf_spend_to_channel is not None
|
|
879
926
|
):
|
|
927
|
+
sorted_channels = sorted(self.reach_to_channel.values())
|
|
928
|
+
inv_reach_map = {v: k for k, v in self.reach_to_channel.items()}
|
|
929
|
+
inv_freq_map = {v: k for k, v in self.frequency_to_channel.items()}
|
|
930
|
+
inv_rf_spend_map = {v: k for k, v in self.rf_spend_to_channel.items()}
|
|
880
931
|
builder.with_reach(
|
|
881
932
|
self.df,
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
933
|
+
[inv_reach_map[ch] for ch in sorted_channels],
|
|
934
|
+
[inv_freq_map[ch] for ch in sorted_channels],
|
|
935
|
+
[inv_rf_spend_map[ch] for ch in sorted_channels],
|
|
936
|
+
sorted_channels,
|
|
886
937
|
self.coord_to_columns.time,
|
|
887
938
|
self.coord_to_columns.geo,
|
|
888
939
|
)
|
|
@@ -895,14 +946,19 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
895
946
|
self.coord_to_columns.geo,
|
|
896
947
|
)
|
|
897
948
|
if (
|
|
898
|
-
self.
|
|
899
|
-
and self.
|
|
949
|
+
self.organic_reach_to_channel is not None
|
|
950
|
+
and self.organic_frequency_to_channel is not None
|
|
900
951
|
):
|
|
952
|
+
sorted_channels = sorted(self.organic_reach_to_channel.values())
|
|
953
|
+
inv_reach_map = {v: k for k, v in self.organic_reach_to_channel.items()}
|
|
954
|
+
inv_freq_map = {
|
|
955
|
+
v: k for k, v in self.organic_frequency_to_channel.items()
|
|
956
|
+
}
|
|
901
957
|
builder.with_organic_reach(
|
|
902
958
|
self.df,
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
959
|
+
[inv_reach_map[ch] for ch in sorted_channels],
|
|
960
|
+
[inv_freq_map[ch] for ch in sorted_channels],
|
|
961
|
+
sorted_channels,
|
|
906
962
|
self.coord_to_columns.time,
|
|
907
963
|
self.coord_to_columns.geo,
|
|
908
964
|
)
|
meridian/mlflow/autolog.py
CHANGED
|
@@ -12,29 +12,130 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
"""MLflow autologging integration for Meridian.
|
|
15
|
+
"""MLflow autologging integration for Meridian.
|
|
16
16
|
|
|
17
|
+
This module enables MLflow tracking for Meridian. When enabled via `autolog()`,
|
|
18
|
+
parameters, metrics, and other metadata will be automatically logged to MLflow,
|
|
19
|
+
allowing for improved experiment tracking and analysis.
|
|
20
|
+
|
|
21
|
+
To enable MLflow autologging for your Meridian workflows, simply call
|
|
22
|
+
`autolog.autolog()` once before your model run.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
import mlflow
|
|
28
|
+
from meridian.data import load
|
|
29
|
+
from meridian.mlflow import autolog
|
|
30
|
+
from meridian.model import model
|
|
31
|
+
|
|
32
|
+
# Enable autologging (call this once per session)
|
|
33
|
+
autolog.autolog(log_metrics=True)
|
|
34
|
+
|
|
35
|
+
# Start an MLflow run (optionally name it for better grouping)
|
|
36
|
+
with mlflow.start_run(run_name="my_run"):
|
|
37
|
+
# Load data
|
|
38
|
+
data = load.CsvDataLoader(...).load()
|
|
39
|
+
|
|
40
|
+
# Initialize Meridian model
|
|
41
|
+
mmm = model.Meridian(input_data=data)
|
|
42
|
+
|
|
43
|
+
# Run Meridian sampling processes
|
|
44
|
+
mmm.sample_prior(n_draws=100, seed=123)
|
|
45
|
+
mmm.sample_posterior(n_chains=7, n_adapt=500, n_burnin=500, n_keep=1000,
|
|
46
|
+
seed=1)
|
|
47
|
+
|
|
48
|
+
# After the run completes, you can retrieve run results using the MLflow client.
|
|
49
|
+
client = mlflow.tracking.MlflowClient()
|
|
50
|
+
|
|
51
|
+
# Get the experiment ID for the run you just launched
|
|
52
|
+
experiment_id = "0"
|
|
53
|
+
|
|
54
|
+
# Search for runs matching the run name
|
|
55
|
+
runs = client.search_runs(
|
|
56
|
+
experiment_id,
|
|
57
|
+
max_results=1000,
|
|
58
|
+
filter_string=f"attributes.run_name = 'my_run'"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Print details of the run
|
|
62
|
+
if runs:
|
|
63
|
+
print(runs[0])
|
|
64
|
+
else:
|
|
65
|
+
print("No runs found.")
|
|
66
|
+
```
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
import dataclasses
|
|
70
|
+
import inspect
|
|
71
|
+
import json
|
|
17
72
|
from typing import Any, Callable
|
|
18
73
|
|
|
19
74
|
import arviz as az
|
|
20
|
-
import
|
|
75
|
+
from meridian.analysis import visualizer
|
|
21
76
|
import mlflow
|
|
22
77
|
from mlflow.utils.autologging_utils import autologging_integration, safe_patch
|
|
23
78
|
from meridian.model import model
|
|
79
|
+
from meridian.model import posterior_sampler
|
|
80
|
+
from meridian.model import prior_sampler
|
|
81
|
+
from meridian.model import spec
|
|
82
|
+
from meridian.version import __version__
|
|
83
|
+
import numpy as np
|
|
84
|
+
import tensorflow_probability as tfp
|
|
85
|
+
|
|
24
86
|
|
|
25
87
|
FLAVOR_NAME = "meridian"
|
|
26
88
|
|
|
27
89
|
|
|
90
|
+
__all__ = ["autolog"]
|
|
91
|
+
|
|
92
|
+
|
|
28
93
|
def _log_versions() -> None:
|
|
29
94
|
"""Logs Meridian and ArviZ versions."""
|
|
30
|
-
mlflow.log_param("meridian_version",
|
|
95
|
+
mlflow.log_param("meridian_version", __version__)
|
|
31
96
|
mlflow.log_param("arviz_version", az.__version__)
|
|
32
97
|
|
|
33
98
|
|
|
99
|
+
def _log_model_spec(model_spec: spec.ModelSpec) -> None:
|
|
100
|
+
"""Logs the `ModelSpec` object."""
|
|
101
|
+
# TODO: Replace with serde api when it's available.
|
|
102
|
+
# PriorDistribution is logged separately.
|
|
103
|
+
excluded_fields = ["prior"]
|
|
104
|
+
|
|
105
|
+
for field in dataclasses.fields(model_spec):
|
|
106
|
+
if field.name in excluded_fields:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
field_value = getattr(model_spec, field.name)
|
|
110
|
+
|
|
111
|
+
# Stringify numpy arrays before logging.
|
|
112
|
+
if isinstance(field_value, np.ndarray):
|
|
113
|
+
field_value = json.dumps(field_value.tolist())
|
|
114
|
+
|
|
115
|
+
mlflow.log_param(f"spec.{field.name}", field_value)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _log_priors(model_spec: spec.ModelSpec) -> None:
|
|
119
|
+
"""Logs the `PriorDistribution` object."""
|
|
120
|
+
# TODO: Replace with serde api when it's available.
|
|
121
|
+
priors = model_spec.prior
|
|
122
|
+
for field in dataclasses.fields(priors):
|
|
123
|
+
field_value = getattr(priors, field.name)
|
|
124
|
+
|
|
125
|
+
# Stringify Distributions and numpy arrays.
|
|
126
|
+
if isinstance(field_value, tfp.distributions.Distribution):
|
|
127
|
+
field_value = str(field_value)
|
|
128
|
+
elif isinstance(field_value, np.ndarray):
|
|
129
|
+
field_value = json.dumps(field_value.tolist())
|
|
130
|
+
|
|
131
|
+
mlflow.log_param(f"prior.{field.name}", field_value)
|
|
132
|
+
|
|
133
|
+
|
|
34
134
|
@autologging_integration(FLAVOR_NAME)
|
|
35
135
|
def autolog(
|
|
36
136
|
disable: bool = False, # pylint: disable=unused-argument
|
|
37
137
|
silent: bool = False, # pylint: disable=unused-argument
|
|
138
|
+
log_metrics: bool = False,
|
|
38
139
|
) -> None:
|
|
39
140
|
"""Enables MLflow tracking for Meridian.
|
|
40
141
|
|
|
@@ -43,12 +144,63 @@ def autolog(
|
|
|
43
144
|
Args:
|
|
44
145
|
disable: Whether to disable autologging.
|
|
45
146
|
silent: Whether to suppress all event logs and warnings from MLflow.
|
|
147
|
+
log_metrics: Whether model metrics should be logged. Enabling this option
|
|
148
|
+
involves the creation of post-modeling objects to compute relevant
|
|
149
|
+
performance metrics. Metrics include R-Squared, MAPE, and wMAPE values.
|
|
46
150
|
"""
|
|
47
151
|
|
|
48
152
|
def patch_meridian_init(
|
|
49
|
-
original: Callable[..., Any], *args, **kwargs
|
|
50
|
-
) ->
|
|
153
|
+
original: Callable[..., Any], self, *args, **kwargs
|
|
154
|
+
) -> model.Meridian:
|
|
51
155
|
_log_versions()
|
|
52
|
-
|
|
156
|
+
mmm = original(self, *args, **kwargs)
|
|
157
|
+
_log_model_spec(self.model_spec)
|
|
158
|
+
_log_priors(self.model_spec)
|
|
159
|
+
return mmm
|
|
160
|
+
|
|
161
|
+
def patch_prior_sampling(original: Callable[..., Any], self, *args, **kwargs):
|
|
162
|
+
mlflow.log_param("sample_prior.n_draws", kwargs.get("n_draws", "default"))
|
|
163
|
+
mlflow.log_param("sample_prior.seed", kwargs.get("seed", "default"))
|
|
164
|
+
return original(self, *args, **kwargs)
|
|
165
|
+
|
|
166
|
+
def patch_posterior_sampling(
|
|
167
|
+
original: Callable[..., Any], self, *args, **kwargs
|
|
168
|
+
):
|
|
169
|
+
excluded_fields = ["current_state", "pins"]
|
|
170
|
+
params = [
|
|
171
|
+
name
|
|
172
|
+
for name, value in inspect.signature(original).parameters.items()
|
|
173
|
+
if name != "self"
|
|
174
|
+
and value.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
175
|
+
and name not in excluded_fields
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
for param in params:
|
|
179
|
+
mlflow.log_param(
|
|
180
|
+
f"sample_posterior.{param}", kwargs.get(param, "default")
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
original(self, *args, **kwargs)
|
|
184
|
+
if log_metrics:
|
|
185
|
+
model_diagnostics = visualizer.ModelDiagnostics(self.model)
|
|
186
|
+
df_diag = model_diagnostics.predictive_accuracy_table()
|
|
187
|
+
|
|
188
|
+
get_metric = lambda n: df_diag[df_diag.metric == n].value.to_list()[0]
|
|
189
|
+
|
|
190
|
+
mlflow.log_metric("R_Squared", get_metric("R_Squared"))
|
|
191
|
+
mlflow.log_metric("MAPE", get_metric("MAPE"))
|
|
192
|
+
mlflow.log_metric("wMAPE", get_metric("wMAPE"))
|
|
53
193
|
|
|
54
194
|
safe_patch(FLAVOR_NAME, model.Meridian, "__init__", patch_meridian_init)
|
|
195
|
+
safe_patch(
|
|
196
|
+
FLAVOR_NAME,
|
|
197
|
+
prior_sampler.PriorDistributionSampler,
|
|
198
|
+
"__call__",
|
|
199
|
+
patch_prior_sampling,
|
|
200
|
+
)
|
|
201
|
+
safe_patch(
|
|
202
|
+
FLAVOR_NAME,
|
|
203
|
+
posterior_sampler.PosteriorMCMCSampler,
|
|
204
|
+
"__call__",
|
|
205
|
+
patch_posterior_sampling,
|
|
206
|
+
)
|
meridian/model/media.py
CHANGED
|
@@ -207,6 +207,8 @@ class RfTensors:
|
|
|
207
207
|
Attributes:
|
|
208
208
|
reach: A tensor constructed from `InputData.reach`.
|
|
209
209
|
frequency: A tensor constructed from `InputData.frequency`.
|
|
210
|
+
rf_impressions: A tensor constructed from `InputData.reach` *
|
|
211
|
+
`InputData.frequency`.
|
|
210
212
|
rf_spend: A tensor constructed from `InputData.rf_spend`.
|
|
211
213
|
reach_transformer: A `MediaTransformer` to scale RF tensors using the
|
|
212
214
|
model's RF data.
|
|
@@ -233,6 +235,7 @@ class RfTensors:
|
|
|
233
235
|
|
|
234
236
|
reach: tf.Tensor | None = None
|
|
235
237
|
frequency: tf.Tensor | None = None
|
|
238
|
+
rf_impressions: tf.Tensor | None = None
|
|
236
239
|
rf_spend: tf.Tensor | None = None
|
|
237
240
|
reach_transformer: transformers.MediaTransformer | None = None
|
|
238
241
|
reach_scaled: tf.Tensor | None = None
|
|
@@ -250,6 +253,9 @@ def build_rf_tensors(
|
|
|
250
253
|
|
|
251
254
|
reach = tf.convert_to_tensor(input_data.reach, dtype=tf.float32)
|
|
252
255
|
frequency = tf.convert_to_tensor(input_data.frequency, dtype=tf.float32)
|
|
256
|
+
rf_impressions = (
|
|
257
|
+
reach * frequency if reach is not None and frequency is not None else None
|
|
258
|
+
)
|
|
253
259
|
rf_spend = tf.convert_to_tensor(input_data.rf_spend, dtype=tf.float32)
|
|
254
260
|
reach_transformer = transformers.MediaTransformer(
|
|
255
261
|
reach, tf.convert_to_tensor(input_data.population, dtype=tf.float32)
|
|
@@ -292,6 +298,7 @@ def build_rf_tensors(
|
|
|
292
298
|
return RfTensors(
|
|
293
299
|
reach=reach,
|
|
294
300
|
frequency=frequency,
|
|
301
|
+
rf_impressions=rf_impressions,
|
|
295
302
|
rf_spend=rf_spend,
|
|
296
303
|
reach_transformer=reach_transformer,
|
|
297
304
|
reach_scaled=reach_scaled,
|
meridian/model/model.py
CHANGED
|
@@ -1447,8 +1447,7 @@ class Meridian:
|
|
|
1447
1447
|
see [PRNGS and seeds]
|
|
1448
1448
|
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
1449
1449
|
"""
|
|
1450
|
-
|
|
1451
|
-
self.inference_data.extend(prior_inference_data, join="right")
|
|
1450
|
+
self.prior_sampler_callable(n_draws=n_draws, seed=seed)
|
|
1452
1451
|
|
|
1453
1452
|
def sample_posterior(
|
|
1454
1453
|
self,
|
|
@@ -1527,22 +1526,21 @@ class Meridian:
|
|
|
1527
1526
|
[ResourceExhaustedError when running Meridian.sample_posterior]
|
|
1528
1527
|
(https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
|
|
1529
1528
|
"""
|
|
1530
|
-
|
|
1531
|
-
n_chains,
|
|
1532
|
-
n_adapt,
|
|
1533
|
-
n_burnin,
|
|
1534
|
-
n_keep,
|
|
1535
|
-
current_state,
|
|
1536
|
-
init_step_size,
|
|
1537
|
-
dual_averaging_kwargs,
|
|
1538
|
-
max_tree_depth,
|
|
1539
|
-
max_energy_diff,
|
|
1540
|
-
unrolled_leapfrog_steps,
|
|
1541
|
-
parallel_iterations,
|
|
1542
|
-
seed,
|
|
1529
|
+
self.posterior_sampler_callable(
|
|
1530
|
+
n_chains=n_chains,
|
|
1531
|
+
n_adapt=n_adapt,
|
|
1532
|
+
n_burnin=n_burnin,
|
|
1533
|
+
n_keep=n_keep,
|
|
1534
|
+
current_state=current_state,
|
|
1535
|
+
init_step_size=init_step_size,
|
|
1536
|
+
dual_averaging_kwargs=dual_averaging_kwargs,
|
|
1537
|
+
max_tree_depth=max_tree_depth,
|
|
1538
|
+
max_energy_diff=max_energy_diff,
|
|
1539
|
+
unrolled_leapfrog_steps=unrolled_leapfrog_steps,
|
|
1540
|
+
parallel_iterations=parallel_iterations,
|
|
1541
|
+
seed=seed,
|
|
1543
1542
|
**pins,
|
|
1544
1543
|
)
|
|
1545
|
-
self.inference_data.extend(posterior_inference_data, join="right")
|
|
1546
1544
|
|
|
1547
1545
|
|
|
1548
1546
|
def save_mmm(mmm: Meridian, file_path: str):
|
|
@@ -85,9 +85,13 @@ class PosteriorMCMCSampler:
|
|
|
85
85
|
def __init__(self, meridian: "model.Meridian"):
|
|
86
86
|
self._meridian = meridian
|
|
87
87
|
|
|
88
|
+
@property
|
|
89
|
+
def model(self) -> "model.Meridian":
|
|
90
|
+
return self._meridian
|
|
91
|
+
|
|
88
92
|
def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution:
|
|
89
93
|
"""Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
|
|
90
|
-
mmm = self.
|
|
94
|
+
mmm = self.model
|
|
91
95
|
mmm.populate_cached_properties()
|
|
92
96
|
|
|
93
97
|
# This lists all the derived properties and states of this Meridian object
|
|
@@ -453,7 +457,7 @@ class PosteriorMCMCSampler:
|
|
|
453
457
|
return joint_dist_unpinned
|
|
454
458
|
|
|
455
459
|
def _get_joint_dist(self) -> tfp.distributions.Distribution:
|
|
456
|
-
mmm = self.
|
|
460
|
+
mmm = self.model
|
|
457
461
|
y = (
|
|
458
462
|
tf.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
|
|
459
463
|
if mmm.holdout_id is not None
|
|
@@ -476,7 +480,7 @@ class PosteriorMCMCSampler:
|
|
|
476
480
|
parallel_iterations: int = 10,
|
|
477
481
|
seed: Sequence[int] | int | None = None,
|
|
478
482
|
**pins,
|
|
479
|
-
) ->
|
|
483
|
+
) -> None:
|
|
480
484
|
"""Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
|
|
481
485
|
|
|
482
486
|
For more information about the arguments, see [`windowed_adaptive_nuts`]
|
|
@@ -529,9 +533,6 @@ class PosteriorMCMCSampler:
|
|
|
529
533
|
**pins: These are used to condition the provided joint distribution, and
|
|
530
534
|
are passed directly to `joint_dist.experimental_pin(**pins)`.
|
|
531
535
|
|
|
532
|
-
Returns:
|
|
533
|
-
An Arviz `InferenceData` object containing posterior samples only.
|
|
534
|
-
|
|
535
536
|
Throws:
|
|
536
537
|
MCMCOOMError: If the model is out of memory. Try reducing `n_keep` or pass
|
|
537
538
|
a list of integers as `n_chains` to sample chains serially. For more
|
|
@@ -589,10 +590,10 @@ class PosteriorMCMCSampler:
|
|
|
589
590
|
if k not in constants.UNSAVED_PARAMETERS
|
|
590
591
|
}
|
|
591
592
|
# Create Arviz InferenceData for posterior draws.
|
|
592
|
-
posterior_coords = self.
|
|
593
|
+
posterior_coords = self.model.create_inference_data_coords(
|
|
593
594
|
total_chains, n_keep
|
|
594
595
|
)
|
|
595
|
-
posterior_dims = self.
|
|
596
|
+
posterior_dims = self.model.create_inference_data_dims()
|
|
596
597
|
infdata_posterior = az.convert_to_inference_data(
|
|
597
598
|
mcmc_states, coords=posterior_coords, dims=posterior_dims
|
|
598
599
|
)
|
|
@@ -654,4 +655,7 @@ class PosteriorMCMCSampler:
|
|
|
654
655
|
dims=sample_stats_dims,
|
|
655
656
|
group="sample_stats",
|
|
656
657
|
)
|
|
657
|
-
|
|
658
|
+
posterior_inference_data = az.concat(
|
|
659
|
+
infdata_posterior, infdata_trace, infdata_sample_stats
|
|
660
|
+
)
|
|
661
|
+
self.model.inference_data.extend(posterior_inference_data, join="right")
|
meridian/model/prior_sampler.py
CHANGED
|
@@ -588,22 +588,20 @@ class PriorDistributionSampler:
|
|
|
588
588
|
| non_media_treatments_vars
|
|
589
589
|
)
|
|
590
590
|
|
|
591
|
-
def __call__(self, n_draws: int, seed: int | None = None) ->
|
|
591
|
+
def __call__(self, n_draws: int, seed: int | None = None) -> None:
|
|
592
592
|
"""Draws samples from prior distributions.
|
|
593
593
|
|
|
594
|
-
Returns:
|
|
595
|
-
An Arviz `InferenceData` object containing prior samples only.
|
|
596
|
-
|
|
597
594
|
Args:
|
|
598
595
|
n_draws: Number of samples drawn from the prior distribution.
|
|
599
596
|
seed: Used to set the seed for reproducible results. For more information,
|
|
600
597
|
see [PRNGS and seeds]
|
|
601
598
|
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
602
599
|
"""
|
|
603
|
-
prior_draws = self._sample_prior(n_draws, seed=seed)
|
|
600
|
+
prior_draws = self._sample_prior(n_draws=n_draws, seed=seed)
|
|
604
601
|
# Create Arviz InferenceData for prior draws.
|
|
605
602
|
prior_coords = self._meridian.create_inference_data_coords(1, n_draws)
|
|
606
603
|
prior_dims = self._meridian.create_inference_data_dims()
|
|
607
|
-
|
|
604
|
+
prior_inference_data = az.convert_to_inference_data(
|
|
608
605
|
prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR
|
|
609
606
|
)
|
|
607
|
+
self._meridian.inference_data.extend(prior_inference_data, join="right")
|
meridian/version.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module for Meridian version."""
|
|
16
|
+
|
|
17
|
+
__version__ = "1.1.3"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|