google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,413 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Serialization and deserialization of Meridian models into/from proto format.
16
+
17
+ The `meridian_serde.MeridianSerde` class provides an interface for serializing
18
+ and deserializing Meridian models into and from an `MmmKernel` proto message.
19
+
20
+ The Meridian model--when serialized into an `MmmKernel` proto--is internally
21
+ represented as the sum of the following components:
22
+
23
+ 1. Marketing data: This includes the KPI, media, and control data present in
24
+ the input data. They are structured into an MMM-agnostic `MarketingData`
25
+ proto message.
26
+ 2. Meridian model: A `MeridianModel` proto message encapsulates
27
+ Meridian-specific model parameters, including hyperparameters, prior
28
+ distributions, and sampled inference data.
29
+
30
+ Sample usage:
31
+
32
+ ```python
33
+ from schema.serde import meridian_serde
34
+
35
+ serde = meridian_serde.MeridianSerde()
36
+ mmm = model.Meridian(...)
37
+ serialized_mmm = serde.serialize(mmm) # An `MmmKernel` proto
38
+ deserialized_mmm = serde.deserialize(serialized_mmm) # A `Meridian` object
39
+ ```
40
+ """
41
+
42
+ import dataclasses
43
+ import os
44
+ import warnings
45
+
46
+ from google.protobuf import text_format
47
+ import meridian
48
+ from meridian import backend
49
+ from meridian.analysis import analyzer
50
+ from meridian.analysis import visualizer
51
+ from meridian.model import model
52
+ from mmm.v1.model import mmm_kernel_pb2 as kernel_pb
53
+ from mmm.v1.model.meridian import meridian_model_pb2 as meridian_pb
54
+ from schema.serde import distribution
55
+ from schema.serde import eda_spec as eda_spec_serde
56
+ from schema.serde import function_registry as function_registry_utils
57
+ from schema.serde import hyperparameters
58
+ from schema.serde import inference_data
59
+ from schema.serde import marketing_data
60
+ from schema.serde import serde
61
+ import semver
62
+
63
+ from google.protobuf import any_pb2
64
+
65
+
66
+ _VERSION_INFO = semver.VersionInfo.parse(meridian.__version__)
67
+
68
+ FunctionRegistry = function_registry_utils.FunctionRegistry
69
+
70
+ _file_exists = os.path.exists
71
+ _make_dirs = os.makedirs
72
+ _file_open = open
73
+
74
+
75
+ class MeridianSerde(serde.Serde[kernel_pb.MmmKernel, model.Meridian]):
76
+ """Serializes and deserializes a Meridian model into an `MmmKernel` proto."""
77
+
78
+ def serialize(
79
+ self,
80
+ obj: model.Meridian,
81
+ model_id: str = '',
82
+ meridian_version: semver.VersionInfo = _VERSION_INFO,
83
+ include_convergence_info: bool = False,
84
+ distribution_function_registry: FunctionRegistry | None = None,
85
+ eda_function_registry: FunctionRegistry | None = None,
86
+ ) -> kernel_pb.MmmKernel:
87
+ """Serializes the given Meridian model into an `MmmKernel` proto.
88
+
89
+ Args:
90
+ obj: The Meridian model to serialize.
91
+ model_id: The ID of the model.
92
+ meridian_version: The version of the Meridian model.
93
+ include_convergence_info: Whether to include convergence information.
94
+ distribution_function_registry: Optional. A lookup table that maps string
95
+ keys to custom functions to be used as parameters in various
96
+ `tfp.distributions`.
97
+ eda_function_registry: A lookup table that maps string keys to custom
98
+ functions to be used in `EDASpec`.
99
+
100
+ Returns:
101
+ An `MmmKernel` proto representing the serialized model.
102
+ """
103
+ distribution_registry = (
104
+ distribution_function_registry
105
+ if distribution_function_registry is not None
106
+ else function_registry_utils.FunctionRegistry()
107
+ )
108
+
109
+ eda_function_registry = (
110
+ eda_function_registry
111
+ if eda_function_registry is not None
112
+ else function_registry_utils.FunctionRegistry()
113
+ )
114
+ meridian_model_proto = self._make_meridian_model_proto(
115
+ mmm=obj,
116
+ model_id=model_id,
117
+ meridian_version=meridian_version,
118
+ distribution_function_registry=distribution_registry,
119
+ eda_function_registry=eda_function_registry,
120
+ include_convergence_info=include_convergence_info,
121
+ )
122
+ any_model = any_pb2.Any()
123
+ any_model.Pack(meridian_model_proto)
124
+ return kernel_pb.MmmKernel(
125
+ marketing_data=marketing_data.MarketingDataSerde().serialize(
126
+ obj.input_data
127
+ ),
128
+ model=any_model,
129
+ )
130
+
131
+ def _make_meridian_model_proto(
132
+ self,
133
+ mmm: model.Meridian,
134
+ model_id: str,
135
+ meridian_version: semver.VersionInfo,
136
+ distribution_function_registry: FunctionRegistry,
137
+ eda_function_registry: FunctionRegistry,
138
+ include_convergence_info: bool = False,
139
+ ) -> meridian_pb.MeridianModel:
140
+ """Constructs a MeridianModel proto from the TrainedModel.
141
+
142
+ Args:
143
+ mmm: Meridian model.
144
+ model_id: The ID of the model.
145
+ meridian_version: The version of the Meridian model.
146
+ distribution_function_registry: A lookup table that maps string keys to
147
+ custom functions to be used as parameters in various
148
+ `tfp.distributions`.
149
+ eda_function_registry: A lookup table containing custom functions used by
150
+ `EDASpec` objects.
151
+ include_convergence_info: Whether to include convergence information.
152
+
153
+ Returns:
154
+ A MeridianModel proto.
155
+ """
156
+ model_proto = meridian_pb.MeridianModel(
157
+ model_id=model_id,
158
+ model_version=str(meridian_version),
159
+ hyperparameters=hyperparameters.HyperparametersSerde().serialize(
160
+ mmm.model_spec
161
+ ),
162
+ prior_tfp_distributions=distribution.DistributionSerde(
163
+ distribution_function_registry
164
+ ).serialize(mmm.model_spec.prior),
165
+ inference_data=inference_data.InferenceDataSerde().serialize(
166
+ mmm.inference_data
167
+ ),
168
+ )
169
+ # For backwards compatibility, only serialize EDA spec if it exists.
170
+ if hasattr(mmm, 'eda_spec'):
171
+ model_proto.eda_spec.CopyFrom(
172
+ eda_spec_serde.EDASpecSerde(eda_function_registry).serialize(
173
+ mmm.eda_spec
174
+ )
175
+ )
176
+
177
+ if include_convergence_info:
178
+ convergence_proto = self._make_model_convergence_proto(mmm)
179
+ if convergence_proto is not None:
180
+ model_proto.convergence_info.CopyFrom(convergence_proto)
181
+
182
+ return model_proto
183
+
184
+ def _make_model_convergence_proto(
185
+ self, mmm: model.Meridian
186
+ ) -> meridian_pb.ModelConvergence | None:
187
+ """Creates ModelConvergence proto."""
188
+ model_convergence_proto = meridian_pb.ModelConvergence()
189
+ try:
190
+ # NotFittedModelError can be raised below. If raised,
191
+ # return None. Otherwise, set convergence status based on
192
+ # MCMCSamplingError (caught in the except block).
193
+ rhats = analyzer.Analyzer(mmm).get_rhat()
194
+ rhat_proto = meridian_pb.RHatDiagnostic()
195
+ for name, tensor in rhats.items():
196
+ rhat_proto.parameter_r_hats.add(
197
+ name=name, tensor=backend.make_tensor_proto(tensor)
198
+ )
199
+ model_convergence_proto.r_hat_diagnostic.CopyFrom(rhat_proto)
200
+
201
+ visualizer.ModelDiagnostics(mmm).plot_rhat_boxplot()
202
+ model_convergence_proto.convergence = True
203
+ except model.MCMCSamplingError:
204
+ model_convergence_proto.convergence = False
205
+ except model.NotFittedModelError:
206
+ return None
207
+
208
+ if hasattr(mmm.inference_data, 'trace'):
209
+ trace = mmm.inference_data.trace
210
+ mcmc_sampling_trace = meridian_pb.McmcSamplingTrace(
211
+ num_chains=len(trace.chain),
212
+ num_draws=len(trace.draw),
213
+ step_size=backend.make_tensor_proto(trace.step_size),
214
+ tune=backend.make_tensor_proto(trace.tune),
215
+ target_log_prob=backend.make_tensor_proto(trace.target_log_prob),
216
+ diverging=backend.make_tensor_proto(trace.diverging),
217
+ accept_ratio=backend.make_tensor_proto(trace.accept_ratio),
218
+ n_steps=backend.make_tensor_proto(trace.n_steps),
219
+ is_accepted=backend.make_tensor_proto(trace.is_accepted),
220
+ )
221
+ model_convergence_proto.mcmc_sampling_trace.CopyFrom(mcmc_sampling_trace)
222
+
223
+ return model_convergence_proto
224
+
225
+ def deserialize(
226
+ self,
227
+ serialized: kernel_pb.MmmKernel,
228
+ serialized_version: str = '',
229
+ distribution_function_registry: FunctionRegistry | None = None,
230
+ eda_function_registry: FunctionRegistry | None = None,
231
+ force_deserialization=False,
232
+ ) -> model.Meridian:
233
+ """Deserializes the given `MmmKernel` proto into a Meridian model.
234
+
235
+ Args:
236
+ serialized: The serialized object in the form of an `MmmKernel` proto.
237
+ serialized_version: The version of the serialized model. This is used to
238
+ handle changes in deserialization logic across different versions.
239
+ distribution_function_registry: Optional. A lookup table that maps string
240
+ keys to custom functions to be used as parameters in various
241
+ `tfp.distributions`.
242
+ eda_function_registry: A lookup table containing custom functions used by
243
+ `EDASpec` objects.
244
+ force_deserialization: If True, bypasses the safety check that validates
245
+ whether functions within a function registry have changed after
246
+ serialization. Use with caution. This should only be used if you have
247
+ intentionally modified a custom function and are confident that the
248
+ changes will not affect the deserialized model. A safer alternative is
249
+ to first deserialize the model with the original functions and then
250
+ serialize it with the new ones.
251
+
252
+ Returns:
253
+ A Meridian model object.
254
+ """
255
+ if serialized.model.Is(meridian_pb.MeridianModel.DESCRIPTOR):
256
+ ser_meridian = meridian_pb.MeridianModel()
257
+ else:
258
+ raise ValueError('`serialized.model` is not a `MeridianModel`.')
259
+ serialized.model.Unpack(ser_meridian)
260
+ serialized_version = semver.VersionInfo.parse(ser_meridian.model_version)
261
+
262
+ deserialized_hyperparameters = (
263
+ hyperparameters.HyperparametersSerde().deserialize(
264
+ ser_meridian.hyperparameters, str(serialized_version)
265
+ )
266
+ )
267
+
268
+ if ser_meridian.HasField('prior_distributions'):
269
+ ser_meridian_priors = ser_meridian.prior_distributions
270
+ elif ser_meridian.HasField('prior_tfp_distributions') and isinstance(
271
+ ser_meridian, meridian_pb.MeridianModel
272
+ ):
273
+ ser_meridian_priors = ser_meridian.prior_tfp_distributions
274
+ else:
275
+ raise ValueError('MeridianModel does not contain any priors.')
276
+
277
+ deserialized_prior_distributions = distribution.DistributionSerde(
278
+ distribution_function_registry
279
+ if distribution_function_registry is not None
280
+ else function_registry_utils.FunctionRegistry()
281
+ ).deserialize(
282
+ ser_meridian_priors,
283
+ str(serialized_version),
284
+ force_deserialization=force_deserialization,
285
+ )
286
+ deserialized_marketing_data = (
287
+ marketing_data.MarketingDataSerde().deserialize(
288
+ serialized.marketing_data, str(serialized_version)
289
+ )
290
+ )
291
+ deserialized_inference_data = (
292
+ inference_data.InferenceDataSerde().deserialize(
293
+ ser_meridian.inference_data, str(serialized_version)
294
+ )
295
+ )
296
+
297
+ deserialized_model_spec = dataclasses.replace(
298
+ deserialized_hyperparameters, prior=deserialized_prior_distributions
299
+ )
300
+
301
+ meridian_kwargs = dict(
302
+ input_data=deserialized_marketing_data,
303
+ model_spec=deserialized_model_spec,
304
+ inference_data=deserialized_inference_data,
305
+ )
306
+
307
+ # For backwards compatibility, only deserialize EDA spec if it exists in the
308
+ # serialized model. Otherwise, warn the user and create a model with default
309
+ # EDA spec.
310
+ if isinstance(
311
+ ser_meridian, meridian_pb.MeridianModel
312
+ ) and ser_meridian.HasField('eda_spec'):
313
+ meridian_kwargs['eda_spec'] = eda_spec_serde.EDASpecSerde(
314
+ eda_function_registry
315
+ if eda_function_registry is not None
316
+ else function_registry_utils.FunctionRegistry()
317
+ ).deserialize(
318
+ ser_meridian.eda_spec,
319
+ str(serialized_version),
320
+ force_deserialization=force_deserialization,
321
+ )
322
+ else:
323
+ warnings.warn('MeridianModel does not contain an EDA spec.')
324
+
325
+ return model.Meridian(**meridian_kwargs)
326
+
327
+
328
+ def save_meridian(
329
+ mmm: model.Meridian,
330
+ file_path: str,
331
+ distribution_function_registry: FunctionRegistry | None = None,
332
+ eda_function_registry: FunctionRegistry | None = None,
333
+ ):
334
+ """Save the model object as an `MmmKernel` proto in the given filepath.
335
+
336
+ Supported file types:
337
+ - `binpb` (wire-format proto)
338
+ - `txtpb` (text-format proto)
339
+ - `textproto` (text-format proto)
340
+
341
+ Args:
342
+ mmm: Model object to save.
343
+ file_path: File path to save a serialized model object. If the file name
344
+ ends with `.binpb`, it will be saved in the wire-format. If the filename
345
+ ends with `.txtpb` or `.textproto`, it will be saved in the text-format.
346
+ distribution_function_registry: Optional. A lookup table that maps string
347
+ keys to custom functions to be used as parameters in various
348
+ `tfp.distributions`.
349
+ eda_function_registry: A lookup table that maps string keys to custom
350
+ functions to be used in `EDASpec`.
351
+ """
352
+ if not _file_exists(os.path.dirname(file_path)):
353
+ _make_dirs(os.path.dirname(file_path))
354
+
355
+ with _file_open(file_path, 'wb') as f:
356
+ # Creates an MmmKernel.
357
+ serialized_kernel = MeridianSerde().serialize(
358
+ mmm,
359
+ distribution_function_registry=distribution_function_registry,
360
+ eda_function_registry=eda_function_registry,
361
+ )
362
+ if file_path.endswith('.binpb'):
363
+ f.write(serialized_kernel.SerializeToString())
364
+ elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'):
365
+ f.write(text_format.MessageToString(serialized_kernel))
366
+ else:
367
+ raise ValueError(f'Unsupported file type: {file_path}')
368
+
369
+
370
+ def load_meridian(
371
+ file_path: str,
372
+ distribution_function_registry: FunctionRegistry | None = None,
373
+ eda_function_registry: FunctionRegistry | None = None,
374
+ force_deserialization=False,
375
+ ) -> model.Meridian:
376
+ """Load the model object from an `MmmKernel` proto file path.
377
+
378
+ Supported file types:
379
+ - `binpb` (wire-format proto)
380
+ - `txtpb` (text-format proto)
381
+ - `textproto` (text-format proto)
382
+
383
+ Args:
384
+ file_path: File path to load a serialized model object from.
385
+ distribution_function_registry: A lookup table that maps string keys to
386
+ custom functions to be used as parameters in various `tfp.distributions`.
387
+ eda_function_registry: A lookup table that maps string keys to custom
388
+ functions to be used in `EDASpec`.
389
+ force_deserialization: If True, bypasses the safety check that validates
390
+ whether functions within a function registry have changed after
391
+ serialization. Use with caution. This should only be used if you have
392
+ intentionally modified a custom function and are confident that the
393
+ changes will not affect the deserialized model. A safer alternative is to
394
+ first deserialize the model with the original functions and then serialize
395
+ it with the new ones.
396
+
397
+ Returns:
398
+ Model object loaded from the file path.
399
+ """
400
+ with _file_open(file_path, 'rb') as f:
401
+ if file_path.endswith('.binpb'):
402
+ serialized_model = kernel_pb.MmmKernel.FromString(f.read())
403
+ elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'):
404
+ serialized_model = kernel_pb.MmmKernel()
405
+ text_format.Parse(f.read(), serialized_model)
406
+ else:
407
+ raise ValueError(f'Unsupported file type: {file_path}')
408
+ return MeridianSerde().deserialize(
409
+ serialized_model,
410
+ distribution_function_registry=distribution_function_registry,
411
+ eda_function_registry=eda_function_registry,
412
+ force_deserialization=force_deserialization,
413
+ )
schema/serde/serde.py ADDED
@@ -0,0 +1,47 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Serialization and deserialization of Meridian models."""
16
+
17
+ import abc
18
+ from typing import Generic, TypeVar
19
+
20
+
21
+ WireFormat = TypeVar("WireFormat")
22
+ PythonType = TypeVar("PythonType")
23
+
24
+
25
+ class Serde(Generic[WireFormat, PythonType], abc.ABC):
26
+ """Serializes and deserializes a Python type into a wire format."""
27
+
28
+ def serialize(self, obj: PythonType, **kwargs) -> WireFormat:
29
+ """Serializes the given object into a wire format."""
30
+ raise NotImplementedError()
31
+
32
+ def deserialize(
33
+ self, serialized: WireFormat, serialized_version: str = "", **kwargs
34
+ ) -> PythonType:
35
+ """Deserializes the given wire format into a Python object.
36
+
37
+ Args:
38
+ serialized: The serialized object.
39
+ serialized_version: The version of the serialized object. This is used to
40
+ handle changes in deserialization logic across different versions.
41
+ **kwargs: Additional keyword arguments to pass to the deserialization
42
+ function.
43
+
44
+ Returns:
45
+ The deserialized object.
46
+ """
47
+ raise NotImplementedError()