google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Constants shared across the Meridian serde library."""
|
|
16
|
+
|
|
17
|
+
# Constants for hyperparameters protobuf structure
|
|
18
|
+
BASELINE_GEO_ONEOF = 'baseline_geo_oneof'
|
|
19
|
+
BASELINE_GEO_INT = 'baseline_geo_int'
|
|
20
|
+
BASELINE_GEO_STRING = 'baseline_geo_string'
|
|
21
|
+
CONTROL_POPULATION_SCALING_ID = 'control_population_scaling_id'
|
|
22
|
+
HOLDOUT_ID = 'holdout_id'
|
|
23
|
+
NON_MEDIA_POPULATION_SCALING_ID = 'non_media_population_scaling_id'
|
|
24
|
+
ADSTOCK_DECAY_SPEC = 'adstock_decay_spec'
|
|
25
|
+
GLOBAL_ADSTOCK_DECAY = 'global_adstock_decay'
|
|
26
|
+
ADSTOCK_DECAY_BY_CHANNEL = 'adstock_decay_by_channel'
|
|
27
|
+
DEFAULT_DECAY = 'geometric'
|
|
28
|
+
|
|
29
|
+
# Constants for marketing data protobuf structure
|
|
30
|
+
GEO_INFO = 'geo_info'
|
|
31
|
+
METADATA = 'metadata'
|
|
32
|
+
REACH_FREQUENCY = 'reach_frequency'
|
|
33
|
+
|
|
34
|
+
# Constants for distribution protobuf structure
|
|
35
|
+
DISTRIBUTION_TYPE = 'distribution_type'
|
|
36
|
+
BATCH_BROADCAST_DISTRIBUTION = 'batch_broadcast'
|
|
37
|
+
DETERMINISTIC_DISTRIBUTION = 'deterministic'
|
|
38
|
+
HALF_NORMAL_DISTRIBUTION = 'half_normal'
|
|
39
|
+
LOG_NORMAL_DISTRIBUTION = 'log_normal'
|
|
40
|
+
NORMAL_DISTRIBUTION = 'normal'
|
|
41
|
+
TRANSFORMED_DISTRIBUTION = 'transformed'
|
|
42
|
+
TRUNCATED_NORMAL_DISTRIBUTION = 'truncated_normal'
|
|
43
|
+
UNIFORM_DISTRIBUTION = 'uniform'
|
|
44
|
+
BETA_DISTRIBUTION = 'beta'
|
|
45
|
+
BIJECTOR_TYPE = 'bijector_type'
|
|
46
|
+
SHIFT_BIJECTOR = 'shift'
|
|
47
|
+
SCALE_BIJECTOR = 'scale'
|
|
48
|
+
RECIPROCAL_BIJECTOR = 'reciprocal'
|
|
@@ -0,0 +1,515 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Serialization and deserialization of `Distribution` objects for priors."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import inspect
|
|
20
|
+
import types
|
|
21
|
+
from typing import Any, Sequence, TypeVar
|
|
22
|
+
import warnings
|
|
23
|
+
|
|
24
|
+
from meridian import backend
|
|
25
|
+
from meridian import constants
|
|
26
|
+
from meridian.model import prior_distribution as pd
|
|
27
|
+
from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
|
|
28
|
+
from schema.serde import constants as sc
|
|
29
|
+
from schema.serde import function_registry as function_registry_utils
|
|
30
|
+
from schema.serde import serde
|
|
31
|
+
|
|
32
|
+
from tensorflow.core.framework import tensor_shape_pb2 # pylint: disable=g-direct-tensorflow-import
|
|
33
|
+
|
|
34
|
+
FunctionRegistry = function_registry_utils.FunctionRegistry
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
MeridianPriorDistributions = (
|
|
38
|
+
meridian_pb.PriorTfpDistributions
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
_FUNCTION_REGISTRY_NAME = "function_registry"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# TODO: b/436637084 - Delete enumerated schema.
|
|
45
|
+
class DistributionSerde(
|
|
46
|
+
serde.Serde[MeridianPriorDistributions, pd.PriorDistribution]
|
|
47
|
+
):
|
|
48
|
+
"""Serializes and deserializes a Meridian prior distributions container into a `Distribution` proto."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, function_registry: FunctionRegistry):
|
|
51
|
+
"""Initializes a `DistributionSerde` instance.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
function_registry: A lookup table containing custom functions used by
|
|
55
|
+
various `backend.tfd` classes. It's recommended to explicitly define the
|
|
56
|
+
custom functions instead of using lambdas, as lambda functions may not
|
|
57
|
+
be registered successfully.
|
|
58
|
+
"""
|
|
59
|
+
self._function_registry = function_registry
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def function_registry(self) -> FunctionRegistry:
|
|
63
|
+
return self._function_registry
|
|
64
|
+
|
|
65
|
+
def serialize(
|
|
66
|
+
self, obj: pd.PriorDistribution
|
|
67
|
+
) -> meridian_pb.PriorTfpDistributions:
|
|
68
|
+
"""Serializes the given Meridian priors container into a `MeridianPriorDistributions` proto."""
|
|
69
|
+
proto = meridian_pb.PriorTfpDistributions()
|
|
70
|
+
for param in constants.ALL_PRIOR_DISTRIBUTION_PARAMETERS:
|
|
71
|
+
if not hasattr(obj, param):
|
|
72
|
+
continue
|
|
73
|
+
getattr(proto, param).CopyFrom(
|
|
74
|
+
self._to_distribution_proto(getattr(obj, param))
|
|
75
|
+
)
|
|
76
|
+
proto.function_registry.update(self.function_registry.hashed_registry)
|
|
77
|
+
return proto
|
|
78
|
+
|
|
79
|
+
def deserialize(
|
|
80
|
+
self,
|
|
81
|
+
serialized: MeridianPriorDistributions,
|
|
82
|
+
serialized_version: str = "",
|
|
83
|
+
force_deserialization: bool = False,
|
|
84
|
+
) -> pd.PriorDistribution:
|
|
85
|
+
"""Deserializes the `PriorTfpDistributions` proto.
|
|
86
|
+
|
|
87
|
+
WARNING: If any custom functions in the function registry are modified after
|
|
88
|
+
serialization, the deserialized model can differ from the original model, as
|
|
89
|
+
the original function's behavior is no longer guaranteed. This will result
|
|
90
|
+
in an error during deserialization.
|
|
91
|
+
|
|
92
|
+
For users who are intentionally changing functions and are confident that
|
|
93
|
+
the changes will not affect the deserialized model, you can bypass safety
|
|
94
|
+
mechanisms to force deserialization. See example:
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
serialized: A serialized `PriorDistributions` object.
|
|
98
|
+
serialized_version: The version of the serialized Meridian model. This is
|
|
99
|
+
used to handle changes in deserialization logic across different
|
|
100
|
+
versions.
|
|
101
|
+
force_deserialization: If True, bypasses the safety check that validates
|
|
102
|
+
whether functions within `function_registry` have changed after
|
|
103
|
+
serialization. Use with caution. This should only be used if you have
|
|
104
|
+
intentionally modified a custom function and are confident that the
|
|
105
|
+
changes will not affect the deserialized model. A safer alternative is
|
|
106
|
+
to first deserialize the model with the original functions and then
|
|
107
|
+
serialize it with the new ones.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
A deserialized `PriorDistribution` object.
|
|
111
|
+
"""
|
|
112
|
+
kwargs = {}
|
|
113
|
+
for param in constants.ALL_PRIOR_DISTRIBUTION_PARAMETERS:
|
|
114
|
+
if not hasattr(serialized, param):
|
|
115
|
+
continue
|
|
116
|
+
# A parameter may be unspecified in a serialized proto message because:
|
|
117
|
+
# (1) It is left unset for Meridian to set its default value.
|
|
118
|
+
# (2) The message was created from a previous Meridian version after
|
|
119
|
+
# introducing a new parameter.
|
|
120
|
+
if not serialized.HasField(param):
|
|
121
|
+
continue
|
|
122
|
+
param_name = getattr(serialized, param)
|
|
123
|
+
if isinstance(serialized, meridian_pb.PriorTfpDistributions):
|
|
124
|
+
if force_deserialization:
|
|
125
|
+
warnings.warn(
|
|
126
|
+
"You're attempting to deserialize a model while ignoring changes"
|
|
127
|
+
" to custom functions. This is a risky operation that can"
|
|
128
|
+
" potentially lead to a deserialized model that behaves"
|
|
129
|
+
" differently from the original, resulting in unexpected behavior"
|
|
130
|
+
" or model failure. We strongly recommend a safer two-step"
|
|
131
|
+
" process: deserialize the model using the original function"
|
|
132
|
+
" registry and reserialize the model using the updated registry."
|
|
133
|
+
" Please proceed with caution."
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
stored_hashed_function_registry = getattr(
|
|
137
|
+
serialized, _FUNCTION_REGISTRY_NAME
|
|
138
|
+
)
|
|
139
|
+
try:
|
|
140
|
+
self.function_registry.validate(stored_hashed_function_registry)
|
|
141
|
+
except ValueError as e:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"An issue found during deserializing Distribution: {e}"
|
|
144
|
+
) from e
|
|
145
|
+
|
|
146
|
+
kwargs[param] = self._from_distribution_proto(param_name)
|
|
147
|
+
# copybara: strip_begin(legacy proto)
|
|
148
|
+
elif isinstance(serialized, meridian_pb.PriorDistributions):
|
|
149
|
+
kwargs[param] = _from_legacy_distribution_proto(param_name)
|
|
150
|
+
# copybara: strip_end
|
|
151
|
+
return pd.PriorDistribution(**kwargs)
|
|
152
|
+
|
|
153
|
+
def _to_distribution_proto(
|
|
154
|
+
self,
|
|
155
|
+
dist: backend.tfd.Distribution,
|
|
156
|
+
) -> meridian_pb.TfpDistribution:
|
|
157
|
+
"""Converts a TensorFlow `Distribution` object to a `TfpDistribution` proto."""
|
|
158
|
+
dist_name = type(dist).__name__
|
|
159
|
+
dist_class = getattr(backend.tfd, dist_name)
|
|
160
|
+
return meridian_pb.TfpDistribution(
|
|
161
|
+
distribution_type=dist_name,
|
|
162
|
+
parameters={
|
|
163
|
+
name: self._to_parameter_value_proto(name, value, dist_class)
|
|
164
|
+
for name, value in dist.parameters.items()
|
|
165
|
+
},
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _to_bijector_proto(
|
|
169
|
+
self,
|
|
170
|
+
bijector: backend.bijectors.Bijector,
|
|
171
|
+
) -> meridian_pb.TfpBijector:
|
|
172
|
+
"""Converts a TensorFlow `Bijector` object to a `TfpBijector` proto."""
|
|
173
|
+
bij_name = type(bijector).__name__
|
|
174
|
+
bij_class = getattr(backend.bijectors, bij_name)
|
|
175
|
+
return meridian_pb.TfpBijector(
|
|
176
|
+
bijector_type=bij_name,
|
|
177
|
+
parameters={
|
|
178
|
+
name: self._to_parameter_value_proto(name, value, bij_class)
|
|
179
|
+
for name, value in bijector.parameters.items()
|
|
180
|
+
},
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def _to_parameter_value_proto(
|
|
184
|
+
self,
|
|
185
|
+
param_name: str,
|
|
186
|
+
value: Any,
|
|
187
|
+
dist: backend.tfd.Distribution | backend.bijectors.Bijector,
|
|
188
|
+
) -> meridian_pb.TfpParameterValue:
|
|
189
|
+
"""Converts a TensorFlow `Distribution` parameter value to a `TfpParameterValue` proto."""
|
|
190
|
+
# Handle built-in types.
|
|
191
|
+
match value:
|
|
192
|
+
case float():
|
|
193
|
+
return meridian_pb.TfpParameterValue(scalar_value=value)
|
|
194
|
+
case int():
|
|
195
|
+
return meridian_pb.TfpParameterValue(int_value=value)
|
|
196
|
+
case bool():
|
|
197
|
+
return meridian_pb.TfpParameterValue(bool_value=value)
|
|
198
|
+
case str():
|
|
199
|
+
return meridian_pb.TfpParameterValue(string_value=value)
|
|
200
|
+
case None:
|
|
201
|
+
return meridian_pb.TfpParameterValue(none_value=True)
|
|
202
|
+
case list():
|
|
203
|
+
value_generator = (
|
|
204
|
+
self._to_parameter_value_proto(param_name, v, dist) for v in value
|
|
205
|
+
)
|
|
206
|
+
return meridian_pb.TfpParameterValue(
|
|
207
|
+
list_value=meridian_pb.TfpParameterValue.List(
|
|
208
|
+
values=value_generator
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
case dict():
|
|
212
|
+
dict_value = {
|
|
213
|
+
k: self._to_parameter_value_proto(param_name, v, dist)
|
|
214
|
+
for k, v in value.items()
|
|
215
|
+
}
|
|
216
|
+
return meridian_pb.TfpParameterValue(
|
|
217
|
+
dict_value=meridian_pb.TfpParameterValue.Dict(value_map=dict_value)
|
|
218
|
+
)
|
|
219
|
+
case backend.Tensor():
|
|
220
|
+
return meridian_pb.TfpParameterValue(
|
|
221
|
+
tensor_value=backend.make_tensor_proto(value)
|
|
222
|
+
)
|
|
223
|
+
case backend.tfd.Distribution():
|
|
224
|
+
return meridian_pb.TfpParameterValue(
|
|
225
|
+
distribution_value=self._to_distribution_proto(value)
|
|
226
|
+
)
|
|
227
|
+
case backend.bijectors.Bijector():
|
|
228
|
+
return meridian_pb.TfpParameterValue(
|
|
229
|
+
bijector_value=self._to_bijector_proto(value)
|
|
230
|
+
)
|
|
231
|
+
case backend.tfd.ReparameterizationType():
|
|
232
|
+
fully_reparameterized = value == backend.tfd.FULLY_REPARAMETERIZED
|
|
233
|
+
return meridian_pb.TfpParameterValue(
|
|
234
|
+
fully_reparameterized=fully_reparameterized
|
|
235
|
+
)
|
|
236
|
+
case types.FunctionType():
|
|
237
|
+
# Check for default value
|
|
238
|
+
signature = inspect.signature(dist.__init__)
|
|
239
|
+
param = signature.parameters[param_name]
|
|
240
|
+
if param.default and param.default is value:
|
|
241
|
+
return meridian_pb.TfpParameterValue(
|
|
242
|
+
function_param=meridian_pb.TfpParameterValue.FunctionParam(
|
|
243
|
+
uses_default=True
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Check against registry.
|
|
248
|
+
function_key = self.function_registry.get_function_key(value)
|
|
249
|
+
if function_key is not None:
|
|
250
|
+
return meridian_pb.TfpParameterValue(
|
|
251
|
+
function_param=meridian_pb.TfpParameterValue.FunctionParam(
|
|
252
|
+
function_key=function_key
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Custom function `{param_name}` detected for"
|
|
257
|
+
f" {type(dist).__name__}, but not found in registry. Please"
|
|
258
|
+
" add custom functions to registry when saving models."
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Handle unsupported types.
|
|
262
|
+
raise TypeError(f"Unsupported type: {type(value)}, {value}")
|
|
263
|
+
|
|
264
|
+
def _from_distribution_proto(
|
|
265
|
+
self,
|
|
266
|
+
dist_proto: meridian_pb.TfpDistribution,
|
|
267
|
+
) -> backend.tfd.Distribution:
|
|
268
|
+
"""Converts a `Distribution` proto to a TensorFlow `Distribution` object."""
|
|
269
|
+
dist_class_name = dist_proto.distribution_type
|
|
270
|
+
dist_class = getattr(backend.tfd, dist_class_name)
|
|
271
|
+
dist_parameters = dist_proto.parameters
|
|
272
|
+
input_parameters = {
|
|
273
|
+
k: self._unpack_tfp_parameters(k, v, dist_class)
|
|
274
|
+
for k, v in dist_parameters.items()
|
|
275
|
+
}
|
|
276
|
+
return dist_class(**input_parameters)
|
|
277
|
+
|
|
278
|
+
def _from_bijector_proto(
|
|
279
|
+
self,
|
|
280
|
+
dist_proto: meridian_pb.TfpBijector,
|
|
281
|
+
) -> backend.bijectors.Bijector:
|
|
282
|
+
"""Converts a `Bijector` proto to a TensorFlow `Bijector` object."""
|
|
283
|
+
dist_class_name = dist_proto.bijector_type
|
|
284
|
+
dist_class = getattr(backend.bijectors, dist_class_name)
|
|
285
|
+
dist_parameters = dist_proto.parameters
|
|
286
|
+
input_parameters = {
|
|
287
|
+
name: self._unpack_tfp_parameters(name, value, dist_class)
|
|
288
|
+
for name, value in dist_parameters.items()
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
return dist_class(**input_parameters)
|
|
292
|
+
|
|
293
|
+
def _unpack_tfp_parameters(
|
|
294
|
+
self,
|
|
295
|
+
param_name: str,
|
|
296
|
+
param_value: meridian_pb.TfpParameterValue,
|
|
297
|
+
dist_class: backend.tfd.Distribution,
|
|
298
|
+
) -> Any:
|
|
299
|
+
"""Unpacks a `TfpParameterValue` proto into a Python value."""
|
|
300
|
+
match param_value.WhichOneof("value_type"):
|
|
301
|
+
# Handle built-in types.
|
|
302
|
+
case "scalar_value":
|
|
303
|
+
return param_value.scalar_value
|
|
304
|
+
case "int_value":
|
|
305
|
+
return param_value.int_value
|
|
306
|
+
case "bool_value":
|
|
307
|
+
return param_value.bool_value
|
|
308
|
+
case "string_value":
|
|
309
|
+
return param_value.string_value
|
|
310
|
+
case "none_value":
|
|
311
|
+
return None
|
|
312
|
+
case "list_value":
|
|
313
|
+
return [
|
|
314
|
+
self._unpack_tfp_parameters(param_name, v, dist_class)
|
|
315
|
+
for v in param_value.list_value.values
|
|
316
|
+
]
|
|
317
|
+
case "dict_value":
|
|
318
|
+
items = param_value.dict_value.value_map.items()
|
|
319
|
+
return {
|
|
320
|
+
key: self._unpack_tfp_parameters(key, value, dist_class)
|
|
321
|
+
for key, value in items
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
# Handle custom types.
|
|
325
|
+
case "tensor_value":
|
|
326
|
+
return backend.to_tensor(backend.make_ndarray(param_value.tensor_value))
|
|
327
|
+
case "distribution_value":
|
|
328
|
+
return self._from_distribution_proto(param_value.distribution_value)
|
|
329
|
+
case "bijector_value":
|
|
330
|
+
return self._from_bijector_proto(param_value.bijector_value)
|
|
331
|
+
case "fully_reparameterized":
|
|
332
|
+
if param_value.fully_reparameterized:
|
|
333
|
+
return backend.tfd.FULLY_REPARAMETERIZED
|
|
334
|
+
else:
|
|
335
|
+
return backend.tfd.NOT_FULLY_REPARAMETERIZED
|
|
336
|
+
|
|
337
|
+
# Handle functions.
|
|
338
|
+
case "function_param":
|
|
339
|
+
function_param = param_value.function_param
|
|
340
|
+
# Check against registry.
|
|
341
|
+
if function_param.HasField("function_key"):
|
|
342
|
+
registry = self.function_registry
|
|
343
|
+
if function_param.function_key in registry:
|
|
344
|
+
return registry.get(function_param.function_key)
|
|
345
|
+
# Check for default value.
|
|
346
|
+
if (
|
|
347
|
+
function_param.HasField("uses_default")
|
|
348
|
+
and function_param.uses_default
|
|
349
|
+
):
|
|
350
|
+
signature = inspect.signature(dist_class.__init__)
|
|
351
|
+
return signature.parameters[param_name].default
|
|
352
|
+
raise ValueError(f"No function found for {param_name}")
|
|
353
|
+
|
|
354
|
+
# Handle unsupported types.
|
|
355
|
+
case _:
|
|
356
|
+
raise ValueError(
|
|
357
|
+
f"Unsupported TFP distribution parameter type: {type(param_value)}"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# copybara: strip_begin
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _from_legacy_bijector_proto(
|
|
364
|
+
bijector_proto: meridian_pb.Distribution.Bijector,
|
|
365
|
+
) -> backend.bijectors.Bijector:
|
|
366
|
+
"""Converts a `Bijector` proto to a `Bijector` object."""
|
|
367
|
+
bijector_type_field = bijector_proto.WhichOneof(sc.BIJECTOR_TYPE)
|
|
368
|
+
match bijector_type_field:
|
|
369
|
+
case sc.SHIFT_BIJECTOR:
|
|
370
|
+
return backend.bijectors.Shift(
|
|
371
|
+
shift=_deserialize_sequence(bijector_proto.shift.shifts)
|
|
372
|
+
)
|
|
373
|
+
case sc.SCALE_BIJECTOR:
|
|
374
|
+
return backend.bijectors.Scale(
|
|
375
|
+
scale=_deserialize_sequence(bijector_proto.scale.scales),
|
|
376
|
+
log_scale=_deserialize_sequence(bijector_proto.scale.log_scales),
|
|
377
|
+
)
|
|
378
|
+
case sc.RECIPROCAL_BIJECTOR:
|
|
379
|
+
return backend.bijectors.Reciprocal()
|
|
380
|
+
case _:
|
|
381
|
+
raise ValueError(
|
|
382
|
+
f"Unsupported Bijector proto type: {bijector_type_field};"
|
|
383
|
+
f" Bijector proto:\n{bijector_proto}"
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _from_legacy_distribution_proto(
|
|
388
|
+
dist_proto: meridian_pb.Distribution,
|
|
389
|
+
) -> backend.tfd.Distribution:
|
|
390
|
+
"""Converts a `Distribution` proto to a `Distribution` object."""
|
|
391
|
+
dist_type_field = dist_proto.WhichOneof(sc.DISTRIBUTION_TYPE)
|
|
392
|
+
match dist_type_field:
|
|
393
|
+
case sc.BATCH_BROADCAST_DISTRIBUTION:
|
|
394
|
+
return backend.tfd.BatchBroadcast(
|
|
395
|
+
name=dist_proto.name,
|
|
396
|
+
distribution=_from_legacy_distribution_proto(
|
|
397
|
+
dist_proto.batch_broadcast.distribution
|
|
398
|
+
),
|
|
399
|
+
with_shape=_from_shape_proto(dist_proto.batch_broadcast.batch_shape),
|
|
400
|
+
)
|
|
401
|
+
case sc.TRANSFORMED_DISTRIBUTION:
|
|
402
|
+
return backend.tfd.TransformedDistribution(
|
|
403
|
+
name=dist_proto.name,
|
|
404
|
+
distribution=_from_legacy_distribution_proto(
|
|
405
|
+
dist_proto.transformed.distribution
|
|
406
|
+
),
|
|
407
|
+
bijector=_from_legacy_bijector_proto(dist_proto.transformed.bijector),
|
|
408
|
+
)
|
|
409
|
+
case sc.DETERMINISTIC_DISTRIBUTION:
|
|
410
|
+
return backend.tfd.Deterministic(
|
|
411
|
+
name=dist_proto.name,
|
|
412
|
+
loc=_deserialize_sequence(dist_proto.deterministic.locs),
|
|
413
|
+
)
|
|
414
|
+
case sc.HALF_NORMAL_DISTRIBUTION:
|
|
415
|
+
return backend.tfd.HalfNormal(
|
|
416
|
+
name=dist_proto.name,
|
|
417
|
+
scale=_deserialize_sequence(dist_proto.half_normal.scales),
|
|
418
|
+
)
|
|
419
|
+
case sc.LOG_NORMAL_DISTRIBUTION:
|
|
420
|
+
return backend.tfd.LogNormal(
|
|
421
|
+
name=dist_proto.name,
|
|
422
|
+
loc=_deserialize_sequence(dist_proto.log_normal.locs),
|
|
423
|
+
scale=_deserialize_sequence(dist_proto.log_normal.scales),
|
|
424
|
+
)
|
|
425
|
+
case sc.NORMAL_DISTRIBUTION:
|
|
426
|
+
return backend.tfd.Normal(
|
|
427
|
+
name=dist_proto.name,
|
|
428
|
+
loc=_deserialize_sequence(dist_proto.normal.locs),
|
|
429
|
+
scale=_deserialize_sequence(dist_proto.normal.scales),
|
|
430
|
+
)
|
|
431
|
+
case sc.TRUNCATED_NORMAL_DISTRIBUTION:
|
|
432
|
+
if (
|
|
433
|
+
hasattr(dist_proto.truncated_normal, "lows")
|
|
434
|
+
and dist_proto.truncated_normal.lows
|
|
435
|
+
):
|
|
436
|
+
if dist_proto.truncated_normal.low:
|
|
437
|
+
_show_warning("low", "TruncatedNormal")
|
|
438
|
+
low = _deserialize_sequence(dist_proto.truncated_normal.lows)
|
|
439
|
+
else:
|
|
440
|
+
low = dist_proto.truncated_normal.low
|
|
441
|
+
|
|
442
|
+
if (
|
|
443
|
+
hasattr(dist_proto.truncated_normal, "highs")
|
|
444
|
+
and dist_proto.truncated_normal.highs
|
|
445
|
+
):
|
|
446
|
+
if dist_proto.truncated_normal.high:
|
|
447
|
+
_show_warning("high", "TruncatedNormal")
|
|
448
|
+
high = _deserialize_sequence(dist_proto.truncated_normal.highs)
|
|
449
|
+
else:
|
|
450
|
+
high = dist_proto.truncated_normal.high
|
|
451
|
+
return backend.tfd.TruncatedNormal(
|
|
452
|
+
name=dist_proto.name,
|
|
453
|
+
loc=_deserialize_sequence(dist_proto.truncated_normal.locs),
|
|
454
|
+
scale=_deserialize_sequence(dist_proto.truncated_normal.scales),
|
|
455
|
+
low=low,
|
|
456
|
+
high=high,
|
|
457
|
+
)
|
|
458
|
+
case sc.UNIFORM_DISTRIBUTION:
|
|
459
|
+
if hasattr(dist_proto.uniform, "lows") and dist_proto.uniform.lows:
|
|
460
|
+
if dist_proto.uniform.low:
|
|
461
|
+
_show_warning("low", "Uniform")
|
|
462
|
+
low = _deserialize_sequence(dist_proto.uniform.lows)
|
|
463
|
+
else:
|
|
464
|
+
low = dist_proto.uniform.low
|
|
465
|
+
|
|
466
|
+
if hasattr(dist_proto.uniform, "highs") and dist_proto.uniform.highs:
|
|
467
|
+
if dist_proto.uniform.high:
|
|
468
|
+
_show_warning("high", "Uniform")
|
|
469
|
+
high = _deserialize_sequence(dist_proto.uniform.highs)
|
|
470
|
+
else:
|
|
471
|
+
high = dist_proto.uniform.high
|
|
472
|
+
|
|
473
|
+
return backend.tfd.Uniform(
|
|
474
|
+
name=dist_proto.name,
|
|
475
|
+
low=low,
|
|
476
|
+
high=high,
|
|
477
|
+
)
|
|
478
|
+
case sc.BETA_DISTRIBUTION:
|
|
479
|
+
return backend.tfd.Beta(
|
|
480
|
+
name=dist_proto.name,
|
|
481
|
+
concentration1=_deserialize_sequence(dist_proto.beta.alpha),
|
|
482
|
+
concentration0=_deserialize_sequence(dist_proto.beta.beta),
|
|
483
|
+
)
|
|
484
|
+
case _:
|
|
485
|
+
raise ValueError(
|
|
486
|
+
f"Unsupported Distribution proto type: {dist_type_field};"
|
|
487
|
+
f" Distribution proto:\n{dist_proto}"
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def _show_warning(field_name: str, dist_name: str) -> None:
|
|
492
|
+
warnings.warn(
|
|
493
|
+
f"Both `{field_name}s` and `{field_name}` are specified in"
|
|
494
|
+
f" {dist_name} distribution proto. Prioritizing `{field_name}s` since"
|
|
495
|
+
f" `{field_name}` is deprecated.",
|
|
496
|
+
DeprecationWarning,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _from_shape_proto(
|
|
501
|
+
shape_proto: tensor_shape_pb2.TensorShapeProto,
|
|
502
|
+
) -> backend.TensorShapeInstance:
|
|
503
|
+
"""Converts a `TensorShapeProto` to a `TensorShape`."""
|
|
504
|
+
return backend.TensorShape([dim.size for dim in shape_proto.dim])
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
T = TypeVar("T")
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _deserialize_sequence(args: Sequence[T]) -> T | Sequence[T] | None:
|
|
511
|
+
if not args:
|
|
512
|
+
return None
|
|
513
|
+
return args[0] if len(args) == 1 else list(args)
|
|
514
|
+
|
|
515
|
+
# copybara: strip_end
|