google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Serialization and deserialization of `InferenceData` container for sampled priors and posteriors."""
|
|
16
|
+
|
|
17
|
+
import io
|
|
18
|
+
|
|
19
|
+
import arviz as az
|
|
20
|
+
from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
|
|
21
|
+
from schema.serde import serde
|
|
22
|
+
import xarray as xr
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
_NETCDF_FORMAT = "NETCDF3_64BIT" # scipy only supports up to v3
|
|
26
|
+
_PRIOR_FIELD = "prior"
|
|
27
|
+
_POSTERIOR_FIELD = "posterior"
|
|
28
|
+
_CREATED_AT_ATTRIBUTE = "created_at"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _remove_created_at_attribute(dataset: xr.Dataset) -> xr.Dataset:
|
|
32
|
+
dataset_copy = dataset.copy()
|
|
33
|
+
if _CREATED_AT_ATTRIBUTE in dataset_copy.attrs:
|
|
34
|
+
del dataset_copy.attrs[_CREATED_AT_ATTRIBUTE]
|
|
35
|
+
return dataset_copy
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class InferenceDataSerde(
|
|
39
|
+
serde.Serde[meridian_pb.InferenceData, az.InferenceData]
|
|
40
|
+
):
|
|
41
|
+
"""Serializes and deserializes an `InferenceData` container in Meridian.
|
|
42
|
+
|
|
43
|
+
Meridian uses `InferenceData` as a container to store sampled prior and
|
|
44
|
+
posterior containers.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def serialize(self, obj: az.InferenceData) -> meridian_pb.InferenceData:
|
|
48
|
+
"""Serializes the given Meridian inference data container into an `InferenceData` proto."""
|
|
49
|
+
if hasattr(obj, _PRIOR_FIELD):
|
|
50
|
+
prior_dataset_copy = _remove_created_at_attribute(obj.prior) # pytype: disable=attribute-error
|
|
51
|
+
prior_bytes = bytes(prior_dataset_copy.to_netcdf(format=_NETCDF_FORMAT))
|
|
52
|
+
else:
|
|
53
|
+
prior_bytes = None
|
|
54
|
+
|
|
55
|
+
if hasattr(obj, _POSTERIOR_FIELD):
|
|
56
|
+
posterior_dataset_copy = _remove_created_at_attribute(obj.posterior) # pytype: disable=attribute-error
|
|
57
|
+
posterior_bytes = bytes(
|
|
58
|
+
posterior_dataset_copy.to_netcdf(format=_NETCDF_FORMAT)
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
posterior_bytes = None
|
|
62
|
+
|
|
63
|
+
aux = {}
|
|
64
|
+
for group in obj.groups():
|
|
65
|
+
if group in (_PRIOR_FIELD, _POSTERIOR_FIELD):
|
|
66
|
+
continue
|
|
67
|
+
aux_dataset_copy = _remove_created_at_attribute(obj.get(group))
|
|
68
|
+
aux[group] = bytes(aux_dataset_copy.to_netcdf(format=_NETCDF_FORMAT))
|
|
69
|
+
|
|
70
|
+
return meridian_pb.InferenceData(
|
|
71
|
+
prior=prior_bytes,
|
|
72
|
+
posterior=posterior_bytes,
|
|
73
|
+
auxiliary_data=aux,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def deserialize(
|
|
77
|
+
self, serialized: meridian_pb.InferenceData, serialized_version: str = ""
|
|
78
|
+
) -> az.InferenceData:
|
|
79
|
+
"""Deserializes the given `InferenceData` proto.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
serialized: The serialized `InferenceData` proto.
|
|
83
|
+
serialized_version: The version of the serialized model. This is used to
|
|
84
|
+
handle changes in deserialization logic across different versions.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
A Meridian inference data container.
|
|
88
|
+
"""
|
|
89
|
+
groups = {}
|
|
90
|
+
|
|
91
|
+
if serialized.HasField(_PRIOR_FIELD):
|
|
92
|
+
prior_dataset = xr.open_dataset(io.BytesIO(serialized.prior))
|
|
93
|
+
groups[_PRIOR_FIELD] = prior_dataset
|
|
94
|
+
|
|
95
|
+
if serialized.HasField(_POSTERIOR_FIELD):
|
|
96
|
+
posterior_dataset = xr.open_dataset(io.BytesIO(serialized.posterior))
|
|
97
|
+
groups[_POSTERIOR_FIELD] = posterior_dataset
|
|
98
|
+
|
|
99
|
+
for name, data in serialized.auxiliary_data.items():
|
|
100
|
+
groups[name] = xr.open_dataset(io.BytesIO(data))
|
|
101
|
+
|
|
102
|
+
idata = az.InferenceData()
|
|
103
|
+
if groups:
|
|
104
|
+
idata.add_groups(groups)
|
|
105
|
+
return idata
|