google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,192 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Serialization and deserialization of `EDASpec` objects."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import warnings
20
+
21
+ from meridian.model.eda import eda_spec
22
+ from mmm.v1.model.meridian.eda import eda_spec_pb2 as eda_spec_pb
23
+ from schema.serde import function_registry as function_registry_utils
24
+ from schema.serde import serde
25
+
26
+
27
+ FunctionRegistry = function_registry_utils.FunctionRegistry
28
+ _FUNCTION_REGISTRY_NAME = "function_registry"
29
+
30
+
31
+ class EDASpecSerde(serde.Serde[eda_spec_pb.EDASpec, eda_spec.EDASpec]):
32
+ """Serializes and deserializes an `EDASpec` object into an `EDASpec` proto."""
33
+
34
+ def __init__(self, function_registry: FunctionRegistry):
35
+ """Initializes an `EDASpecSerde` instance.
36
+
37
+ Args:
38
+ function_registry: A lookup table containing custom functions used by
39
+ `EDASpec` objects. It's recommended to explicitly define the custom
40
+ functions instead of using lambdas, as lambda functions may not be
41
+ serialized successfully.
42
+ """
43
+ self._function_registry = function_registry
44
+
45
+ @property
46
+ def function_registry(self) -> FunctionRegistry:
47
+ return self._function_registry
48
+
49
+ def serialize(self, obj: eda_spec.EDASpec) -> eda_spec_pb.EDASpec:
50
+ """Serializes the given `EDASpec` object into an `EDASpec` proto."""
51
+ proto = eda_spec_pb.EDASpec(
52
+ aggregation_config=self._to_aggregation_config_proto(
53
+ obj.aggregation_config
54
+ ),
55
+ vif_spec=self._to_vif_spec_proto(obj.vif_spec),
56
+ )
57
+ hashed_function_registry = self.function_registry.hashed_registry
58
+ proto.function_registry.update(hashed_function_registry)
59
+ return proto
60
+
61
+ def deserialize(
62
+ self,
63
+ serialized: eda_spec_pb.EDASpec,
64
+ serialized_version: str = "",
65
+ force_deserialization: bool = False,
66
+ ) -> eda_spec.EDASpec:
67
+ """Deserializes the `EDASpec` proto.
68
+
69
+ Args:
70
+ serialized: A serialized `EDASpec` object.
71
+ serialized_version: The version of the serialized Meridian model. This is
72
+ used to handle changes in deserialization logic across different
73
+ versions.
74
+ force_deserialization: If True, bypasses the safety check that validates
75
+ whether functions within `function_registry` have changed after
76
+ serialization. Use with caution.
77
+
78
+ Returns:
79
+ A deserialized `EDASpec` object.
80
+ """
81
+ if force_deserialization:
82
+ warnings.warn(
83
+ "You're attempting to deserialize an EDASpec while ignoring changes"
84
+ " to custom functions. This is a risky operation that can"
85
+ " potentially lead to a deserialized EDASpec that behaves"
86
+ " differently from the original, resulting in unexpected behavior."
87
+ " Please proceed with caution."
88
+ )
89
+ else:
90
+ hashed_function_registry = getattr(serialized, _FUNCTION_REGISTRY_NAME)
91
+ try:
92
+ self.function_registry.validate(hashed_function_registry)
93
+ except Exception as e:
94
+ raise ValueError(
95
+ f"An issue found during deserializing EDASpec: {e}"
96
+ ) from e
97
+
98
+ aggregation_config = self._from_aggregation_config_proto(
99
+ serialized.aggregation_config
100
+ )
101
+ vif_spec = self._from_vif_spec_proto(serialized.vif_spec)
102
+ return eda_spec.EDASpec(
103
+ aggregation_config=aggregation_config,
104
+ vif_spec=vif_spec,
105
+ )
106
+
107
+ def _to_aggregation_config_proto(
108
+ self, config: eda_spec.AggregationConfig
109
+ ) -> eda_spec_pb.AggregationConfig:
110
+ """Converts a Python `AggregationConfig` to a proto."""
111
+ proto = eda_spec_pb.AggregationConfig()
112
+ if config.control_variables:
113
+ for key, func in config.control_variables.items():
114
+ proto.control_variables[key].CopyFrom(
115
+ self._to_aggregation_function_proto(func, key, "control_variables")
116
+ )
117
+ if config.non_media_treatments:
118
+ for key, func in config.non_media_treatments.items():
119
+ proto.non_media_treatments[key].CopyFrom(
120
+ self._to_aggregation_function_proto(
121
+ func, key, "non_media_treatments"
122
+ )
123
+ )
124
+ return proto
125
+
126
+ def _to_aggregation_function_proto(
127
+ self, func: eda_spec.AggregationFn, key: str, field: str
128
+ ) -> eda_spec_pb.AggregationFunction:
129
+ """Converts a Python aggregation function to a proto."""
130
+ function_key = self.function_registry.get_function_key(func)
131
+ if function_key is not None:
132
+ return eda_spec_pb.AggregationFunction(function_key=function_key)
133
+
134
+ raise ValueError(
135
+ f"Custom aggregation function `{key}` in `{field}` detected, but not"
136
+ " found in registry. Please add custom functions to registry when"
137
+ " saving models."
138
+ )
139
+
140
+ def _from_aggregation_config_proto(
141
+ self, proto: eda_spec_pb.AggregationConfig
142
+ ) -> eda_spec.AggregationConfig:
143
+ """Converts a proto `AggregationConfig` to a Python object."""
144
+ control_variables = {
145
+ key: self._from_aggregation_function_proto(key, val)
146
+ for key, val in proto.control_variables.items()
147
+ }
148
+ non_media_treatments = {
149
+ key: self._from_aggregation_function_proto(key, val)
150
+ for key, val in proto.non_media_treatments.items()
151
+ }
152
+ return eda_spec.AggregationConfig(
153
+ control_variables=control_variables,
154
+ non_media_treatments=non_media_treatments,
155
+ )
156
+
157
+ def _from_aggregation_function_proto(
158
+ self, var_name: str, agg_func: eda_spec_pb.AggregationFunction
159
+ ) -> eda_spec.AggregationFn:
160
+ """Converts a proto `AggregationFunction` to a Python function."""
161
+ if not agg_func.function_key:
162
+ raise ValueError(
163
+ "Function key is required in `AggregationFunction` proto message."
164
+ f" The function key for {var_name} is empty."
165
+ )
166
+
167
+ if agg_func.function_key in self.function_registry:
168
+ return self.function_registry[agg_func.function_key]
169
+
170
+ raise ValueError(
171
+ f"Function key `{agg_func.function_key}` not found in registry."
172
+ )
173
+
174
+ def _to_vif_spec_proto(
175
+ self, vif_spec_obj: eda_spec.VIFSpec
176
+ ) -> eda_spec_pb.VIFSpec:
177
+ """Converts a Python `VIFSpec` to a proto."""
178
+ return eda_spec_pb.VIFSpec(
179
+ geo_threshold=vif_spec_obj.geo_threshold,
180
+ overall_threshold=vif_spec_obj.overall_threshold,
181
+ national_threshold=vif_spec_obj.national_threshold,
182
+ )
183
+
184
+ def _from_vif_spec_proto(
185
+ self, vif_spec_proto: eda_spec_pb.VIFSpec
186
+ ) -> eda_spec.VIFSpec:
187
+ """Converts a proto `VIFSpec` to a Python object."""
188
+ return eda_spec.VIFSpec(
189
+ geo_threshold=vif_spec_proto.geo_threshold,
190
+ overall_threshold=vif_spec_proto.overall_threshold,
191
+ national_threshold=vif_spec_proto.national_threshold,
192
+ )
@@ -0,0 +1,143 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Function registry for Serde."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import functools
20
+ import hashlib
21
+ import inspect
22
+ from typing import Any, Callable
23
+ import warnings
24
+
25
+
26
+ class SourceCodeRetrievalError(Exception):
27
+ """Raised when the source code of a function cannot be retrieved."""
28
+
29
+
30
+ class LambdaSourceCodeWarning(UserWarning):
31
+ """Warning issued when trying to get source code of a lambda function."""
32
+
33
+
34
+ def _get_func_source(func: Callable[..., Any]) -> str:
35
+ """Returns the source code of a function.
36
+
37
+ Args:
38
+ func: The function to get the source code for.
39
+
40
+ Returns:
41
+ The source code of the function.
42
+
43
+ Raises:
44
+ SourceCodeRetrievalError: If the source code of the function cannot be
45
+ retrieved.
46
+ """
47
+ if hasattr(func, "__code__") and func.__code__.co_name == "<lambda>":
48
+ warnings.warn(
49
+ "Retrieving the source code of a lambda function might not work"
50
+ " successfully. It's recommended to explicitly define a function.",
51
+ LambdaSourceCodeWarning,
52
+ )
53
+ try:
54
+ return inspect.getsource(func)
55
+ except (OSError, TypeError) as e:
56
+ raise SourceCodeRetrievalError(
57
+ f"Source code of function {func} is not retrievable."
58
+ ) from e
59
+
60
+
61
+ def _get_hash(value: str) -> str:
62
+ """Returns a SHA-256 hash of the given value."""
63
+ encoded_string = value.encode("utf-8")
64
+ sha_256_hash = hashlib.sha256()
65
+ sha_256_hash.update(encoded_string)
66
+ return sha_256_hash.hexdigest()
67
+
68
+
69
+ class FunctionRegistry(dict[str, Callable[..., Any]]):
70
+ """A dictionary-like container for custom functions used in serialization.
71
+
72
+ This class extends dict and provides methods for hashing, validation,
73
+ and key retrieval based on function identity, required for safe
74
+ serialization and deserialization of models that use custom functions.
75
+ """
76
+
77
+ def __init__(self, *args, **kwargs):
78
+ """Initializes the FunctionRegistry.
79
+
80
+ Accepts the same arguments as a standard dictionary constructor.
81
+ For example:
82
+ reg = FunctionRegistry({'func1': my_func1, 'func2': my_func2})
83
+ reg = FunctionRegistry(func1=my_func1, func2=my_func2)
84
+
85
+ Args:
86
+ *args: Positional arguments to pass to the dictionary constructor.
87
+ **kwargs: Keyword arguments to pass to the dictionary constructor.
88
+ """
89
+ super().__init__(*args, **kwargs)
90
+
91
+ @functools.cached_property
92
+ def hashed_registry(self) -> dict[str, str]:
93
+ """Returns hashed function registry with keys mapped to hashed function code."""
94
+ return {
95
+ key: _get_hash(_get_func_source(function))
96
+ for key, function in self.items()
97
+ }
98
+
99
+ def validate(self, stored_hashed_function_registry: dict[str, str]):
100
+ """Validates whether functions within the registry have changed.
101
+
102
+ It checks that all functions listed in stored_hashed_function_registry
103
+ are present in this registry, and that their source code hash matches
104
+ the stored hash.
105
+
106
+ Args:
107
+ stored_hashed_function_registry: The hashed function registry from the
108
+ serialized object.
109
+
110
+ Raises:
111
+ ValueError: If a function is missing or a hash mismatch is detected.
112
+ """
113
+ if not stored_hashed_function_registry and self:
114
+ warnings.warn(
115
+ "A function registry was provided during loading, but none was"
116
+ " found on the serialized object. Custom functions will be"
117
+ " ignored."
118
+ )
119
+ return
120
+
121
+ for key, stored_hash in stored_hashed_function_registry.items():
122
+ if key not in self:
123
+ raise ValueError(
124
+ f"Function '{key}' is required by the serialized object but"
125
+ " is missing from the provided function registry."
126
+ )
127
+ func = self[key]
128
+ try:
129
+ source_code = _get_func_source(func)
130
+ except SourceCodeRetrievalError as e:
131
+ raise ValueError(
132
+ f"Failed to retrieve source code of function {key}."
133
+ ) from e
134
+ evaluated_hash = _get_hash(source_code)
135
+ if stored_hash != evaluated_hash:
136
+ raise ValueError(f"Function registry hash mismatch for {key}.")
137
+
138
+ def get_function_key(self, func: Callable[..., Any]) -> str | None:
139
+ """Returns the function key for the given function from the registry."""
140
+ for function_key, registry_func in self.items():
141
+ if func is registry_func:
142
+ return function_key
143
+ return None
@@ -0,0 +1,363 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Serde for Hyperparameters."""
16
+
17
+ import warnings
18
+
19
+ from meridian import backend
20
+ from meridian import constants as c
21
+ from meridian.model import spec
22
+ from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
23
+ from schema.serde import constants as sc
24
+ from schema.serde import serde
25
+ import numpy as np
26
+
27
+ _MediaEffectsDist = meridian_pb.MediaEffectsDistribution
28
+ _PaidMediaPriorType = meridian_pb.PaidMediaPriorType
29
+ _NonPaidTreatmentsPriorType = meridian_pb.NonPaidTreatmentsPriorType
30
+ _NonMediaBaselineFunction = (
31
+ meridian_pb.NonMediaBaselineValue.NonMediaBaselineFunction
32
+ )
33
+
34
+
35
+ def _media_effects_dist_to_proto_enum(
36
+ media_effect_dict: str,
37
+ ) -> _MediaEffectsDist:
38
+ match media_effect_dict:
39
+ case c.MEDIA_EFFECTS_LOG_NORMAL:
40
+ return _MediaEffectsDist.LOG_NORMAL
41
+ case c.MEDIA_EFFECTS_NORMAL:
42
+ return _MediaEffectsDist.NORMAL
43
+ case _:
44
+ return _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED
45
+
46
+
47
+ def _proto_enum_to_media_effects_dist(
48
+ proto_enum: _MediaEffectsDist,
49
+ ) -> str:
50
+ """Converts a `_MediaEffectsDist` enum to its string representation."""
51
+ match proto_enum:
52
+ case _MediaEffectsDist.LOG_NORMAL:
53
+ return c.MEDIA_EFFECTS_LOG_NORMAL
54
+ case _MediaEffectsDist.NORMAL:
55
+ return c.MEDIA_EFFECTS_NORMAL
56
+ case _MediaEffectsDist.MEDIA_EFFECTS_DISTRIBUTION_UNSPECIFIED:
57
+ warnings.warn(
58
+ "Media effects distribution is unspecified. Resolving to"
59
+ " 'log-normal'."
60
+ )
61
+ return c.MEDIA_EFFECTS_LOG_NORMAL
62
+ case _:
63
+ raise ValueError(
64
+ "Unsupported MediaEffectsDistribution proto enum value:"
65
+ f" {proto_enum}."
66
+ )
67
+
68
+
69
+ def _paid_media_prior_type_to_proto_enum(
70
+ paid_media_prior_type: str | None,
71
+ ) -> _PaidMediaPriorType:
72
+ """Converts a paid media prior type string to its proto enum."""
73
+ if paid_media_prior_type is None:
74
+ return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
75
+ try:
76
+ return _PaidMediaPriorType.Value(paid_media_prior_type.upper())
77
+ except ValueError:
78
+ warnings.warn(
79
+ f"Invalid paid media prior type: {paid_media_prior_type}. Resolving to"
80
+ " PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED."
81
+ )
82
+ return _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED
83
+
84
+
85
+ def _proto_enum_to_paid_media_prior_type(
86
+ proto_enum: _PaidMediaPriorType,
87
+ ) -> str | None:
88
+ """Converts a `_PaidMediaPriorType` enum to its string representation."""
89
+ if proto_enum == _PaidMediaPriorType.PAID_MEDIA_PRIOR_TYPE_UNSPECIFIED:
90
+ return None
91
+ return _PaidMediaPriorType.Name(proto_enum).lower()
92
+
93
+
94
+ def _non_paid_prior_type_to_proto_enum(
95
+ non_paid_prior_type: str,
96
+ ) -> _NonPaidTreatmentsPriorType:
97
+ """Converts a non-paid prior type string to its proto enum."""
98
+ try:
99
+ return _NonPaidTreatmentsPriorType.Value(
100
+ f"NON_PAID_TREATMENTS_PRIOR_TYPE_{non_paid_prior_type.upper()}"
101
+ )
102
+ except ValueError:
103
+ warnings.warn(
104
+ f"Invalid non-paid prior type: {non_paid_prior_type}. Resolving to"
105
+ " NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION."
106
+ )
107
+ return (
108
+ _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_CONTRIBUTION
109
+ )
110
+
111
+
112
+ def _proto_enum_to_non_paid_prior_type(
113
+ proto_enum: _NonPaidTreatmentsPriorType,
114
+ ) -> str:
115
+ """Converts a `_NonPaidTreatmentsPriorType` enum to its string representation."""
116
+ if (
117
+ proto_enum
118
+ == _NonPaidTreatmentsPriorType.NON_PAID_TREATMENTS_PRIOR_TYPE_UNSPECIFIED
119
+ ):
120
+ warnings.warn(
121
+ "Non-paid prior type is unspecified. Resolving to 'contribution'."
122
+ )
123
+ return c.TREATMENT_PRIOR_TYPE_CONTRIBUTION
124
+ return (
125
+ _NonPaidTreatmentsPriorType.Name(proto_enum)
126
+ .replace("NON_PAID_TREATMENTS_PRIOR_TYPE_", "")
127
+ .lower()
128
+ )
129
+
130
+
131
+ class HyperparametersSerde(
132
+ serde.Serde[meridian_pb.Hyperparameters, spec.ModelSpec]
133
+ ):
134
+ """Serializes and deserializes a ModelSpec into a `Hyperparameters` proto.
135
+
136
+ Note that this Serde only handles the Hyperparameters part of ModelSpec.
137
+ The 'prior' attribute of ModelSpec is serialized/deserialized separately
138
+ using DistributionSerde.
139
+ """
140
+
141
+ def serialize(self, obj: spec.ModelSpec) -> meridian_pb.Hyperparameters:
142
+ """Serializes the given ModelSpec into a `Hyperparameters` proto."""
143
+ hyperparameters_proto = meridian_pb.Hyperparameters(
144
+ media_effects_dist=_media_effects_dist_to_proto_enum(
145
+ obj.media_effects_dist
146
+ ),
147
+ hill_before_adstock=obj.hill_before_adstock,
148
+ unique_sigma_for_each_geo=obj.unique_sigma_for_each_geo,
149
+ media_prior_type=_paid_media_prior_type_to_proto_enum(
150
+ obj.media_prior_type
151
+ ),
152
+ rf_prior_type=_paid_media_prior_type_to_proto_enum(obj.rf_prior_type),
153
+ paid_media_prior_type=_paid_media_prior_type_to_proto_enum(
154
+ obj.paid_media_prior_type
155
+ ),
156
+ organic_media_prior_type=_non_paid_prior_type_to_proto_enum(
157
+ obj.organic_media_prior_type
158
+ ),
159
+ organic_rf_prior_type=_non_paid_prior_type_to_proto_enum(
160
+ obj.organic_rf_prior_type
161
+ ),
162
+ non_media_treatments_prior_type=_non_paid_prior_type_to_proto_enum(
163
+ obj.non_media_treatments_prior_type
164
+ ),
165
+ enable_aks=obj.enable_aks,
166
+ )
167
+ if obj.max_lag is not None:
168
+ hyperparameters_proto.max_lag = obj.max_lag
169
+
170
+ if isinstance(obj.knots, int):
171
+ hyperparameters_proto.knots.append(obj.knots)
172
+ elif isinstance(obj.knots, list):
173
+ hyperparameters_proto.knots.extend(obj.knots)
174
+
175
+ if isinstance(obj.baseline_geo, str):
176
+ hyperparameters_proto.baseline_geo_string = obj.baseline_geo
177
+ elif isinstance(obj.baseline_geo, int):
178
+ hyperparameters_proto.baseline_geo_int = obj.baseline_geo
179
+
180
+ if obj.roi_calibration_period is not None:
181
+ hyperparameters_proto.roi_calibration_period.CopyFrom(
182
+ backend.make_tensor_proto(np.array(obj.roi_calibration_period))
183
+ )
184
+ if obj.rf_roi_calibration_period is not None:
185
+ hyperparameters_proto.rf_roi_calibration_period.CopyFrom(
186
+ backend.make_tensor_proto(np.array(obj.rf_roi_calibration_period))
187
+ )
188
+ if obj.holdout_id is not None:
189
+ hyperparameters_proto.holdout_id.CopyFrom(
190
+ backend.make_tensor_proto(np.array(obj.holdout_id))
191
+ )
192
+ if obj.control_population_scaling_id is not None:
193
+ hyperparameters_proto.control_population_scaling_id.CopyFrom(
194
+ backend.make_tensor_proto(np.array(obj.control_population_scaling_id))
195
+ )
196
+ if obj.non_media_population_scaling_id is not None:
197
+ hyperparameters_proto.non_media_population_scaling_id.CopyFrom(
198
+ backend.make_tensor_proto(
199
+ np.array(obj.non_media_population_scaling_id)
200
+ )
201
+ )
202
+
203
+ if isinstance(obj.adstock_decay_spec, str):
204
+ hyperparameters_proto.global_adstock_decay = obj.adstock_decay_spec
205
+ elif isinstance(obj.adstock_decay_spec, dict):
206
+ hyperparameters_proto.adstock_decay_by_channel.channel_decays.update(
207
+ obj.adstock_decay_spec
208
+ )
209
+
210
+ if obj.non_media_baseline_values is not None:
211
+ for value in obj.non_media_baseline_values:
212
+ value_proto = hyperparameters_proto.non_media_baseline_values.add()
213
+ if isinstance(value, str):
214
+ if value.lower() == "min":
215
+ value_proto.function_value = _NonMediaBaselineFunction.MIN
216
+ elif value.lower() == "max":
217
+ value_proto.function_value = _NonMediaBaselineFunction.MAX
218
+ elif isinstance(value, (float, int)):
219
+ value_proto.value = float(value)
220
+
221
+ return hyperparameters_proto
222
+
223
+ def deserialize(
224
+ self,
225
+ serialized: meridian_pb.Hyperparameters,
226
+ serialized_version: str = "",
227
+ ) -> spec.ModelSpec:
228
+ """Deserializes the given `Hyperparameters` proto into a ModelSpec.
229
+
230
+ Note that this only deserializes the Hyperparameters part of ModelSpec.
231
+ The 'prior' attribute of ModelSpec is deserialized separately
232
+ using DistributionSerde and should be combined in the MeridianSerde.
233
+
234
+ Args:
235
+ serialized: The serialized `Hyperparameters` proto.
236
+ serialized_version: The version of the serialized model. This is used to
237
+ handle changes in deserialization logic across different versions.
238
+
239
+ Returns:
240
+ A Meridian model spec container.
241
+ """
242
+ baseline_geo = None
243
+ baseline_geo_field = serialized.WhichOneof(sc.BASELINE_GEO_ONEOF)
244
+ if baseline_geo_field == sc.BASELINE_GEO_INT:
245
+ baseline_geo = serialized.baseline_geo_int
246
+ elif baseline_geo_field == sc.BASELINE_GEO_STRING:
247
+ baseline_geo = serialized.baseline_geo_string
248
+
249
+ knots = None
250
+ if serialized.knots:
251
+ if len(serialized.knots) == 1:
252
+ knots = serialized.knots[0]
253
+ else:
254
+ knots = list(serialized.knots)
255
+
256
+ max_lag = serialized.max_lag if serialized.HasField(c.MAX_LAG) else None
257
+
258
+ roi_calibration_period = (
259
+ backend.make_ndarray(serialized.roi_calibration_period)
260
+ if serialized.HasField(c.ROI_CALIBRATION_PERIOD)
261
+ else None
262
+ )
263
+ rf_roi_calibration_period = (
264
+ backend.make_ndarray(serialized.rf_roi_calibration_period)
265
+ if serialized.HasField(c.RF_ROI_CALIBRATION_PERIOD)
266
+ else None
267
+ )
268
+
269
+ holdout_id = (
270
+ backend.make_ndarray(serialized.holdout_id)
271
+ if serialized.HasField(sc.HOLDOUT_ID)
272
+ else None
273
+ )
274
+
275
+ control_population_scaling_id = (
276
+ backend.make_ndarray(serialized.control_population_scaling_id)
277
+ if serialized.HasField(sc.CONTROL_POPULATION_SCALING_ID)
278
+ else None
279
+ )
280
+
281
+ non_media_population_scaling_id = (
282
+ backend.make_ndarray(serialized.non_media_population_scaling_id)
283
+ if serialized.HasField(sc.NON_MEDIA_POPULATION_SCALING_ID)
284
+ else None
285
+ )
286
+
287
+ non_media_baseline_values = None
288
+ if serialized.non_media_baseline_values:
289
+ non_media_baseline_values = []
290
+ for value_proto in serialized.non_media_baseline_values:
291
+ field = value_proto.WhichOneof("non_media_baseline_value")
292
+ if field == "value":
293
+ non_media_baseline_values.append(value_proto.value)
294
+ elif field == "function_value":
295
+ if value_proto.function_value == _NonMediaBaselineFunction.MIN:
296
+ non_media_baseline_values.append("min")
297
+ elif value_proto.function_value == _NonMediaBaselineFunction.MAX:
298
+ non_media_baseline_values.append("max")
299
+ elif (
300
+ value_proto.function_value
301
+ == _NonMediaBaselineFunction.NON_MEDIA_BASELINE_FUNCTION_UNSPECIFIED
302
+ ):
303
+ warnings.warn(
304
+ "Non-media baseline function value is unspecified. Resolving to"
305
+ " 'min'."
306
+ )
307
+ non_media_baseline_values.append("min")
308
+ else:
309
+ raise ValueError(
310
+ "Unsupported NonMediaBaselineFunction proto enum value:"
311
+ f" {value_proto.function_value}."
312
+ )
313
+ else:
314
+ raise ValueError(
315
+ f"Unsupported NonMediaBaselineValue proto enum value: {field}."
316
+ )
317
+
318
+ adstock_decay_spec_field = serialized.WhichOneof(sc.ADSTOCK_DECAY_SPEC)
319
+ if adstock_decay_spec_field == sc.GLOBAL_ADSTOCK_DECAY:
320
+ adstock_decay_spec = serialized.global_adstock_decay
321
+ elif adstock_decay_spec_field == sc.ADSTOCK_DECAY_BY_CHANNEL:
322
+ adstock_decay_spec = dict(
323
+ serialized.adstock_decay_by_channel.channel_decays
324
+ )
325
+ else:
326
+ adstock_decay_spec = sc.DEFAULT_DECAY
327
+
328
+ return spec.ModelSpec(
329
+ media_effects_dist=_proto_enum_to_media_effects_dist(
330
+ serialized.media_effects_dist
331
+ ),
332
+ hill_before_adstock=serialized.hill_before_adstock,
333
+ max_lag=max_lag,
334
+ unique_sigma_for_each_geo=serialized.unique_sigma_for_each_geo,
335
+ media_prior_type=_proto_enum_to_paid_media_prior_type(
336
+ serialized.media_prior_type
337
+ ),
338
+ rf_prior_type=_proto_enum_to_paid_media_prior_type(
339
+ serialized.rf_prior_type
340
+ ),
341
+ paid_media_prior_type=_proto_enum_to_paid_media_prior_type(
342
+ serialized.paid_media_prior_type
343
+ ),
344
+ organic_media_prior_type=_proto_enum_to_non_paid_prior_type(
345
+ serialized.organic_media_prior_type
346
+ ),
347
+ organic_rf_prior_type=_proto_enum_to_non_paid_prior_type(
348
+ serialized.organic_rf_prior_type
349
+ ),
350
+ non_media_treatments_prior_type=_proto_enum_to_non_paid_prior_type(
351
+ serialized.non_media_treatments_prior_type
352
+ ),
353
+ non_media_baseline_values=non_media_baseline_values,
354
+ knots=knots,
355
+ enable_aks=serialized.enable_aks,
356
+ baseline_geo=baseline_geo,
357
+ roi_calibration_period=roi_calibration_period,
358
+ rf_roi_calibration_period=rf_roi_calibration_period,
359
+ holdout_id=holdout_id,
360
+ control_population_scaling_id=control_population_scaling_id,
361
+ non_media_population_scaling_id=non_media_population_scaling_id,
362
+ adstock_decay_spec=adstock_decay_spec,
363
+ )