google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Serialization and deserialization of Meridian models into/from proto format.
|
|
16
|
+
|
|
17
|
+
The `meridian_serde.MeridianSerde` class provides an interface for serializing
|
|
18
|
+
and deserializing Meridian models into and from an `MmmKernel` proto message.
|
|
19
|
+
|
|
20
|
+
The Meridian model--when serialized into an `MmmKernel` proto--is internally
|
|
21
|
+
represented as the sum of the following components:
|
|
22
|
+
|
|
23
|
+
1. Marketing data: This includes the KPI, media, and control data present in
|
|
24
|
+
the input data. They are structured into an MMM-agnostic `MarketingData`
|
|
25
|
+
proto message.
|
|
26
|
+
2. Meridian model: A `MeridianModel` proto message encapsulates
|
|
27
|
+
Meridian-specific model parameters, including hyperparameters, prior
|
|
28
|
+
distributions, and sampled inference data.
|
|
29
|
+
|
|
30
|
+
Sample usage:
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
from schema.serde import meridian_serde
|
|
34
|
+
|
|
35
|
+
serde = meridian_serde.MeridianSerde()
|
|
36
|
+
mmm = model.Meridian(...)
|
|
37
|
+
serialized_mmm = serde.serialize(mmm) # An `MmmKernel` proto
|
|
38
|
+
deserialized_mmm = serde.deserialize(serialized_mmm) # A `Meridian` object
|
|
39
|
+
```
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
import dataclasses
|
|
43
|
+
import os
|
|
44
|
+
import warnings
|
|
45
|
+
|
|
46
|
+
from google.protobuf import text_format
|
|
47
|
+
import meridian
|
|
48
|
+
from meridian import backend
|
|
49
|
+
from meridian.analysis import analyzer
|
|
50
|
+
from meridian.analysis import visualizer
|
|
51
|
+
from meridian.model import model
|
|
52
|
+
from mmm.v1.model import mmm_kernel_pb2 as kernel_pb
|
|
53
|
+
from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
|
|
54
|
+
from schema.serde import distribution
|
|
55
|
+
from schema.serde import eda_spec as eda_spec_serde
|
|
56
|
+
from schema.serde import function_registry as function_registry_utils
|
|
57
|
+
from schema.serde import hyperparameters
|
|
58
|
+
from schema.serde import inference_data
|
|
59
|
+
from schema.serde import marketing_data
|
|
60
|
+
from schema.serde import serde
|
|
61
|
+
import semver
|
|
62
|
+
|
|
63
|
+
from google.protobuf import any_pb2
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
_VERSION_INFO = semver.VersionInfo.parse(meridian.__version__)
|
|
67
|
+
|
|
68
|
+
FunctionRegistry = function_registry_utils.FunctionRegistry
|
|
69
|
+
|
|
70
|
+
_file_exists = os.path.exists
|
|
71
|
+
_make_dirs = os.makedirs
|
|
72
|
+
_file_open = open
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
|
|
76
|
+
"""Serializes and deserializes a Meridian model into an `MmmKernel` proto."""
|
|
77
|
+
|
|
78
|
+
def serialize(
|
|
79
|
+
self,
|
|
80
|
+
obj: model.Meridian,
|
|
81
|
+
model_id: str = '',
|
|
82
|
+
meridian_version: semver.VersionInfo = _VERSION_INFO,
|
|
83
|
+
include_convergence_info: bool = False,
|
|
84
|
+
distribution_function_registry: FunctionRegistry | None = None,
|
|
85
|
+
eda_function_registry: FunctionRegistry | None = None,
|
|
86
|
+
) -> kernel_pb.MmmKernel:
|
|
87
|
+
"""Serializes the given Meridian model into an `MmmKernel` proto.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
obj: The Meridian model to serialize.
|
|
91
|
+
model_id: The ID of the model.
|
|
92
|
+
meridian_version: The version of the Meridian model.
|
|
93
|
+
include_convergence_info: Whether to include convergence information.
|
|
94
|
+
distribution_function_registry: Optional. A lookup table that maps string
|
|
95
|
+
keys to custom functions to be used as parameters in various
|
|
96
|
+
`tfp.distributions`.
|
|
97
|
+
eda_function_registry: A lookup table that maps string keys to custom
|
|
98
|
+
functions to be used in `EDASpec`.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
An `MmmKernel` proto representing the serialized model.
|
|
102
|
+
"""
|
|
103
|
+
distribution_registry = (
|
|
104
|
+
distribution_function_registry
|
|
105
|
+
if distribution_function_registry is not None
|
|
106
|
+
else function_registry_utils.FunctionRegistry()
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
eda_function_registry = (
|
|
110
|
+
eda_function_registry
|
|
111
|
+
if eda_function_registry is not None
|
|
112
|
+
else function_registry_utils.FunctionRegistry()
|
|
113
|
+
)
|
|
114
|
+
meridian_model_proto = self._make_meridian_model_proto(
|
|
115
|
+
mmm=obj,
|
|
116
|
+
model_id=model_id,
|
|
117
|
+
meridian_version=meridian_version,
|
|
118
|
+
distribution_function_registry=distribution_registry,
|
|
119
|
+
eda_function_registry=eda_function_registry,
|
|
120
|
+
include_convergence_info=include_convergence_info,
|
|
121
|
+
)
|
|
122
|
+
any_model = any_pb2.Any()
|
|
123
|
+
any_model.Pack(meridian_model_proto)
|
|
124
|
+
return kernel_pb.MmmKernel(
|
|
125
|
+
marketing_data=marketing_data.MarketingDataSerde().serialize(
|
|
126
|
+
obj.input_data
|
|
127
|
+
),
|
|
128
|
+
model=any_model,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def _make_meridian_model_proto(
|
|
132
|
+
self,
|
|
133
|
+
mmm: model.Meridian,
|
|
134
|
+
model_id: str,
|
|
135
|
+
meridian_version: semver.VersionInfo,
|
|
136
|
+
distribution_function_registry: FunctionRegistry,
|
|
137
|
+
eda_function_registry: FunctionRegistry,
|
|
138
|
+
include_convergence_info: bool = False,
|
|
139
|
+
) -> meridian_pb.MeridianModel:
|
|
140
|
+
"""Constructs a MeridianModel proto from the TrainedModel.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
mmm: Meridian model.
|
|
144
|
+
model_id: The ID of the model.
|
|
145
|
+
meridian_version: The version of the Meridian model.
|
|
146
|
+
distribution_function_registry: A lookup table that maps string keys to
|
|
147
|
+
custom functions to be used as parameters in various
|
|
148
|
+
`tfp.distributions`.
|
|
149
|
+
eda_function_registry: A lookup table containing custom functions used by
|
|
150
|
+
`EDASpec` objects.
|
|
151
|
+
include_convergence_info: Whether to include convergence information.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A MeridianModel proto.
|
|
155
|
+
"""
|
|
156
|
+
model_proto = meridian_pb.MeridianModel(
|
|
157
|
+
model_id=model_id,
|
|
158
|
+
model_version=str(meridian_version),
|
|
159
|
+
hyperparameters=hyperparameters.HyperparametersSerde().serialize(
|
|
160
|
+
mmm.model_spec
|
|
161
|
+
),
|
|
162
|
+
prior_tfp_distributions=distribution.DistributionSerde(
|
|
163
|
+
distribution_function_registry
|
|
164
|
+
).serialize(mmm.model_spec.prior),
|
|
165
|
+
inference_data=inference_data.InferenceDataSerde().serialize(
|
|
166
|
+
mmm.inference_data
|
|
167
|
+
),
|
|
168
|
+
)
|
|
169
|
+
# For backwards compatibility, only serialize EDA spec if it exists.
|
|
170
|
+
if hasattr(mmm, 'eda_spec'):
|
|
171
|
+
model_proto.eda_spec.CopyFrom(
|
|
172
|
+
eda_spec_serde.EDASpecSerde(eda_function_registry).serialize(
|
|
173
|
+
mmm.eda_spec
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if include_convergence_info:
|
|
178
|
+
convergence_proto = self._make_model_convergence_proto(mmm)
|
|
179
|
+
if convergence_proto is not None:
|
|
180
|
+
model_proto.convergence_info.CopyFrom(convergence_proto)
|
|
181
|
+
|
|
182
|
+
return model_proto
|
|
183
|
+
|
|
184
|
+
def _make_model_convergence_proto(
|
|
185
|
+
self, mmm: model.Meridian
|
|
186
|
+
) -> meridian_pb.ModelConvergence | None:
|
|
187
|
+
"""Creates ModelConvergence proto."""
|
|
188
|
+
model_convergence_proto = meridian_pb.ModelConvergence()
|
|
189
|
+
try:
|
|
190
|
+
# NotFittedModelError can be raised below. If raised,
|
|
191
|
+
# return None. Otherwise, set convergence status based on
|
|
192
|
+
# MCMCSamplingError (caught in the except block).
|
|
193
|
+
rhats = analyzer.Analyzer(mmm).get_rhat()
|
|
194
|
+
rhat_proto = meridian_pb.RHatDiagnostic()
|
|
195
|
+
for name, tensor in rhats.items():
|
|
196
|
+
rhat_proto.parameter_r_hats.add(
|
|
197
|
+
name=name, tensor=backend.make_tensor_proto(tensor)
|
|
198
|
+
)
|
|
199
|
+
model_convergence_proto.r_hat_diagnostic.CopyFrom(rhat_proto)
|
|
200
|
+
|
|
201
|
+
visualizer.ModelDiagnostics(mmm).plot_rhat_boxplot()
|
|
202
|
+
model_convergence_proto.convergence = True
|
|
203
|
+
except model.MCMCSamplingError:
|
|
204
|
+
model_convergence_proto.convergence = False
|
|
205
|
+
except model.NotFittedModelError:
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
if hasattr(mmm.inference_data, 'trace'):
|
|
209
|
+
trace = mmm.inference_data.trace
|
|
210
|
+
mcmc_sampling_trace = meridian_pb.McmcSamplingTrace(
|
|
211
|
+
num_chains=len(trace.chain),
|
|
212
|
+
num_draws=len(trace.draw),
|
|
213
|
+
step_size=backend.make_tensor_proto(trace.step_size),
|
|
214
|
+
tune=backend.make_tensor_proto(trace.tune),
|
|
215
|
+
target_log_prob=backend.make_tensor_proto(trace.target_log_prob),
|
|
216
|
+
diverging=backend.make_tensor_proto(trace.diverging),
|
|
217
|
+
accept_ratio=backend.make_tensor_proto(trace.accept_ratio),
|
|
218
|
+
n_steps=backend.make_tensor_proto(trace.n_steps),
|
|
219
|
+
is_accepted=backend.make_tensor_proto(trace.is_accepted),
|
|
220
|
+
)
|
|
221
|
+
model_convergence_proto.mcmc_sampling_trace.CopyFrom(mcmc_sampling_trace)
|
|
222
|
+
|
|
223
|
+
return model_convergence_proto
|
|
224
|
+
|
|
225
|
+
def deserialize(
|
|
226
|
+
self,
|
|
227
|
+
serialized: kernel_pb.MmmKernel,
|
|
228
|
+
serialized_version: str = '',
|
|
229
|
+
distribution_function_registry: FunctionRegistry | None = None,
|
|
230
|
+
eda_function_registry: FunctionRegistry | None = None,
|
|
231
|
+
force_deserialization=False,
|
|
232
|
+
) -> model.Meridian:
|
|
233
|
+
"""Deserializes the given `MmmKernel` proto into a Meridian model.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
serialized: The serialized object in the form of an `MmmKernel` proto.
|
|
237
|
+
serialized_version: The version of the serialized model. This is used to
|
|
238
|
+
handle changes in deserialization logic across different versions.
|
|
239
|
+
distribution_function_registry: Optional. A lookup table that maps string
|
|
240
|
+
keys to custom functions to be used as parameters in various
|
|
241
|
+
`tfp.distributions`.
|
|
242
|
+
eda_function_registry: A lookup table containing custom functions used by
|
|
243
|
+
`EDASpec` objects.
|
|
244
|
+
force_deserialization: If True, bypasses the safety check that validates
|
|
245
|
+
whether functions within a function registry have changed after
|
|
246
|
+
serialization. Use with caution. This should only be used if you have
|
|
247
|
+
intentionally modified a custom function and are confident that the
|
|
248
|
+
changes will not affect the deserialized model. A safer alternative is
|
|
249
|
+
to first deserialize the model with the original functions and then
|
|
250
|
+
serialize it with the new ones.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
A Meridian model object.
|
|
254
|
+
"""
|
|
255
|
+
if serialized.model.Is(meridian_pb.MeridianModel.DESCRIPTOR):
|
|
256
|
+
ser_meridian = meridian_pb.MeridianModel()
|
|
257
|
+
else:
|
|
258
|
+
raise ValueError('`serialized.model` is not a `MeridianModel`.')
|
|
259
|
+
serialized.model.Unpack(ser_meridian)
|
|
260
|
+
serialized_version = semver.VersionInfo.parse(ser_meridian.model_version)
|
|
261
|
+
|
|
262
|
+
deserialized_hyperparameters = (
|
|
263
|
+
hyperparameters.HyperparametersSerde().deserialize(
|
|
264
|
+
ser_meridian.hyperparameters, str(serialized_version)
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if ser_meridian.HasField('prior_distributions'):
|
|
269
|
+
ser_meridian_priors = ser_meridian.prior_distributions
|
|
270
|
+
elif ser_meridian.HasField('prior_tfp_distributions') and isinstance(
|
|
271
|
+
ser_meridian, meridian_pb.MeridianModel
|
|
272
|
+
):
|
|
273
|
+
ser_meridian_priors = ser_meridian.prior_tfp_distributions
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError('MeridianModel does not contain any priors.')
|
|
276
|
+
|
|
277
|
+
deserialized_prior_distributions = distribution.DistributionSerde(
|
|
278
|
+
distribution_function_registry
|
|
279
|
+
if distribution_function_registry is not None
|
|
280
|
+
else function_registry_utils.FunctionRegistry()
|
|
281
|
+
).deserialize(
|
|
282
|
+
ser_meridian_priors,
|
|
283
|
+
str(serialized_version),
|
|
284
|
+
force_deserialization=force_deserialization,
|
|
285
|
+
)
|
|
286
|
+
deserialized_marketing_data = (
|
|
287
|
+
marketing_data.MarketingDataSerde().deserialize(
|
|
288
|
+
serialized.marketing_data, str(serialized_version)
|
|
289
|
+
)
|
|
290
|
+
)
|
|
291
|
+
deserialized_inference_data = (
|
|
292
|
+
inference_data.InferenceDataSerde().deserialize(
|
|
293
|
+
ser_meridian.inference_data, str(serialized_version)
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
deserialized_model_spec = dataclasses.replace(
|
|
298
|
+
deserialized_hyperparameters, prior=deserialized_prior_distributions
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
meridian_kwargs = dict(
|
|
302
|
+
input_data=deserialized_marketing_data,
|
|
303
|
+
model_spec=deserialized_model_spec,
|
|
304
|
+
inference_data=deserialized_inference_data,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# For backwards compatibility, only deserialize EDA spec if it exists in the
|
|
308
|
+
# serialized model. Otherwise, warn the user and create a model with default
|
|
309
|
+
# EDA spec.
|
|
310
|
+
if isinstance(
|
|
311
|
+
ser_meridian, meridian_pb.MeridianModel
|
|
312
|
+
) and ser_meridian.HasField('eda_spec'):
|
|
313
|
+
meridian_kwargs['eda_spec'] = eda_spec_serde.EDASpecSerde(
|
|
314
|
+
eda_function_registry
|
|
315
|
+
if eda_function_registry is not None
|
|
316
|
+
else function_registry_utils.FunctionRegistry()
|
|
317
|
+
).deserialize(
|
|
318
|
+
ser_meridian.eda_spec,
|
|
319
|
+
str(serialized_version),
|
|
320
|
+
force_deserialization=force_deserialization,
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
warnings.warn('MeridianModel does not contain an EDA spec.')
|
|
324
|
+
|
|
325
|
+
return model.Meridian(**meridian_kwargs)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def save_meridian(
|
|
329
|
+
mmm: model.Meridian,
|
|
330
|
+
file_path: str,
|
|
331
|
+
distribution_function_registry: FunctionRegistry | None = None,
|
|
332
|
+
eda_function_registry: FunctionRegistry | None = None,
|
|
333
|
+
):
|
|
334
|
+
"""Save the model object as an `MmmKernel` proto in the given filepath.
|
|
335
|
+
|
|
336
|
+
Supported file types:
|
|
337
|
+
- `binpb` (wire-format proto)
|
|
338
|
+
- `txtpb` (text-format proto)
|
|
339
|
+
- `textproto` (text-format proto)
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
mmm: Model object to save.
|
|
343
|
+
file_path: File path to save a serialized model object. If the file name
|
|
344
|
+
ends with `.binpb`, it will be saved in the wire-format. If the filename
|
|
345
|
+
ends with `.txtpb` or `.textproto`, it will be saved in the text-format.
|
|
346
|
+
distribution_function_registry: Optional. A lookup table that maps string
|
|
347
|
+
keys to custom functions to be used as parameters in various
|
|
348
|
+
`tfp.distributions`.
|
|
349
|
+
eda_function_registry: A lookup table that maps string keys to custom
|
|
350
|
+
functions to be used in `EDASpec`.
|
|
351
|
+
"""
|
|
352
|
+
if not _file_exists(os.path.dirname(file_path)):
|
|
353
|
+
_make_dirs(os.path.dirname(file_path))
|
|
354
|
+
|
|
355
|
+
with _file_open(file_path, 'wb') as f:
|
|
356
|
+
# Creates an MmmKernel.
|
|
357
|
+
serialized_kernel = MeridianSerde().serialize(
|
|
358
|
+
mmm,
|
|
359
|
+
distribution_function_registry=distribution_function_registry,
|
|
360
|
+
eda_function_registry=eda_function_registry,
|
|
361
|
+
)
|
|
362
|
+
if file_path.endswith('.binpb'):
|
|
363
|
+
f.write(serialized_kernel.SerializeToString())
|
|
364
|
+
elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'):
|
|
365
|
+
f.write(text_format.MessageToString(serialized_kernel))
|
|
366
|
+
else:
|
|
367
|
+
raise ValueError(f'Unsupported file type: {file_path}')
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def load_meridian(
|
|
371
|
+
file_path: str,
|
|
372
|
+
distribution_function_registry: FunctionRegistry | None = None,
|
|
373
|
+
eda_function_registry: FunctionRegistry | None = None,
|
|
374
|
+
force_deserialization=False,
|
|
375
|
+
) -> model.Meridian:
|
|
376
|
+
"""Load the model object from an `MmmKernel` proto file path.
|
|
377
|
+
|
|
378
|
+
Supported file types:
|
|
379
|
+
- `binpb` (wire-format proto)
|
|
380
|
+
- `txtpb` (text-format proto)
|
|
381
|
+
- `textproto` (text-format proto)
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
file_path: File path to load a serialized model object from.
|
|
385
|
+
distribution_function_registry: A lookup table that maps string keys to
|
|
386
|
+
custom functions to be used as parameters in various `tfp.distributions`.
|
|
387
|
+
eda_function_registry: A lookup table that maps string keys to custom
|
|
388
|
+
functions to be used in `EDASpec`.
|
|
389
|
+
force_deserialization: If True, bypasses the safety check that validates
|
|
390
|
+
whether functions within a function registry have changed after
|
|
391
|
+
serialization. Use with caution. This should only be used if you have
|
|
392
|
+
intentionally modified a custom function and are confident that the
|
|
393
|
+
changes will not affect the deserialized model. A safer alternative is to
|
|
394
|
+
first deserialize the model with the original functions and then serialize
|
|
395
|
+
it with the new ones.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Model object loaded from the file path.
|
|
399
|
+
"""
|
|
400
|
+
with _file_open(file_path, 'rb') as f:
|
|
401
|
+
if file_path.endswith('.binpb'):
|
|
402
|
+
serialized_model = kernel_pb.MmmKernel.FromString(f.read())
|
|
403
|
+
elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'):
|
|
404
|
+
serialized_model = kernel_pb.MmmKernel()
|
|
405
|
+
text_format.Parse(f.read(), serialized_model)
|
|
406
|
+
else:
|
|
407
|
+
raise ValueError(f'Unsupported file type: {file_path}')
|
|
408
|
+
return MeridianSerde().deserialize(
|
|
409
|
+
serialized_model,
|
|
410
|
+
distribution_function_registry=distribution_function_registry,
|
|
411
|
+
eda_function_registry=eda_function_registry,
|
|
412
|
+
force_deserialization=force_deserialization,
|
|
413
|
+
)
|
schema/serde/serde.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Serialization and deserialization of Meridian models."""
|
|
16
|
+
|
|
17
|
+
import abc
|
|
18
|
+
from typing import Generic, TypeVar
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
WireFormat = TypeVar("WireFormat")
|
|
22
|
+
PythonType = TypeVar("PythonType")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Serde(Generic[WireFormat, PythonType], abc.ABC):
|
|
26
|
+
"""Serializes and deserializes a Python type into a wire format."""
|
|
27
|
+
|
|
28
|
+
def serialize(self, obj: PythonType, **kwargs) -> WireFormat:
|
|
29
|
+
"""Serializes the given object into a wire format."""
|
|
30
|
+
raise NotImplementedError()
|
|
31
|
+
|
|
32
|
+
def deserialize(
|
|
33
|
+
self, serialized: WireFormat, serialized_version: str = "", **kwargs
|
|
34
|
+
) -> PythonType:
|
|
35
|
+
"""Deserializes the given wire format into a Python object.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
serialized: The serialized object.
|
|
39
|
+
serialized_version: The version of the serialized object. This is used to
|
|
40
|
+
handle changes in deserialization logic across different versions.
|
|
41
|
+
**kwargs: Additional keyword arguments to pass to the deserialization
|
|
42
|
+
function.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The deserialized object.
|
|
46
|
+
"""
|
|
47
|
+
raise NotImplementedError()
|