google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
schema/serde/eda_spec.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Serialization and deserialization of `EDASpec` objects."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
from meridian.model.eda import eda_spec
|
|
22
|
+
from mmm.v1.model.meridian.eda import eda_spec_pb2 as eda_spec_pb
|
|
23
|
+
from schema.serde import function_registry as function_registry_utils
|
|
24
|
+
from schema.serde import serde
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
FunctionRegistry = function_registry_utils.FunctionRegistry
|
|
28
|
+
_FUNCTION_REGISTRY_NAME = "function_registry"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EDASpecSerde(serde.Serde[eda_spec_pb.EDASpec, eda_spec.EDASpec]):
|
|
32
|
+
"""Serializes and deserializes an `EDASpec` object into an `EDASpec` proto."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, function_registry: FunctionRegistry):
|
|
35
|
+
"""Initializes an `EDASpecSerde` instance.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
function_registry: A lookup table containing custom functions used by
|
|
39
|
+
`EDASpec` objects. It's recommended to explicitly define the custom
|
|
40
|
+
functions instead of using lambdas, as lambda functions may not be
|
|
41
|
+
serialized successfully.
|
|
42
|
+
"""
|
|
43
|
+
self._function_registry = function_registry
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def function_registry(self) -> FunctionRegistry:
|
|
47
|
+
return self._function_registry
|
|
48
|
+
|
|
49
|
+
def serialize(self, obj: eda_spec.EDASpec) -> eda_spec_pb.EDASpec:
|
|
50
|
+
"""Serializes the given `EDASpec` object into an `EDASpec` proto."""
|
|
51
|
+
proto = eda_spec_pb.EDASpec(
|
|
52
|
+
aggregation_config=self._to_aggregation_config_proto(
|
|
53
|
+
obj.aggregation_config
|
|
54
|
+
),
|
|
55
|
+
vif_spec=self._to_vif_spec_proto(obj.vif_spec),
|
|
56
|
+
)
|
|
57
|
+
hashed_function_registry = self.function_registry.hashed_registry
|
|
58
|
+
proto.function_registry.update(hashed_function_registry)
|
|
59
|
+
return proto
|
|
60
|
+
|
|
61
|
+
def deserialize(
|
|
62
|
+
self,
|
|
63
|
+
serialized: eda_spec_pb.EDASpec,
|
|
64
|
+
serialized_version: str = "",
|
|
65
|
+
force_deserialization: bool = False,
|
|
66
|
+
) -> eda_spec.EDASpec:
|
|
67
|
+
"""Deserializes the `EDASpec` proto.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
serialized: A serialized `EDASpec` object.
|
|
71
|
+
serialized_version: The version of the serialized Meridian model. This is
|
|
72
|
+
used to handle changes in deserialization logic across different
|
|
73
|
+
versions.
|
|
74
|
+
force_deserialization: If True, bypasses the safety check that validates
|
|
75
|
+
whether functions within `function_registry` have changed after
|
|
76
|
+
serialization. Use with caution.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A deserialized `EDASpec` object.
|
|
80
|
+
"""
|
|
81
|
+
if force_deserialization:
|
|
82
|
+
warnings.warn(
|
|
83
|
+
"You're attempting to deserialize an EDASpec while ignoring changes"
|
|
84
|
+
" to custom functions. This is a risky operation that can"
|
|
85
|
+
" potentially lead to a deserialized EDASpec that behaves"
|
|
86
|
+
" differently from the original, resulting in unexpected behavior."
|
|
87
|
+
" Please proceed with caution."
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
hashed_function_registry = getattr(serialized, _FUNCTION_REGISTRY_NAME)
|
|
91
|
+
try:
|
|
92
|
+
self.function_registry.validate(hashed_function_registry)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"An issue found during deserializing EDASpec: {e}"
|
|
96
|
+
) from e
|
|
97
|
+
|
|
98
|
+
aggregation_config = self._from_aggregation_config_proto(
|
|
99
|
+
serialized.aggregation_config
|
|
100
|
+
)
|
|
101
|
+
vif_spec = self._from_vif_spec_proto(serialized.vif_spec)
|
|
102
|
+
return eda_spec.EDASpec(
|
|
103
|
+
aggregation_config=aggregation_config,
|
|
104
|
+
vif_spec=vif_spec,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def _to_aggregation_config_proto(
|
|
108
|
+
self, config: eda_spec.AggregationConfig
|
|
109
|
+
) -> eda_spec_pb.AggregationConfig:
|
|
110
|
+
"""Converts a Python `AggregationConfig` to a proto."""
|
|
111
|
+
proto = eda_spec_pb.AggregationConfig()
|
|
112
|
+
if config.control_variables:
|
|
113
|
+
for key, func in config.control_variables.items():
|
|
114
|
+
proto.control_variables[key].CopyFrom(
|
|
115
|
+
self._to_aggregation_function_proto(func, key, "control_variables")
|
|
116
|
+
)
|
|
117
|
+
if config.non_media_treatments:
|
|
118
|
+
for key, func in config.non_media_treatments.items():
|
|
119
|
+
proto.non_media_treatments[key].CopyFrom(
|
|
120
|
+
self._to_aggregation_function_proto(
|
|
121
|
+
func, key, "non_media_treatments"
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
return proto
|
|
125
|
+
|
|
126
|
+
def _to_aggregation_function_proto(
|
|
127
|
+
self, func: eda_spec.AggregationFn, key: str, field: str
|
|
128
|
+
) -> eda_spec_pb.AggregationFunction:
|
|
129
|
+
"""Converts a Python aggregation function to a proto."""
|
|
130
|
+
function_key = self.function_registry.get_function_key(func)
|
|
131
|
+
if function_key is not None:
|
|
132
|
+
return eda_spec_pb.AggregationFunction(function_key=function_key)
|
|
133
|
+
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"Custom aggregation function `{key}` in `{field}` detected, but not"
|
|
136
|
+
" found in registry. Please add custom functions to registry when"
|
|
137
|
+
" saving models."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _from_aggregation_config_proto(
|
|
141
|
+
self, proto: eda_spec_pb.AggregationConfig
|
|
142
|
+
) -> eda_spec.AggregationConfig:
|
|
143
|
+
"""Converts a proto `AggregationConfig` to a Python object."""
|
|
144
|
+
control_variables = {
|
|
145
|
+
key: self._from_aggregation_function_proto(key, val)
|
|
146
|
+
for key, val in proto.control_variables.items()
|
|
147
|
+
}
|
|
148
|
+
non_media_treatments = {
|
|
149
|
+
key: self._from_aggregation_function_proto(key, val)
|
|
150
|
+
for key, val in proto.non_media_treatments.items()
|
|
151
|
+
}
|
|
152
|
+
return eda_spec.AggregationConfig(
|
|
153
|
+
control_variables=control_variables,
|
|
154
|
+
non_media_treatments=non_media_treatments,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def _from_aggregation_function_proto(
|
|
158
|
+
self, var_name: str, agg_func: eda_spec_pb.AggregationFunction
|
|
159
|
+
) -> eda_spec.AggregationFn:
|
|
160
|
+
"""Converts a proto `AggregationFunction` to a Python function."""
|
|
161
|
+
if not agg_func.function_key:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"Function key is required in `AggregationFunction` proto message."
|
|
164
|
+
f" The function key for {var_name} is empty."
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if agg_func.function_key in self.function_registry:
|
|
168
|
+
return self.function_registry[agg_func.function_key]
|
|
169
|
+
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Function key `{agg_func.function_key}` not found in registry."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _to_vif_spec_proto(
|
|
175
|
+
self, vif_spec_obj: eda_spec.VIFSpec
|
|
176
|
+
) -> eda_spec_pb.VIFSpec:
|
|
177
|
+
"""Converts a Python `VIFSpec` to a proto."""
|
|
178
|
+
return eda_spec_pb.VIFSpec(
|
|
179
|
+
geo_threshold=vif_spec_obj.geo_threshold,
|
|
180
|
+
overall_threshold=vif_spec_obj.overall_threshold,
|
|
181
|
+
national_threshold=vif_spec_obj.national_threshold,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def _from_vif_spec_proto(
|
|
185
|
+
self, vif_spec_proto: eda_spec_pb.VIFSpec
|
|
186
|
+
) -> eda_spec.VIFSpec:
|
|
187
|
+
"""Converts a proto `VIFSpec` to a Python object."""
|
|
188
|
+
return eda_spec.VIFSpec(
|
|
189
|
+
geo_threshold=vif_spec_proto.geo_threshold,
|
|
190
|
+
overall_threshold=vif_spec_proto.overall_threshold,
|
|
191
|
+
national_threshold=vif_spec_proto.national_threshold,
|
|
192
|
+
)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Function registry for Serde."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import functools
|
|
20
|
+
import hashlib
|
|
21
|
+
import inspect
|
|
22
|
+
from typing import Any, Callable
|
|
23
|
+
import warnings
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SourceCodeRetrievalError(Exception):
|
|
27
|
+
"""Raised when the source code of a function cannot be retrieved."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LambdaSourceCodeWarning(UserWarning):
|
|
31
|
+
"""Warning issued when trying to get source code of a lambda function."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_func_source(func: Callable[..., Any]) -> str:
|
|
35
|
+
"""Returns the source code of a function.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
func: The function to get the source code for.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The source code of the function.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
SourceCodeRetrievalError: If the source code of the function cannot be
|
|
45
|
+
retrieved.
|
|
46
|
+
"""
|
|
47
|
+
if hasattr(func, "__code__") and func.__code__.co_name == "<lambda>":
|
|
48
|
+
warnings.warn(
|
|
49
|
+
"Retrieving the source code of a lambda function might not work"
|
|
50
|
+
" successfully. It's recommended to explicitly define a function.",
|
|
51
|
+
LambdaSourceCodeWarning,
|
|
52
|
+
)
|
|
53
|
+
try:
|
|
54
|
+
return inspect.getsource(func)
|
|
55
|
+
except (OSError, TypeError) as e:
|
|
56
|
+
raise SourceCodeRetrievalError(
|
|
57
|
+
f"Source code of function {func} is not retrievable."
|
|
58
|
+
) from e
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _get_hash(value: str) -> str:
|
|
62
|
+
"""Returns a SHA-256 hash of the given value."""
|
|
63
|
+
encoded_string = value.encode("utf-8")
|
|
64
|
+
sha_256_hash = hashlib.sha256()
|
|
65
|
+
sha_256_hash.update(encoded_string)
|
|
66
|
+
return sha_256_hash.hexdigest()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class FunctionRegistry(dict[str, Callable[..., Any]]):
|
|
70
|
+
"""A dictionary-like container for custom functions used in serialization.
|
|
71
|
+
|
|
72
|
+
This class extends dict and provides methods for hashing, validation,
|
|
73
|
+
and key retrieval based on function identity, required for safe
|
|
74
|
+
serialization and deserialization of models that use custom functions.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, *args, **kwargs):
|
|
78
|
+
"""Initializes the FunctionRegistry.
|
|
79
|
+
|
|
80
|
+
Accepts the same arguments as a standard dictionary constructor.
|
|
81
|
+
For example:
|
|
82
|
+
reg = FunctionRegistry({'func1': my_func1, 'func2': my_func2})
|
|
83
|
+
reg = FunctionRegistry(func1=my_func1, func2=my_func2)
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
*args: Positional arguments to pass to the dictionary constructor.
|
|
87
|
+
**kwargs: Keyword arguments to pass to the dictionary constructor.
|
|
88
|
+
"""
|
|
89
|
+
super().__init__(*args, **kwargs)
|
|
90
|
+
|
|
91
|
+
@functools.cached_property
|
|
92
|
+
def hashed_registry(self) -> dict[str, str]:
|
|
93
|
+
"""Returns hashed function registry with keys mapped to hashed function code."""
|
|
94
|
+
return {
|
|
95
|
+
key: _get_hash(_get_func_source(function))
|
|
96
|
+
for key, function in self.items()
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def validate(self, stored_hashed_function_registry: dict[str, str]):
|
|
100
|
+
"""Validates whether functions within the registry have changed.
|
|
101
|
+
|
|
102
|
+
It checks that all functions listed in stored_hashed_function_registry
|
|
103
|
+
are present in this registry, and that their source code hash matches
|
|
104
|
+
the stored hash.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
stored_hashed_function_registry: The hashed function registry from the
|
|
108
|
+
serialized object.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
ValueError: If a function is missing or a hash mismatch is detected.
|
|
112
|
+
"""
|
|
113
|
+
if not stored_hashed_function_registry and self:
|
|
114
|
+
warnings.warn(
|
|
115
|
+
"A function registry was provided during loading, but none was"
|
|
116
|
+
" found on the serialized object. Custom functions will be"
|
|
117
|
+
" ignored."
|
|
118
|
+
)
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
for key, stored_hash in stored_hashed_function_registry.items():
|
|
122
|
+
if key not in self:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Function '{key}' is required by the serialized object but"
|
|
125
|
+
" is missing from the provided function registry."
|
|
126
|
+
)
|
|
127
|
+
func = self[key]
|
|
128
|
+
try:
|
|
129
|
+
source_code = _get_func_source(func)
|
|
130
|
+
except SourceCodeRetrievalError as e:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Failed to retrieve source code of function {key}."
|
|
133
|
+
) from e
|
|
134
|
+
evaluated_hash = _get_hash(source_code)
|
|
135
|
+
if stored_hash != evaluated_hash:
|
|
136
|
+
raise ValueError(f"Function registry hash mismatch for {key}.")
|
|
137
|
+
|
|
138
|
+
def get_function_key(self, func: Callable[..., Any]) -> str | None:
|
|
139
|
+
"""Returns the function key for the given function from the registry."""
|
|
140
|
+
for function_key, registry_func in self.items():
|
|
141
|
+
if func is registry_func:
|
|
142
|
+
return function_key
|
|
143
|
+
return None
|
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Serde for Hyperparameters."""
|
|
16
|
+
|
|
17
|
+
import warnings
|
|
18
|
+
|
|
19
|
+
from meridian import backend
|
|
20
|
+
from meridian import constants as c
|
|
21
|
+
from meridian.model import spec
|
|
22
|
+
from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
|
|
23
|
+
from schema.serde import constants as sc
|
|
24
|
+
from schema.serde import serde
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
_MediaEffectsDist = meridian_pb.MediaEffectsDistribution
|
|
28
|
+
_PaidMediaPriorType = meridian_pb.PaidMediaPriorType
|
|
29
|
+
_NonPaidTreatmentsPriorType = meridian_pb.NonPaidTreatmentsPriorType
|
|
30
|
+
_NonMediaBaselineFunction = (
|
|
31
|
+
meridian_pb.NonMediaBaselineValue.NonMediaBaselineFunction
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _media_effects_dist_to_proto_enum(
|
|
36
|
+
media_effect_dict: str,
|
|
37
|
+
) -> _MediaEffectsDist:
|
|
38
|
+
match media_effect_dict:
|
|
39
|
+
case c.MEDIA_EFFECTS_LOG_NORMAL:
|
|
40
|
+
return _MediaEffectsDist.LOG_NORMAL
|
|
41
|
+
case c.MEDIA_EFFECTS_NORMAL:
|
|
42
|
+
return _MediaEffectsDist.NORMAL
|
|
43
|
+
case _:
|
|
44
|
+
return _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _proto_enum_to_media_effects_dist(
|
|
48
|
+
proto_enum: _MediaEffectsDist,
|
|
49
|
+
) -> str:
|
|
50
|
+
"""Converts a `_MediaEffectsDist` enum to its string representation."""
|
|
51
|
+
match proto_enum:
|
|
52
|
+
case _MediaEffectsDist.LOG_NORMAL:
|
|
53
|
+
return c.MEDIA_EFFECTS_LOG_NORMAL
|
|
54
|
+
case _MediaEffectsDist.NORMAL:
|
|
55
|
+
return c.MEDIA_EFFECTS_NORMAL
|
|
56
|
+
case _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED:
|
|
57
|
+
warnings.warn(
|
|
58
|
+
"Media effects distribution is unspecified. Resolving to"
|
|
59
|
+
" 'log-normal'."
|
|
60
|
+
)
|
|
61
|
+
return c.MEDIA_EFFECTS_LOG_NORMAL
|
|
62
|
+
case _:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"Unsupported MediaEffectsDistribution proto enum value:"
|
|
65
|
+
f" {proto_enum}."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _paid_media_prior_type_to_proto_enum(
|
|
70
|
+
paid_media_prior_type: str | None,
|
|
71
|
+
) -> _PaidMediaPriorType:
|
|
72
|
+
"""Converts a paid media prior type string to its proto enum."""
|
|
73
|
+
if paid_media_prior_type is None:
|
|
74
|
+
return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
|
|
75
|
+
try:
|
|
76
|
+
return _PaidMediaPriorType.Value(paid_media_prior_type.upper())
|
|
77
|
+
except ValueError:
|
|
78
|
+
warnings.warn(
|
|
79
|
+
f"Invalid paid media prior type: {paid_media_prior_type}. Resolving to"
|
|
80
|
+
" PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED."
|
|
81
|
+
)
|
|
82
|
+
return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _proto_enum_to_paid_media_prior_type(
|
|
86
|
+
proto_enum: _PaidMediaPriorType,
|
|
87
|
+
) -> str | None:
|
|
88
|
+
"""Converts a `_PaidMediaPriorType` enum to its string representation."""
|
|
89
|
+
if proto_enum == _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED:
|
|
90
|
+
return None
|
|
91
|
+
return _PaidMediaPriorType.Name(proto_enum).lower()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _non_paid_prior_type_to_proto_enum(
|
|
95
|
+
non_paid_prior_type: str,
|
|
96
|
+
) -> _NonPaidTreatmentsPriorType:
|
|
97
|
+
"""Converts a non-paid prior type string to its proto enum."""
|
|
98
|
+
try:
|
|
99
|
+
return _NonPaidTreatmentsPriorType.Value(
|
|
100
|
+
f"NON_PAID_TREATMENTS_PRIOR_TYPE_{non_paid_prior_type.upper()}"
|
|
101
|
+
)
|
|
102
|
+
except ValueError:
|
|
103
|
+
warnings.warn(
|
|
104
|
+
f"Invalid non-paid prior type: {non_paid_prior_type}. Resolving to"
|
|
105
|
+
" NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION."
|
|
106
|
+
)
|
|
107
|
+
return (
|
|
108
|
+
_NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _proto_enum_to_non_paid_prior_type(
|
|
113
|
+
proto_enum: _NonPaidTreatmentsPriorType,
|
|
114
|
+
) -> str:
|
|
115
|
+
"""Converts a `_NonPaidTreatmentsPriorType` enum to its string representation."""
|
|
116
|
+
if (
|
|
117
|
+
proto_enum
|
|
118
|
+
== _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED
|
|
119
|
+
):
|
|
120
|
+
warnings.warn(
|
|
121
|
+
"Non-paid prior type is unspecified. Resolving to 'contribution'."
|
|
122
|
+
)
|
|
123
|
+
return c.TREATMENT_PRIOR_TYPE_CONTRIBUTION
|
|
124
|
+
return (
|
|
125
|
+
_NonPaidTreatmentsPriorType.Name(proto_enum)
|
|
126
|
+
.replace("NON_PAID_TREATMENTS_PRIOR_TYPE_", "")
|
|
127
|
+
.lower()
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class HyperparametersSerde(
|
|
132
|
+
serde.Serde[meridian_pb.Hyperparameters, spec.ModelSpec]
|
|
133
|
+
):
|
|
134
|
+
"""Serializes and deserializes a ModelSpec into a `Hyperparameters` proto.
|
|
135
|
+
|
|
136
|
+
Note that this Serde only handles the Hyperparameters part of ModelSpec.
|
|
137
|
+
The 'prior' attribute of ModelSpec is serialized/deserialized separately
|
|
138
|
+
using DistributionSerde.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def serialize(self, obj: spec.ModelSpec) -> meridian_pb.Hyperparameters:
|
|
142
|
+
"""Serializes the given ModelSpec into a `Hyperparameters` proto."""
|
|
143
|
+
hyperparameters_proto = meridian_pb.Hyperparameters(
|
|
144
|
+
media_effects_dist=_media_effects_dist_to_proto_enum(
|
|
145
|
+
obj.media_effects_dist
|
|
146
|
+
),
|
|
147
|
+
hill_before_adstock=obj.hill_before_adstock,
|
|
148
|
+
unique_sigma_for_each_geo=obj.unique_sigma_for_each_geo,
|
|
149
|
+
media_prior_type=_paid_media_prior_type_to_proto_enum(
|
|
150
|
+
obj.media_prior_type
|
|
151
|
+
),
|
|
152
|
+
rf_prior_type=_paid_media_prior_type_to_proto_enum(obj.rf_prior_type),
|
|
153
|
+
paid_media_prior_type=_paid_media_prior_type_to_proto_enum(
|
|
154
|
+
obj.paid_media_prior_type
|
|
155
|
+
),
|
|
156
|
+
organic_media_prior_type=_non_paid_prior_type_to_proto_enum(
|
|
157
|
+
obj.organic_media_prior_type
|
|
158
|
+
),
|
|
159
|
+
organic_rf_prior_type=_non_paid_prior_type_to_proto_enum(
|
|
160
|
+
obj.organic_rf_prior_type
|
|
161
|
+
),
|
|
162
|
+
non_media_treatments_prior_type=_non_paid_prior_type_to_proto_enum(
|
|
163
|
+
obj.non_media_treatments_prior_type
|
|
164
|
+
),
|
|
165
|
+
enable_aks=obj.enable_aks,
|
|
166
|
+
)
|
|
167
|
+
if obj.max_lag is not None:
|
|
168
|
+
hyperparameters_proto.max_lag = obj.max_lag
|
|
169
|
+
|
|
170
|
+
if isinstance(obj.knots, int):
|
|
171
|
+
hyperparameters_proto.knots.append(obj.knots)
|
|
172
|
+
elif isinstance(obj.knots, list):
|
|
173
|
+
hyperparameters_proto.knots.extend(obj.knots)
|
|
174
|
+
|
|
175
|
+
if isinstance(obj.baseline_geo, str):
|
|
176
|
+
hyperparameters_proto.baseline_geo_string = obj.baseline_geo
|
|
177
|
+
elif isinstance(obj.baseline_geo, int):
|
|
178
|
+
hyperparameters_proto.baseline_geo_int = obj.baseline_geo
|
|
179
|
+
|
|
180
|
+
if obj.roi_calibration_period is not None:
|
|
181
|
+
hyperparameters_proto.roi_calibration_period.CopyFrom(
|
|
182
|
+
backend.make_tensor_proto(np.array(obj.roi_calibration_period))
|
|
183
|
+
)
|
|
184
|
+
if obj.rf_roi_calibration_period is not None:
|
|
185
|
+
hyperparameters_proto.rf_roi_calibration_period.CopyFrom(
|
|
186
|
+
backend.make_tensor_proto(np.array(obj.rf_roi_calibration_period))
|
|
187
|
+
)
|
|
188
|
+
if obj.holdout_id is not None:
|
|
189
|
+
hyperparameters_proto.holdout_id.CopyFrom(
|
|
190
|
+
backend.make_tensor_proto(np.array(obj.holdout_id))
|
|
191
|
+
)
|
|
192
|
+
if obj.control_population_scaling_id is not None:
|
|
193
|
+
hyperparameters_proto.control_population_scaling_id.CopyFrom(
|
|
194
|
+
backend.make_tensor_proto(np.array(obj.control_population_scaling_id))
|
|
195
|
+
)
|
|
196
|
+
if obj.non_media_population_scaling_id is not None:
|
|
197
|
+
hyperparameters_proto.non_media_population_scaling_id.CopyFrom(
|
|
198
|
+
backend.make_tensor_proto(
|
|
199
|
+
np.array(obj.non_media_population_scaling_id)
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if isinstance(obj.adstock_decay_spec, str):
|
|
204
|
+
hyperparameters_proto.global_adstock_decay = obj.adstock_decay_spec
|
|
205
|
+
elif isinstance(obj.adstock_decay_spec, dict):
|
|
206
|
+
hyperparameters_proto.adstock_decay_by_channel.channel_decays.update(
|
|
207
|
+
obj.adstock_decay_spec
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if obj.non_media_baseline_values is not None:
|
|
211
|
+
for value in obj.non_media_baseline_values:
|
|
212
|
+
value_proto = hyperparameters_proto.non_media_baseline_values.add()
|
|
213
|
+
if isinstance(value, str):
|
|
214
|
+
if value.lower() == "min":
|
|
215
|
+
value_proto.function_value = _NonMediaBaselineFunction.MIN
|
|
216
|
+
elif value.lower() == "max":
|
|
217
|
+
value_proto.function_value = _NonMediaBaselineFunction.MAX
|
|
218
|
+
elif isinstance(value, (float, int)):
|
|
219
|
+
value_proto.value = float(value)
|
|
220
|
+
|
|
221
|
+
return hyperparameters_proto
|
|
222
|
+
|
|
223
|
+
def deserialize(
|
|
224
|
+
self,
|
|
225
|
+
serialized: meridian_pb.Hyperparameters,
|
|
226
|
+
serialized_version: str = "",
|
|
227
|
+
) -> spec.ModelSpec:
|
|
228
|
+
"""Deserializes the given `Hyperparameters` proto into a ModelSpec.
|
|
229
|
+
|
|
230
|
+
Note that this only deserializes the Hyperparameters part of ModelSpec.
|
|
231
|
+
The 'prior' attribute of ModelSpec is deserialized separately
|
|
232
|
+
using DistributionSerde and should be combined in the MeridianSerde.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
serialized: The serialized `Hyperparameters` proto.
|
|
236
|
+
serialized_version: The version of the serialized model. This is used to
|
|
237
|
+
handle changes in deserialization logic across different versions.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
A Meridian model spec container.
|
|
241
|
+
"""
|
|
242
|
+
baseline_geo = None
|
|
243
|
+
baseline_geo_field = serialized.WhichOneof(sc.BASELINE_GEO_ONEOF)
|
|
244
|
+
if baseline_geo_field == sc.BASELINE_GEO_INT:
|
|
245
|
+
baseline_geo = serialized.baseline_geo_int
|
|
246
|
+
elif baseline_geo_field == sc.BASELINE_GEO_STRING:
|
|
247
|
+
baseline_geo = serialized.baseline_geo_string
|
|
248
|
+
|
|
249
|
+
knots = None
|
|
250
|
+
if serialized.knots:
|
|
251
|
+
if len(serialized.knots) == 1:
|
|
252
|
+
knots = serialized.knots[0]
|
|
253
|
+
else:
|
|
254
|
+
knots = list(serialized.knots)
|
|
255
|
+
|
|
256
|
+
max_lag = serialized.max_lag if serialized.HasField(c.MAX_LAG) else None
|
|
257
|
+
|
|
258
|
+
roi_calibration_period = (
|
|
259
|
+
backend.make_ndarray(serialized.roi_calibration_period)
|
|
260
|
+
if serialized.HasField(c.ROI_CALIBRATION_PERIOD)
|
|
261
|
+
else None
|
|
262
|
+
)
|
|
263
|
+
rf_roi_calibration_period = (
|
|
264
|
+
backend.make_ndarray(serialized.rf_roi_calibration_period)
|
|
265
|
+
if serialized.HasField(c.RF_ROI_CALIBRATION_PERIOD)
|
|
266
|
+
else None
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
holdout_id = (
|
|
270
|
+
backend.make_ndarray(serialized.holdout_id)
|
|
271
|
+
if serialized.HasField(sc.HOLDOUT_ID)
|
|
272
|
+
else None
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
control_population_scaling_id = (
|
|
276
|
+
backend.make_ndarray(serialized.control_population_scaling_id)
|
|
277
|
+
if serialized.HasField(sc.CONTROL_POPULATION_SCALING_ID)
|
|
278
|
+
else None
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
non_media_population_scaling_id = (
|
|
282
|
+
backend.make_ndarray(serialized.non_media_population_scaling_id)
|
|
283
|
+
if serialized.HasField(sc.NON_MEDIA_POPULATION_SCALING_ID)
|
|
284
|
+
else None
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
non_media_baseline_values = None
|
|
288
|
+
if serialized.non_media_baseline_values:
|
|
289
|
+
non_media_baseline_values = []
|
|
290
|
+
for value_proto in serialized.non_media_baseline_values:
|
|
291
|
+
field = value_proto.WhichOneof("non_media_baseline_value")
|
|
292
|
+
if field == "value":
|
|
293
|
+
non_media_baseline_values.append(value_proto.value)
|
|
294
|
+
elif field == "function_value":
|
|
295
|
+
if value_proto.function_value == _NonMediaBaselineFunction.MIN:
|
|
296
|
+
non_media_baseline_values.append("min")
|
|
297
|
+
elif value_proto.function_value == _NonMediaBaselineFunction.MAX:
|
|
298
|
+
non_media_baseline_values.append("max")
|
|
299
|
+
elif (
|
|
300
|
+
value_proto.function_value
|
|
301
|
+
== _NonMediaBaselineFunction.NON_MEDIA_BASELINE_FUNCTION_UNSPECIFIED
|
|
302
|
+
):
|
|
303
|
+
warnings.warn(
|
|
304
|
+
"Non-media baseline function value is unspecified. Resolving to"
|
|
305
|
+
" 'min'."
|
|
306
|
+
)
|
|
307
|
+
non_media_baseline_values.append("min")
|
|
308
|
+
else:
|
|
309
|
+
raise ValueError(
|
|
310
|
+
"Unsupported NonMediaBaselineFunction proto enum value:"
|
|
311
|
+
f" {value_proto.function_value}."
|
|
312
|
+
)
|
|
313
|
+
else:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"Unsupported NonMediaBaselineValue proto enum value: {field}."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
adstock_decay_spec_field = serialized.WhichOneof(sc.ADSTOCK_DECAY_SPEC)
|
|
319
|
+
if adstock_decay_spec_field == sc.GLOBAL_ADSTOCK_DECAY:
|
|
320
|
+
adstock_decay_spec = serialized.global_adstock_decay
|
|
321
|
+
elif adstock_decay_spec_field == sc.ADSTOCK_DECAY_BY_CHANNEL:
|
|
322
|
+
adstock_decay_spec = dict(
|
|
323
|
+
serialized.adstock_decay_by_channel.channel_decays
|
|
324
|
+
)
|
|
325
|
+
else:
|
|
326
|
+
adstock_decay_spec = sc.DEFAULT_DECAY
|
|
327
|
+
|
|
328
|
+
return spec.ModelSpec(
|
|
329
|
+
media_effects_dist=_proto_enum_to_media_effects_dist(
|
|
330
|
+
serialized.media_effects_dist
|
|
331
|
+
),
|
|
332
|
+
hill_before_adstock=serialized.hill_before_adstock,
|
|
333
|
+
max_lag=max_lag,
|
|
334
|
+
unique_sigma_for_each_geo=serialized.unique_sigma_for_each_geo,
|
|
335
|
+
media_prior_type=_proto_enum_to_paid_media_prior_type(
|
|
336
|
+
serialized.media_prior_type
|
|
337
|
+
),
|
|
338
|
+
rf_prior_type=_proto_enum_to_paid_media_prior_type(
|
|
339
|
+
serialized.rf_prior_type
|
|
340
|
+
),
|
|
341
|
+
paid_media_prior_type=_proto_enum_to_paid_media_prior_type(
|
|
342
|
+
serialized.paid_media_prior_type
|
|
343
|
+
),
|
|
344
|
+
organic_media_prior_type=_proto_enum_to_non_paid_prior_type(
|
|
345
|
+
serialized.organic_media_prior_type
|
|
346
|
+
),
|
|
347
|
+
organic_rf_prior_type=_proto_enum_to_non_paid_prior_type(
|
|
348
|
+
serialized.organic_rf_prior_type
|
|
349
|
+
),
|
|
350
|
+
non_media_treatments_prior_type=_proto_enum_to_non_paid_prior_type(
|
|
351
|
+
serialized.non_media_treatments_prior_type
|
|
352
|
+
),
|
|
353
|
+
non_media_baseline_values=non_media_baseline_values,
|
|
354
|
+
knots=knots,
|
|
355
|
+
enable_aks=serialized.enable_aks,
|
|
356
|
+
baseline_geo=baseline_geo,
|
|
357
|
+
roi_calibration_period=roi_calibration_period,
|
|
358
|
+
rf_roi_calibration_period=rf_roi_calibration_period,
|
|
359
|
+
holdout_id=holdout_id,
|
|
360
|
+
control_population_scaling_id=control_population_scaling_id,
|
|
361
|
+
non_media_population_scaling_id=non_media_population_scaling_id,
|
|
362
|
+
adstock_decay_spec=adstock_decay_spec,
|
|
363
|
+
)
|